Skip to content

Commit 0dc281c

Browse files
authored
feat: Publication rule service should use roles of API key #1341 (#1342)
1 parent 1b548ec commit 0dc281c

File tree

3 files changed

+122
-60
lines changed

3 files changed

+122
-60
lines changed

server/src/main/java/com/epam/aidial/core/server/security/RuleMatcher.java

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import java.util.Collection;
88
import java.util.List;
99
import java.util.Map;
10+
import java.util.Optional;
1011
import java.util.regex.Pattern;
1112

1213
@UtilityClass
@@ -22,20 +23,16 @@ public boolean match(ProxyContext context, Collection<Rule> rules) {
2223
return true;
2324
}
2425

25-
ExtractedClaims claims = context.getExtractedClaims();
26-
if (claims == null) {
27-
return false;
28-
}
29-
30-
Map<String, List<String>> userClaims = claims.userClaims();
26+
Map<String, List<String>> userClaims = Optional.ofNullable(context.getExtractedClaims())
27+
.map(ExtractedClaims::userClaims).orElse(null);
3128

3229
for (Rule rule : rules) {
3330
String targetClaim = rule.getSource();
3431
List<String> sources;
3532
if (targetClaim.equals("roles")) {
36-
sources = claims.userRoles();
33+
sources = context.getUserRoles();
3734
} else {
38-
sources = userClaims.get(targetClaim);
35+
sources = userClaims == null ? null : userClaims.get(targetClaim);
3936
}
4037

4138
if (sources == null) {

server/src/test/java/com/epam/aidial/core/server/security/AccessServiceTest.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,7 @@ public void testGetOwnResourcesAccessForChainedSchemaRichApplication_NotAnApplic
311311
public void testGetAdminAccess_WhenPublicResource() {
312312
ResourceDescriptor resource = new ResourceDescriptor(ResourceTypes.FILE, "file.json", List.of(), "bucket", "public/", false);
313313
ApplicationSchemaService applicationSchemaService = mock(ApplicationSchemaService.class);
314-
ExtractedClaims extractedClaims = new ExtractedClaims("sub", List.of("admin"), "hash", Map.of(), null, "userName");
315-
when(context.getExtractedClaims()).thenReturn(extractedClaims);
314+
when(context.getUserRoles()).thenReturn(List.of("admin"));
316315
when(context.getApiKeyData()).thenReturn(new ApiKeyData());
317316
AccessService accessService = new AccessService(encryptionService, shareService, ruleService,
318317
applicationSchemaService,
@@ -335,8 +334,7 @@ public void testGetAdminAccess_WhenReviewResource() {
335334
String reviewLocation = "/Users/sub/publications/123/";
336335
ResourceDescriptor resource = new ResourceDescriptor(ResourceTypes.FILE, "file.json", List.of(), "bucket", reviewLocation, false);
337336
ApplicationSchemaService applicationSchemaService = mock(ApplicationSchemaService.class);
338-
ExtractedClaims extractedClaims = new ExtractedClaims("sub", List.of("admin"), "hash", Map.of(), null, "userName");
339-
when(context.getExtractedClaims()).thenReturn(extractedClaims);
337+
when(context.getUserRoles()).thenReturn(List.of("admin"));
340338
when(context.getApiKeyData()).thenReturn(new ApiKeyData());
341339
AccessService accessService = new AccessService(encryptionService, shareService, ruleService,
342340
applicationSchemaService,
@@ -380,8 +378,7 @@ public void testGetAdminAccess_WhenPrivateResource() {
380378
public void testGetAdminAccess_WhenSourceFolderOfCodeApp() {
381379
ResourceDescriptor resource = new ResourceDescriptor(ResourceTypes.FILE, "app.py", List.of(), "bucket", "public/deployments/123/", false);
382380
ApplicationSchemaService applicationSchemaService = mock(ApplicationSchemaService.class);
383-
ExtractedClaims extractedClaims = new ExtractedClaims("sub", List.of("admin"), "hash", Map.of(), null, "userName");
384-
when(context.getExtractedClaims()).thenReturn(extractedClaims);
381+
when(context.getUserRoles()).thenReturn(List.of("admin"));
385382
when(context.getApiKeyData()).thenReturn(new ApiKeyData());
386383
AccessService accessService = new AccessService(encryptionService, shareService, ruleService,
387384
applicationSchemaService,

server/src/test/java/com/epam/aidial/core/server/security/RuleMatcherTest.java

Lines changed: 114 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -13,101 +13,169 @@ class RuleMatcherTest {
1313

1414
@Test
1515
void testUserRoleRules() {
16-
verify(rule("roles", Rule.Function.TRUE), true, "any-role");
17-
verify(rule("roles", Rule.Function.FALSE), false, "any-role");
18-
19-
verify(rule("roles", Rule.Function.EQUAL, "admin"), true, "admin");
20-
verify(rule("roles", Rule.Function.EQUAL, "admin1"), false, "admin");
21-
verify(rule("roles", Rule.Function.EQUAL, "admin1"), false, "admin2");
22-
verify(rule("roles", Rule.Function.EQUAL, "AdMin"), true, "admin");
23-
24-
verify(rule("roles", Rule.Function.CONTAIN, "admin"), true, "admin");
25-
verify(rule("roles", Rule.Function.CONTAIN, "admin1"), false, "admin");
26-
verify(rule("roles", Rule.Function.CONTAIN, "admin"), true, "admin2");
27-
verify(rule("roles", Rule.Function.CONTAIN, "dmi"), true, "admin");
28-
verify(rule("roles", Rule.Function.CONTAIN, "manager"), true, "Delivery Manager");
29-
30-
verify(rule("roles", Rule.Function.REGEX, ".*"), true, "any");
31-
verify(rule("roles", Rule.Function.REGEX, "(admin|user)"), true, "user");
32-
verify(rule("roles", Rule.Function.REGEX, "(admin|user)$"), false, "user2");
16+
17+
// test access by access token
18+
19+
verifyOnAccessToken(rule("roles", Rule.Function.TRUE), true, "any-role");
20+
verifyOnAccessToken(rule("roles", Rule.Function.FALSE), false, "any-role");
21+
22+
verifyOnAccessToken(rule("roles", Rule.Function.EQUAL, "admin"), true, "admin");
23+
verifyOnAccessToken(rule("roles", Rule.Function.EQUAL, "admin1"), false, "admin");
24+
verifyOnAccessToken(rule("roles", Rule.Function.EQUAL, "admin1"), false, "admin2");
25+
verifyOnAccessToken(rule("roles", Rule.Function.EQUAL, "AdMin"), true, "admin");
26+
27+
verifyOnAccessToken(rule("roles", Rule.Function.CONTAIN, "admin"), true, "admin");
28+
verifyOnAccessToken(rule("roles", Rule.Function.CONTAIN, "admin1"), false, "admin");
29+
verifyOnAccessToken(rule("roles", Rule.Function.CONTAIN, "admin"), true, "admin2");
30+
verifyOnAccessToken(rule("roles", Rule.Function.CONTAIN, "dmi"), true, "admin");
31+
verifyOnAccessToken(rule("roles", Rule.Function.CONTAIN, "manager"), true, "Delivery Manager");
32+
33+
verifyOnAccessToken(rule("roles", Rule.Function.REGEX, ".*"), true, "any");
34+
verifyOnAccessToken(rule("roles", Rule.Function.REGEX, "(admin|user)"), true, "user");
35+
verifyOnAccessToken(rule("roles", Rule.Function.REGEX, "(admin|user)$"), false, "user2");
36+
37+
// test access by API key
38+
39+
verifyOnApiKey(rule("roles", Rule.Function.TRUE), true, "any-role");
40+
verifyOnApiKey(rule("roles", Rule.Function.FALSE), false, "any-role");
41+
42+
verifyOnApiKey(rule("roles", Rule.Function.EQUAL, "admin"), true, "admin");
43+
verifyOnApiKey(rule("roles", Rule.Function.EQUAL, "admin1"), false, "admin");
44+
verifyOnApiKey(rule("roles", Rule.Function.EQUAL, "admin1"), false, "admin2");
45+
verifyOnApiKey(rule("roles", Rule.Function.EQUAL, "AdMin"), true, "admin");
46+
47+
verifyOnApiKey(rule("roles", Rule.Function.CONTAIN, "admin"), true, "admin");
48+
verifyOnApiKey(rule("roles", Rule.Function.CONTAIN, "admin1"), false, "admin");
49+
verifyOnApiKey(rule("roles", Rule.Function.CONTAIN, "admin"), true, "admin2");
50+
verifyOnApiKey(rule("roles", Rule.Function.CONTAIN, "dmi"), true, "admin");
51+
verifyOnApiKey(rule("roles", Rule.Function.CONTAIN, "manager"), true, "Delivery Manager");
52+
53+
verifyOnApiKey(rule("roles", Rule.Function.REGEX, ".*"), true, "any");
54+
verifyOnApiKey(rule("roles", Rule.Function.REGEX, "(admin|user)"), true, "user");
55+
verifyOnApiKey(rule("roles", Rule.Function.REGEX, "(admin|user)$"), false, "user2");
3356
}
3457

3558
@Test
3659
void testUserClaimRules() {
37-
verify(List.of(rule("title", Rule.Function.TRUE)),
60+
// test access by access token
61+
62+
verifyOnAccessToken(List.of(rule("title", Rule.Function.TRUE)),
3863
List.of("user1"), Map.of("title", List.of("Software Engineer")), true);
39-
verify(List.of(rule("title", Rule.Function.FALSE)),
64+
verifyOnAccessToken(List.of(rule("title", Rule.Function.FALSE)),
4065
List.of("user2"), Map.of("title", List.of("Software Engineer")), false);
4166

42-
verify(List.of(rule("title", Rule.Function.EQUAL, "Software Engineer")),
67+
verifyOnAccessToken(List.of(rule("title", Rule.Function.EQUAL, "Software Engineer")),
4368
List.of("admin"), Map.of("title", List.of("Software Engineer")), true);
44-
verify(List.of(rule("title", Rule.Function.EQUAL, "Engineer")),
69+
verifyOnAccessToken(List.of(rule("title", Rule.Function.EQUAL, "Engineer")),
4570
List.of("user"), Map.of("title", List.of("Software Engineer")), false);
4671

47-
verify(List.of(rule("email", Rule.Function.CONTAIN, "@example.com")),
72+
verifyOnAccessToken(List.of(rule("email", Rule.Function.CONTAIN, "@example.com")),
4873
List.of("admin"), Map.of("email", List.of("foo_bar@example.com")), true);
49-
verify(List.of(rule("email", Rule.Function.CONTAIN, "@example.com")),
74+
verifyOnAccessToken(List.of(rule("email", Rule.Function.CONTAIN, "@example.com")),
5075
List.of("user"), Map.of("email", List.of("foo_bar@mail.com")), false);
51-
verify(List.of(rule("email", Rule.Function.CONTAIN, "@example.com")),
76+
verifyOnAccessToken(List.of(rule("email", Rule.Function.CONTAIN, "@example.com")),
5277
List.of("user"), Map.of(), false);
5378

54-
verify(List.of(rule("title", Rule.Function.REGEX, ".*")),
79+
verifyOnAccessToken(List.of(rule("title", Rule.Function.REGEX, ".*")),
5580
List.of("admin"), Map.of("title", List.of("Developer")), true);
56-
verify(List.of(rule("title", Rule.Function.REGEX, "(Developer|Manager)")),
81+
verifyOnAccessToken(List.of(rule("title", Rule.Function.REGEX, "(Developer|Manager)")),
5782
List.of("user"), Map.of("title", List.of("Manager")), true);
58-
verify(List.of(rule("title", Rule.Function.REGEX, ".*(Manager|Developer)$")),
83+
verifyOnAccessToken(List.of(rule("title", Rule.Function.REGEX, ".*(Manager|Developer)$")),
5984
List.of("user"), Map.of("title", List.of("Senior Delivery Manager")), true);
60-
verify(List.of(rule("title", Rule.Function.REGEX, ".*(Manager|Developer)$")),
85+
verifyOnAccessToken(List.of(rule("title", Rule.Function.REGEX, ".*(Manager|Developer)$")),
6186
List.of("user"), Map.of("title", List.of("Manager Senior")), false);
87+
88+
// test access by API key
89+
90+
verifyOnApiKey(List.of(rule("title", Rule.Function.TRUE)),
91+
List.of("user1"), false);
92+
93+
verifyOnApiKey(List.of(rule("title", Rule.Function.EQUAL, "Software Engineer")),
94+
List.of("admin"), false);
95+
96+
verifyOnApiKey(List.of(rule("email", Rule.Function.CONTAIN, "@example.com")),
97+
List.of("admin"), false);
98+
99+
verifyOnApiKey(List.of(rule("title", Rule.Function.REGEX, ".*")),
100+
List.of("admin"), false);
62101
}
63102

64103
@Test
65104
void testCombinedRules() {
66-
verify(List.of(rule("title", Rule.Function.TRUE), rule("roles", Rule.Function.EQUAL, "dial")),
105+
// test access by access token
106+
107+
verifyOnAccessToken(List.of(rule("title", Rule.Function.TRUE), rule("roles", Rule.Function.EQUAL, "dial")),
67108
List.of("user1"), Map.of("title", List.of("Software Engineer")), true);
68-
verify(List.of(rule("title", Rule.Function.TRUE), rule("roles", Rule.Function.EQUAL, "dial")),
109+
verifyOnAccessToken(List.of(rule("title", Rule.Function.TRUE), rule("roles", Rule.Function.EQUAL, "dial")),
69110
List.of("dial"), Map.of(), true);
70-
verify(List.of(rule("title", Rule.Function.CONTAIN, "Software"), rule("roles", Rule.Function.EQUAL, "dial")),
111+
verifyOnAccessToken(List.of(rule("title", Rule.Function.CONTAIN, "Software"), rule("roles", Rule.Function.EQUAL, "dial")),
71112
List.of("custom"), Map.of("title", List.of("System Engineer")), false);
72113

73-
verify(List.of(rule("title", Rule.Function.EQUAL, "Software Engineer"), rule("roles", Rule.Function.EQUAL, "dial")),
114+
verifyOnAccessToken(List.of(rule("title", Rule.Function.EQUAL, "Software Engineer"), rule("roles", Rule.Function.EQUAL, "dial")),
74115
List.of("admin"), Map.of("title", List.of("Software Engineer")), true);
75-
verify(List.of(rule("title", Rule.Function.EQUAL, "Software Engineer"), rule("roles", Rule.Function.EQUAL, "dial")),
116+
verifyOnAccessToken(List.of(rule("title", Rule.Function.EQUAL, "Software Engineer"), rule("roles", Rule.Function.EQUAL, "dial")),
76117
List.of("dial"), Map.of("title", List.of("Manager")), true);
77-
verify(List.of(rule("title", Rule.Function.EQUAL, "Engineer"), rule("roles", Rule.Function.EQUAL, "dial")),
118+
verifyOnAccessToken(List.of(rule("title", Rule.Function.EQUAL, "Engineer"), rule("roles", Rule.Function.EQUAL, "dial")),
78119
List.of("user"), Map.of("title", List.of("Software Engineer")), false);
79120

80-
verify(List.of(rule("email", Rule.Function.CONTAIN, "@example.com"), rule("roles", Rule.Function.EQUAL, "dial")),
121+
verifyOnAccessToken(List.of(rule("email", Rule.Function.CONTAIN, "@example.com"), rule("roles", Rule.Function.EQUAL, "dial")),
81122
List.of("admin"), Map.of("email", List.of("foo_bar@example.com")), true);
82-
verify(List.of(rule("email", Rule.Function.CONTAIN, "@example.com"), rule("roles", Rule.Function.EQUAL, "dial")),
123+
verifyOnAccessToken(List.of(rule("email", Rule.Function.CONTAIN, "@example.com"), rule("roles", Rule.Function.EQUAL, "dial")),
83124
List.of("dial"), Map.of("email", List.of("foo_bar@example2.com")), true);
84-
verify(List.of(rule("email", Rule.Function.CONTAIN, "@example.com"), rule("roles", Rule.Function.EQUAL, "dial")),
125+
verifyOnAccessToken(List.of(rule("email", Rule.Function.CONTAIN, "@example.com"), rule("roles", Rule.Function.EQUAL, "dial")),
85126
List.of("user"), Map.of("email", List.of("foo_bar@mail.com")), false);
86-
verify(List.of(rule("email", Rule.Function.CONTAIN, "@example.com"), rule("roles", Rule.Function.EQUAL, "dial")),
127+
verifyOnAccessToken(List.of(rule("email", Rule.Function.CONTAIN, "@example.com"), rule("roles", Rule.Function.EQUAL, "dial")),
87128
List.of("user"), Map.of(), false);
88129

89-
verify(List.of(rule("title", Rule.Function.REGEX, ".*"), rule("roles", Rule.Function.EQUAL, "dial")),
130+
verifyOnAccessToken(List.of(rule("title", Rule.Function.REGEX, ".*"), rule("roles", Rule.Function.EQUAL, "dial")),
90131
List.of("admin"), Map.of("title", List.of("Developer")), true);
91-
verify(List.of(rule("title", Rule.Function.REGEX, "^Developer$"), rule("roles", Rule.Function.EQUAL, "dial")),
132+
verifyOnAccessToken(List.of(rule("title", Rule.Function.REGEX, "^Developer$"), rule("roles", Rule.Function.EQUAL, "dial")),
92133
List.of("dial"), Map.of("title", List.of("Manager")), true);
93-
verify(List.of(rule("title", Rule.Function.REGEX, "(Developer|Manager)"), rule("roles", Rule.Function.EQUAL, "dial")),
134+
verifyOnAccessToken(List.of(rule("title", Rule.Function.REGEX, "(Developer|Manager)"), rule("roles", Rule.Function.EQUAL, "dial")),
94135
List.of("dial"), Map.of("title", List.of("Human Resource")), true);
95-
verify(List.of(rule("title", Rule.Function.REGEX, ".*(Manager|Developer)$"), rule("roles", Rule.Function.EQUAL, "dial")),
136+
verifyOnAccessToken(List.of(rule("title", Rule.Function.REGEX, ".*(Manager|Developer)$"), rule("roles", Rule.Function.EQUAL, "dial")),
96137
List.of("user"), Map.of("title", List.of("Senior Delivery Manager")), true);
97-
verify(List.of(rule("title", Rule.Function.REGEX, ".*(Manager|Developer)$"), rule("roles", Rule.Function.EQUAL, "dial")),
138+
verifyOnAccessToken(List.of(rule("title", Rule.Function.REGEX, ".*(Manager|Developer)$"), rule("roles", Rule.Function.EQUAL, "dial")),
98139
List.of("user"), Map.of("title", List.of("Manager Senior")), false);
140+
141+
// test access by API key
142+
143+
verifyOnApiKey(List.of(rule("title", Rule.Function.TRUE), rule("roles", Rule.Function.EQUAL, "dial")),
144+
List.of("user1"), false);
145+
146+
147+
verifyOnApiKey(List.of(rule("title", Rule.Function.EQUAL, "Software Engineer"), rule("roles", Rule.Function.EQUAL, "dial")),
148+
List.of("admin"), false);
149+
150+
verifyOnApiKey(List.of(rule("email", Rule.Function.CONTAIN, "@example.com"), rule("roles", Rule.Function.EQUAL, "dial")),
151+
List.of("admin"), false);
152+
153+
verifyOnApiKey(List.of(rule("title", Rule.Function.REGEX, ".*"), rule("roles", Rule.Function.EQUAL, "dial")),
154+
List.of("admin"), false);
99155
}
100156

101-
void verify(List<Rule> rules, List<String> userRoles, Map<String, List<String>> userClaims, boolean expected) {
157+
void verifyOnAccessToken(List<Rule> rules, List<String> userRoles, Map<String, List<String>> userClaims, boolean expected) {
102158
ProxyContext context = Mockito.mock(ProxyContext.class);
159+
Mockito.when(context.getUserRoles()).thenReturn(userRoles);
103160
ExtractedClaims claims = new ExtractedClaims("sub", userRoles, "hash", userClaims, null, null);
104161
Mockito.when(context.getExtractedClaims()).thenReturn(claims);
105162
boolean actual = RuleMatcher.match(context, rules);
106163
Assertions.assertEquals(expected, actual);
107164
}
108165

109-
void verify(Rule rule, boolean expected, String... roles) {
110-
verify(List.of(rule), List.of(roles), Map.of(), expected);
166+
void verifyOnAccessToken(Rule rule, boolean expected, String... roles) {
167+
verifyOnAccessToken(List.of(rule), List.of(roles), Map.of(), expected);
168+
}
169+
170+
void verifyOnApiKey(List<Rule> rules, List<String> userRoles, boolean expected) {
171+
ProxyContext context = Mockito.mock(ProxyContext.class);
172+
Mockito.when(context.getUserRoles()).thenReturn(userRoles);
173+
boolean actual = RuleMatcher.match(context, rules);
174+
Assertions.assertEquals(expected, actual);
175+
}
176+
177+
void verifyOnApiKey(Rule rule, boolean expected, String... roles) {
178+
verifyOnApiKey(List.of(rule), List.of(roles), expected);
111179
}
112180

113181
Rule rule(String source, Rule.Function function, String... targets) {

0 commit comments

Comments
 (0)