Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.epam.aidial.core.credentials.service.metadata.HttpHeadersHandler;
import com.epam.aidial.core.credentials.util.JsonMapperUtil;
import com.epam.aidial.core.storage.http.HttpException;
import com.epam.aidial.core.storage.http.HttpStatus;
import com.google.common.annotations.VisibleForTesting;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
Expand Down Expand Up @@ -91,6 +92,8 @@ private <R> R execute(HttpRequest request, Class<R> responseType) {
}
}

checkOauthError(body, request.uri());

return JsonMapperUtil.convertToObject(body, responseType);
} catch (ConnectException e) {
if (hasUnresolvedAddressException(e)) {
Expand All @@ -114,4 +117,23 @@ private static boolean hasUnresolvedAddressException(Throwable ex) {
private java.time.Duration createRequestConfig() {
return java.time.Duration.ofSeconds(30);
}

/**
* Some OAuth Authorization Servers return HTTP 200 with an error payload
* instead of a proper error status code. Detect and handle this case.
*/
private static void checkOauthError(String body, URI uri) {
if (body == null || body.isBlank()) {
return;
}
var node = JsonMapperUtil.convertToObject(body, java.util.Map.class);
if (node != null && node.containsKey("error")) {
String error = String.valueOf(node.get("error"));
String description = node.containsKey("error_description")
? String.valueOf(node.get("error_description"))
: "no description";
log.debug("OAuth error in 200 response from {}: error={}, description={}", uri, error, description);
throw new HttpException(HttpStatus.BAD_REQUEST, "Authorization server returned error: %s (%s)".formatted(error, description));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,7 @@ public TokenResponse getToken(String resourceId,
.redirectUri(resourceAuthSettings.getRedirectUri())
.build();

TokenResponse tokenResponse = resourceAuthorizationClient.executePost(
resourceAuthSettings.getTokenEndpoint(),
tokenRequest.buildFormData(),
"application/x-www-form-urlencoded",
TokenResponse.class);
TokenResponse tokenResponse = doTokenCall(resourceAuthSettings.getTokenEndpoint(), tokenRequest.buildFormData());
log.debug("Finished Resource {} token retrieval", resourceId);
return tokenResponse;
}
Expand All @@ -48,12 +44,15 @@ public TokenResponse getToken(String resourceId,
.refreshToken(refreshToken)
.build();

TokenResponse tokenResponse = resourceAuthorizationClient.executePost(
resourceAuthSettings.getTokenEndpoint(),
tokenRequest.buildFormData(),
"application/x-www-form-urlencoded",
TokenResponse.class);
TokenResponse tokenResponse = doTokenCall(resourceAuthSettings.getTokenEndpoint(), tokenRequest.buildFormData());
log.debug("Finished Resource {} refresh token retrieval", resourceId);
return tokenResponse;
}

private TokenResponse doTokenCall(String tokenEndpoint, String tokenRequest) {
return resourceAuthorizationClient.executePost(
tokenEndpoint, tokenRequest,
"application/x-www-form-urlencoded",
TokenResponse.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,87 @@ void testExecutePost_ConnectException() throws Exception {
assertEquals("Cannot connect to https://example.com/resource", exception.getMessage());
}

@Test
void testExecuteGet_OauthErrorInSuccessResponse() throws Exception {
// Given
String url = "https://example.com/token";
String errorResponse = "{\"error\":\"invalid_grant\",\"error_description\":\"The authorization code has expired\"}";
HttpResponse<String> httpResponseMock = mock(HttpResponse.class);
when(httpClientMock.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))).thenReturn(httpResponseMock);
when(httpResponseMock.statusCode()).thenReturn(200);
when(httpResponseMock.body()).thenReturn(errorResponse);

// When
HttpException exception = assertThrows(HttpException.class,
() -> resourceAuthorizationClient.executeGet(url, TestResponse.class));

// Then
assertEquals(400, exception.getStatus().getCode());
assertTrue(exception.getMessage().contains("invalid_grant"));
assertTrue(exception.getMessage().contains("The authorization code has expired"));
}

@Test
void testExecutePost_OauthErrorInSuccessResponse() throws Exception {
// Given
String url = "https://example.com/token";
TestRequest requestPayload = new TestRequest("testValue");
String errorResponse = "{\"error\":\"invalid_client\",\"error_description\":\"Invalid redirect_uri\"}";
HttpResponse<String> httpResponseMock = mock(HttpResponse.class);
when(httpClientMock.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))).thenReturn(httpResponseMock);
when(httpResponseMock.statusCode()).thenReturn(200);
when(httpResponseMock.body()).thenReturn(errorResponse);

// When
HttpException exception = assertThrows(HttpException.class,
() -> resourceAuthorizationClient.executePost(url, requestPayload,
ContentType.APPLICATION_JSON.toString(), TestResponse.class));

// Then
assertEquals(400, exception.getStatus().getCode());
assertTrue(exception.getMessage().contains("invalid_client"));
assertTrue(exception.getMessage().contains("Invalid redirect_uri"));
}

@Test
void testExecutePost_OauthErrorWithoutDescription() throws Exception {
// Given
String url = "https://example.com/token";
TestRequest requestPayload = new TestRequest("testValue");
String errorResponse = "{\"error\":\"server_error\"}";
HttpResponse<String> httpResponseMock = mock(HttpResponse.class);
when(httpClientMock.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))).thenReturn(httpResponseMock);
when(httpResponseMock.statusCode()).thenReturn(200);
when(httpResponseMock.body()).thenReturn(errorResponse);

// When
HttpException exception = assertThrows(HttpException.class,
() -> resourceAuthorizationClient.executePost(url, requestPayload,
ContentType.APPLICATION_JSON.toString(), TestResponse.class));

// Then
assertEquals(400, exception.getStatus().getCode());
assertTrue(exception.getMessage().contains("server_error"));
assertTrue(exception.getMessage().contains("no description"));
}

@Test
void testExecuteGet_ValidResponseNotTreatedAsOauthError() throws Exception {
String url = "https://example.com/resource";
String jsonResponse = "{\"key\":\"value\"}";
HttpResponse<String> httpResponseMock = mock(HttpResponse.class);
when(httpClientMock.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))).thenReturn(httpResponseMock);
when(httpResponseMock.statusCode()).thenReturn(200);
when(httpResponseMock.body()).thenReturn(jsonResponse);

// When
TestResponse actualResponse = resourceAuthorizationClient.executeGet(url, TestResponse.class);

// Then
assertNotNull(actualResponse);
assertEquals("value", actualResponse.getKey());
}

@Data
@AllArgsConstructor
@NoArgsConstructor
Expand Down
Loading