Skip to content

Commit ba52d14

Browse files
authored
Merge pull request #126 from yo-main/master
perf: refacto & slight performance improvements
2 parents 2e0a802 + ff7f288 commit ba52d14

File tree

8 files changed

+49
-126
lines changed

8 files changed

+49
-126
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,4 +103,7 @@ venv.bak/
103103
# mypy
104104
.mypy_cache/
105105
_trial_temp
106-
.idea
106+
.idea
107+
108+
# vscode
109+
.vscode

casbin/enforcer.py

Lines changed: 15 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,7 @@ def get_users_for_role(self, name):
2626
def has_role_for_user(self, name, role):
2727
""" determines whether a user has a role. """
2828
roles = self.get_roles_for_user(name)
29-
30-
hasRole = False
31-
for r in roles:
32-
if r == role:
33-
hasRole = True
34-
break
35-
36-
return hasRole
29+
return any(r == role for r in roles)
3730

3831
def add_role_for_user(self, user, role):
3932
"""
@@ -116,7 +109,7 @@ def has_permission_for_user(self, user, *permission):
116109
"""
117110
return self.has_policy(join_slice(user, *permission))
118111

119-
def get_implicit_roles_for_user(self, name, *domain):
112+
def get_implicit_roles_for_user(self, name, domain=None):
120113
"""
121114
gets implicit roles that a user has.
122115
Compared to get_roles_for_user(), this function retrieves indirect roles besides direct roles.
@@ -127,28 +120,22 @@ def get_implicit_roles_for_user(self, name, *domain):
127120
get_roles_for_user("alice") can only get: ["role:admin"].
128121
But get_implicit_roles_for_user("alice") will get: ["role:admin", "role:user"].
129122
"""
130-
res = list()
131-
roleSet = dict()
132-
roleSet[name] = True
133-
134-
q = list()
135-
q.append(name)
123+
res = []
124+
queue = [name]
136125

137-
while len(q) > 0:
138-
name = q[0]
139-
q = q[1:]
126+
while queue:
127+
name = queue.pop(0)
140128

141129
for rm in self.rm_map.values():
142-
roles = rm.get_roles(name, *domain)
130+
roles = rm.get_roles(name, domain)
143131
for r in roles:
144-
if r not in roleSet:
132+
if r not in res:
145133
res.append(r)
146-
q.append(r)
147-
roleSet[r] = True
134+
queue.append(r)
148135

149136
return res
150137

151-
def get_implicit_permissions_for_user(self, user, *domain):
138+
def get_implicit_permissions_for_user(self, user, domain=None):
152139
"""
153140
gets implicit permissions for a user or role.
154141
Compared to get_permissions_for_user(), this function retrieves permissions for inherited roles.
@@ -160,23 +147,17 @@ def get_implicit_permissions_for_user(self, user, *domain):
160147
get_permissions_for_user("alice") can only get: [["alice", "data2", "read"]].
161148
But get_implicit_permissions_for_user("alice") will get: [["admin", "data1", "read"], ["alice", "data2", "read"]].
162149
"""
163-
roles = self.get_implicit_roles_for_user(user, *domain)
150+
roles = self.get_implicit_roles_for_user(user, domain)
164151

165152
roles.insert(0, user)
166153

167-
withDomain = False
168-
if len(domain) == 1:
169-
withDomain = True
170-
elif len(domain) > 1:
171-
return None
172-
173154
res = []
174-
permissions = [list()]
175155
for role in roles:
176-
if withDomain:
177-
permissions = self.get_permissions_for_user_in_domain(role, domain[0])
156+
if domain:
157+
permissions = self.get_permissions_for_user_in_domain(role, domain)
178158
else:
179159
permissions = self.get_permissions_for_user(role)
160+
180161
res.extend(permissions)
181162

182163
return res
@@ -227,4 +208,4 @@ def delete_roles_for_user_in_domain(self, user, role, domain):
227208

228209
def get_permissions_for_user_in_domain(self, user, domain):
229210
"""gets permissions for a user or role inside domain."""
230-
return self.get_filtered_policy(0, user, domain)
211+
return self.get_filtered_policy(0, user, domain)

casbin/management_enforcer.py

Lines changed: 6 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,7 @@ def has_named_policy(self, ptype, *params):
7979
str_slice = params[0]
8080
return self.model.has_policy('p', ptype, str_slice)
8181

82-
policy = []
83-
84-
for param in params:
85-
policy.append(param)
86-
87-
return self.model.has_policy('p', ptype, policy)
82+
return self.model.has_policy('p', ptype, list(params))
8883

8984
def add_policy(self, *params):
9085
"""adds an authorization rule to the current policy.
@@ -113,12 +108,7 @@ def add_named_policy(self, ptype, *params):
113108
str_slice = params[0]
114109
rule_added = self._add_policy('p', ptype, str_slice)
115110
else:
116-
policy = []
117-
118-
for param in params:
119-
policy.append(param)
120-
121-
rule_added = self._add_policy('p', ptype, policy)
111+
rule_added = self._add_policy('p', ptype, list(params))
122112

123113
return rule_added
124114

@@ -164,12 +154,7 @@ def remove_named_policy(self, ptype, *params):
164154
str_slice = params[0]
165155
rule_removed = self._remove_policy('p', ptype, str_slice)
166156
else:
167-
policy = []
168-
169-
for param in params:
170-
policy.append(param)
171-
172-
rule_removed = self._remove_policy('p', ptype, policy)
157+
rule_removed = self._remove_policy('p', ptype, list(params))
173158

174159
return rule_removed
175160

@@ -193,12 +178,7 @@ def has_named_grouping_policy(self, ptype, *params):
193178
str_slice = params[0]
194179
return self.model.has_policy('g', ptype, str_slice)
195180

196-
policy = []
197-
198-
for param in params:
199-
policy.append(param)
200-
201-
return self.model.has_policy('g', ptype, policy)
181+
return self.model.has_policy('g', ptype, list(params))
202182

203183
def add_grouping_policy(self, *params):
204184
"""adds a role inheritance rule to the current policy.
@@ -227,12 +207,7 @@ def add_named_grouping_policy(self, ptype, *params):
227207
str_slice = params[0]
228208
rule_added = self._add_policy('g', ptype, str_slice)
229209
else:
230-
policy = []
231-
232-
for param in params:
233-
policy.append(param)
234-
235-
rule_added = self._add_policy('g', ptype, policy)
210+
rule_added = self._add_policy('g', ptype, list(params))
236211

237212
if self.auto_build_role_links:
238213
self.build_role_links()
@@ -268,12 +243,7 @@ def remove_named_grouping_policy(self, ptype, *params):
268243
str_slice = params[0]
269244
rule_removed = self._remove_policy('g', ptype, str_slice)
270245
else:
271-
policy = []
272-
273-
for param in params:
274-
policy.append(param)
275-
276-
rule_removed = self._remove_policy('g', ptype, policy)
246+
rule_removed = self._remove_policy('g', ptype, list(params))
277247

278248
if self.auto_build_role_links:
279249
self.build_role_links()

casbin/model/assertion.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,10 @@ def __init__(self):
1313
def build_role_links(self, rm):
1414
self.rm = rm
1515
count = self.value.count("_")
16+
if count < 2:
17+
raise RuntimeError('the number of "_" in role definition should be at least 2')
1618

1719
for rule in self.policy:
18-
if count < 2:
19-
raise RuntimeError('the number of "_" in role definition should be at least 2')
20-
2120
if len(rule) < count:
2221
raise RuntimeError("grouping policy elements do not meet role definition")
2322

casbin/model/policy.py

Lines changed: 13 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
import logging
22

3-
from casbin import util
4-
5-
63
class Policy:
74
def __init__(self):
85
self.logger = logging.getLogger()
@@ -36,7 +33,7 @@ def clear_policy(self):
3633
if sec not in self.model.keys():
3734
continue
3835

39-
for key, ast in self.model[sec].items():
36+
for key in self.model[sec].keys():
4037
self.model[sec][key].policy = []
4138

4239
def get_policy(self, sec, ptype):
@@ -46,19 +43,10 @@ def get_policy(self, sec, ptype):
4643

4744
def get_filtered_policy(self, sec, ptype, field_index, *field_values):
4845
"""gets rules based on field filters from a policy."""
49-
res = []
50-
51-
for rule in self.model[sec][ptype].policy:
52-
matched = True
53-
for i, field_value in enumerate(field_values):
54-
if field_value != '' and rule[field_index + i] != field_value:
55-
matched = False
56-
break
57-
58-
if matched:
59-
res.append(rule)
60-
61-
return res
46+
return [
47+
rule for rule in self.model[sec][ptype].policy
48+
if all(value == "" or rule[field_index + i] == value for i, value in enumerate(field_values))
49+
]
6250

6351
def has_policy(self, sec, ptype, rule):
6452
"""determines whether a model has the specified policy rule."""
@@ -80,14 +68,14 @@ def add_policy(self, sec, ptype, rule):
8068

8169
def add_policies(self,sec,ptype,rules):
8270
"""adds policy rules to the model."""
83-
71+
8472
for rule in rules:
8573
if self.has_policy(sec,ptype,rule):
8674
return False
8775

8876
for rule in rules:
8977
self.model[sec][ptype].policy.append(rule)
90-
78+
9179
return True
9280

9381
def update_policy(self, sec, ptype, old_rule, new_rule):
@@ -109,7 +97,6 @@ def update_policies(self, sec, ptype, old_rules, new_rules):
10997

11098
def remove_policy(self, sec, ptype, rule):
11199
"""removes a policy rule from the model."""
112-
113100
if not self.has_policy(sec, ptype, rule):
114101
return False
115102

@@ -126,7 +113,7 @@ def remove_policies(self, sec, ptype, rules):
126113
self.model[sec][ptype].policy.remove(rule)
127114
if rule in self.model[sec][ptype].policy:
128115
return False
129-
116+
130117
return True
131118

132119
def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
@@ -140,13 +127,7 @@ def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
140127
return res
141128

142129
for rule in self.model[sec][ptype].policy:
143-
matched = True
144-
for i, field_value in enumerate(field_values):
145-
if field_value != '' and rule[field_index + i] != field_value:
146-
matched = False
147-
break
148-
149-
if matched:
130+
if all(value == "" or rule[field_index + i] == value for i, value in enumerate(field_values)):
150131
res = True
151132
else:
152133
tmp.append(rule)
@@ -157,14 +138,15 @@ def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
157138

158139
def get_values_for_field_in_policy(self, sec, ptype, field_index):
159140
"""gets all values for a field for all rules in a policy, duplicated values are removed."""
160-
161141
values = []
162142
if sec not in self.model.keys():
163143
return values
164144
if ptype not in self.model[sec]:
165145
return values
166146

167147
for rule in self.model[sec][ptype].policy:
168-
values.append(rule[field_index])
148+
value = rule[field_index]
149+
if value not in values:
150+
values.append(value)
169151

170-
return util.array_remove_duplicates(values)
152+
return values

casbin/rbac/default_role_manager/role_manager.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,23 +103,21 @@ def has_link(self, name1, name2, *domain):
103103
return True
104104
return False
105105

106-
def get_roles(self, name, *domain):
106+
def get_roles(self, name, domain=None):
107107
"""
108108
gets the roles that a subject inherits.
109109
domain is a prefix to the roles.
110110
"""
111-
if len(domain) == 1:
112-
name = domain[0] + "::" + name
113-
elif len(domain) > 1:
114-
return RuntimeError("error: domain should be 1 parameter")
111+
if domain:
112+
name = domain + "::" + name
115113

116114
if not self.has_role(name):
117115
return []
118116

119117
roles = self.create_role(name).get_roles()
120-
if len(domain) == 1:
118+
if domain:
121119
for key, value in enumerate(roles):
122-
roles[key] = value[len(domain[0]) + 2:]
120+
roles[key] = value[len(domain) + 2:]
123121

124122
return roles
125123

casbin/util/util.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,7 @@ def join_slice(a, *b):
4848

4949
def set_subtract(a, b):
5050
''' returns the elements in `a` that aren't in `b`. '''
51-
mb = dict()
52-
53-
for x in b:
54-
mb[x] = True
55-
56-
diff = list()
57-
for x in a:
58-
if x not in mb:
59-
diff.append(x)
60-
61-
return diff
51+
return [i for i in a if i not in b]
6252

6353
def has_eval(s):
6454
'''determine whether matcher contains function eval'''

test_filter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_load_filtered_policy(self):
3434

3535
self.assertTrue(e.has_policy(['admin', 'domain1', 'data1','read']))
3636
self.assertFalse(e.has_policy(['admin', 'domain2', 'data2','read']))
37-
37+
3838
with self.assertRaises(RuntimeError):
3939
e.save_policy()
4040

@@ -121,5 +121,5 @@ def test_filtered_adapter_invalid_filepath(self):
121121
adapter = casbin.persist.adapters.FilteredAdapter("examples/does_not_exist_policy.csv")
122122
e = casbin.Enforcer("examples/rbac_with_domains_model.conf", adapter)
123123

124-
with self.assertRaises(FileNotFoundError):
125-
e.load_filtered_policy(None)
124+
with self.assertRaises(RuntimeError):
125+
e.load_filtered_policy(None)

0 commit comments

Comments
 (0)