Skip to content
Open
18 changes: 13 additions & 5 deletions core/src/main/kotlin/com/avsystem/justworks/core/ArrowHelpers.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,27 @@ import kotlin.contracts.InvocationKind.AT_MOST_ONCE
import kotlin.contracts.contract

@OptIn(ExperimentalContracts::class)
context(warnings: IorRaise<Nel<Error>>)
context(iorRaise: IorRaise<Nel<Error>>)
inline fun <Error> ensureOrAccumulate(condition: Boolean, error: () -> Error) {
contract { callsInPlace(error, AT_MOST_ONCE) }
if (!condition) {
warnings.accumulate(nonEmptyListOf(error()))
iorRaise.accumulate(nonEmptyListOf(error()))
}
}

@OptIn(ExperimentalContracts::class)
context(warnings: IorRaise<Nel<Error>>)
inline fun <Error, B : Any> ensureNotNullOrAccumulate(value: B?, error: () -> Error) {
context(iorRaise: IorRaise<Nel<Error>>)
inline fun <Error, B : Any> ensureNotNullOrAccumulate(value: B?, error: () -> Error): B? {
contract { callsInPlace(error, AT_MOST_ONCE) }
if (value == null) {
warnings.accumulate(nonEmptyListOf(error()))
iorRaise.accumulate(nonEmptyListOf(error()))
}
return value
}

/** Accumulates a single error and returns `null`, for use in `when` branches that must yield a nullable result. */
context(iorRaise: IorRaise<Nel<Error>>)
fun <Error> accumulate(error: Error): Nothing? {
iorRaise.accumulate(nonEmptyListOf(error))
return null
}
3 changes: 2 additions & 1 deletion core/src/main/kotlin/com/avsystem/justworks/core/Issue.kt
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
@file:OptIn(ExperimentalRaiseAccumulateApi::class)

package com.avsystem.justworks.core

import arrow.core.Nel
import arrow.core.raise.ExperimentalRaiseAccumulateApi
import arrow.core.raise.IorRaise
import kotlin.contracts.ExperimentalContracts

object Issue {
data class Error(val message: String)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import com.avsystem.justworks.core.gen.model.ModelGenerator
import com.avsystem.justworks.core.gen.shared.ApiClientBaseGenerator
import com.avsystem.justworks.core.gen.shared.ApiResponseGenerator
import com.avsystem.justworks.core.model.ApiSpec
import com.avsystem.justworks.core.model.SecurityScheme
import java.io.File

/**
Expand Down Expand Up @@ -36,8 +37,8 @@ object CodeGenerator {
return Result(modelFiles.size, clientFiles.size)
}

fun generateSharedTypes(outputDir: File): Int {
val files = ApiResponseGenerator.generate() + ApiClientBaseGenerator.generate()
fun generateSharedTypes(outputDir: File, securitySchemes: List<SecurityScheme>): Int {
val files = ApiResponseGenerator.generate() + ApiClientBaseGenerator.generate(securitySchemes)
files.forEach { it.writeTo(outputDir) }
return files.size
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ val HTTP_SUCCESS = ClassName("com.avsystem.justworks", "HttpSuccess")
// Kotlin stdlib
// ============================================================================

val BASE64_CLASS = ClassName("java.util", "Base64")
val CLOSEABLE = ClassName("java.io", "Closeable")
val IO_EXCEPTION = ClassName("java.io", "IOException")
val HTTP_REQUEST_TIMEOUT_EXCEPTION = ClassName("io.ktor.client.plugins", "HttpRequestTimeoutException")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,19 @@ import com.avsystem.justworks.core.gen.HTTP_SUCCESS
import com.avsystem.justworks.core.gen.ModelPackage
import com.avsystem.justworks.core.gen.NameRegistry
import com.avsystem.justworks.core.gen.RAISE
import com.avsystem.justworks.core.gen.TOKEN
import com.avsystem.justworks.core.gen.client.BodyGenerator.buildFunctionBody
import com.avsystem.justworks.core.gen.client.ParametersGenerator.buildBodyParams
import com.avsystem.justworks.core.gen.client.ParametersGenerator.buildNullableParameter
import com.avsystem.justworks.core.gen.invoke
import com.avsystem.justworks.core.gen.sanitizeKdoc
import com.avsystem.justworks.core.gen.shared.ApiClientBaseGenerator
import com.avsystem.justworks.core.gen.toCamelCase
import com.avsystem.justworks.core.gen.toPascalCase
import com.avsystem.justworks.core.gen.toTypeName
import com.avsystem.justworks.core.model.ApiSpec
import com.avsystem.justworks.core.model.Endpoint
import com.avsystem.justworks.core.model.ParameterLocation
import com.avsystem.justworks.core.model.SecurityScheme
import com.squareup.kotlinpoet.ClassName
import com.squareup.kotlinpoet.CodeBlock
import com.squareup.kotlinpoet.ContextParameter
Expand Down Expand Up @@ -59,7 +60,7 @@ internal object ClientGenerator {
): List<FileSpec> {
val grouped = spec.endpoints.groupBy { it.tags.firstOrNull() ?: DEFAULT_TAG }
return grouped.map { (tag, endpoints) ->
generateClientFile(tag, endpoints, hasPolymorphicTypes, nameRegistry)
generateClientFile(tag, endpoints, hasPolymorphicTypes, nameRegistry, spec.securitySchemes)
}
}

Expand All @@ -69,6 +70,7 @@ internal object ClientGenerator {
endpoints: List<Endpoint>,
hasPolymorphicTypes: Boolean,
nameRegistry: NameRegistry,
securitySchemes: List<SecurityScheme>,
): FileSpec {
val className = ClassName(apiPackage, nameRegistry.register("${tag.toPascalCase()}$API_SUFFIX"))

Expand All @@ -80,25 +82,30 @@ internal object ClientGenerator {
}

val tokenType = LambdaTypeName.get(returnType = STRING)
val authParams = ApiClientBaseGenerator.buildAuthConstructorParams(securitySchemes)

val primaryConstructor = FunSpec
val constructorBuilder = FunSpec
.constructorBuilder()
.addParameter(BASE_URL, STRING)
.addParameter(TOKEN, tokenType)
.build()

val classBuilder = TypeSpec
.classBuilder(className)
.superclass(API_CLIENT_BASE)
.addSuperclassConstructorParameter(BASE_URL)

for (paramName in authParams) {
constructorBuilder.addParameter(paramName, tokenType)
classBuilder.addSuperclassConstructorParameter(paramName)
}

val httpClientProperty = PropertySpec
.builder(CLIENT, HTTP_CLIENT)
.addModifiers(KModifier.OVERRIDE, KModifier.PROTECTED)
.initializer(clientInitializer)
.build()

val classBuilder = TypeSpec
.classBuilder(className)
.superclass(API_CLIENT_BASE)
.addSuperclassConstructorParameter(BASE_URL)
.addSuperclassConstructorParameter(TOKEN)
.primaryConstructor(primaryConstructor)
classBuilder
.primaryConstructor(constructorBuilder.build())
.addProperty(httpClientProperty)

val methodRegistry = NameRegistry()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package com.avsystem.justworks.core.gen.shared

import com.avsystem.justworks.core.gen.API_CLIENT_BASE
import com.avsystem.justworks.core.gen.APPLY_AUTH
import com.avsystem.justworks.core.gen.BASE64_CLASS
import com.avsystem.justworks.core.gen.BASE_URL
import com.avsystem.justworks.core.gen.BODY_AS_TEXT_FUN
import com.avsystem.justworks.core.gen.BODY_FUN
Expand All @@ -28,6 +29,9 @@ import com.avsystem.justworks.core.gen.RAISE_FUN
import com.avsystem.justworks.core.gen.SAFE_CALL
import com.avsystem.justworks.core.gen.SERIALIZERS_MODULE
import com.avsystem.justworks.core.gen.TOKEN
import com.avsystem.justworks.core.gen.toCamelCase
import com.avsystem.justworks.core.model.ApiKeyLocation
import com.avsystem.justworks.core.model.SecurityScheme
import com.squareup.kotlinpoet.CodeBlock
import com.squareup.kotlinpoet.ContextParameter
import com.squareup.kotlinpoet.ExperimentalKotlinPoetApi
Expand Down Expand Up @@ -59,7 +63,7 @@ internal object ApiClientBaseGenerator {
private const val BLOCK = "block"
private const val NETWORK_ERROR = "Network error"

fun generate(): FileSpec {
fun generate(securitySchemes: List<SecurityScheme>): FileSpec {
val t = TypeVariableName("T").copy(reified = true)

return FileSpec
Expand All @@ -68,7 +72,7 @@ internal object ApiClientBaseGenerator {
.addFunction(buildMapToResult(t))
.addFunction(buildToResult(t))
.addFunction(buildToEmptyResult())
.addType(buildApiClientBaseClass())
.addType(buildApiClientBaseClass(securitySchemes))
.build()
}

Expand Down Expand Up @@ -132,26 +136,37 @@ internal object ApiClientBaseGenerator {
.addStatement("return %L { Unit }", MAP_TO_RESULT)
.build()

private fun buildApiClientBaseClass(): TypeSpec {
private fun buildApiClientBaseClass(securitySchemes: List<SecurityScheme>): TypeSpec {
val tokenType = LambdaTypeName.get(returnType = STRING)
val authParams = buildAuthConstructorParams(securitySchemes)

val constructor = FunSpec
val constructorBuilder = FunSpec
.constructorBuilder()
.addParameter(BASE_URL, STRING)
.addParameter(TOKEN, tokenType)
.build()

val classBuilder = TypeSpec
.classBuilder(API_CLIENT_BASE)
.addModifiers(KModifier.ABSTRACT)
.addSuperinterface(CLOSEABLE)

val baseUrlProp = PropertySpec
.builder(BASE_URL, STRING)
.initializer(BASE_URL)
.addModifiers(KModifier.PROTECTED)
.build()

val tokenProp = PropertySpec
.builder(TOKEN, tokenType)
.initializer(TOKEN)
.addModifiers(KModifier.PRIVATE)
.build()
classBuilder.addProperty(baseUrlProp)

for (paramName in authParams) {
constructorBuilder.addParameter(paramName, tokenType)
classBuilder.addProperty(
PropertySpec
.builder(paramName, tokenType)
.initializer(paramName)
.addModifiers(KModifier.PRIVATE)
.build(),
)
}

val clientProp = PropertySpec
.builder(CLIENT, HTTP_CLIENT)
Expand All @@ -164,32 +179,118 @@ internal object ApiClientBaseGenerator {
.addStatement("$CLIENT.close()")
.build()

return TypeSpec
.classBuilder(API_CLIENT_BASE)
.addModifiers(KModifier.ABSTRACT)
.addSuperinterface(CLOSEABLE)
.primaryConstructor(constructor)
.addProperty(baseUrlProp)
.addProperty(tokenProp)
return classBuilder
.primaryConstructor(constructorBuilder.build())
.addProperty(clientProp)
.addFunction(closeFun)
.addFunction(buildApplyAuth())
.addFunction(buildApplyAuth(securitySchemes))
.addFunction(buildSafeCall())
.addFunction(buildCreateHttpClient())
.build()
}

private fun buildApplyAuth(): FunSpec = FunSpec
.builder(APPLY_AUTH)
.addModifiers(KModifier.PROTECTED)
.receiver(HTTP_REQUEST_BUILDER)
.beginControlFlow("%M", HEADERS_FUN)
.addStatement(
"append(%T.Authorization, %P)",
HTTP_HEADERS,
CodeBlock.of($$"Bearer ${'$'}{$$TOKEN()}"),
).endControlFlow()
.build()
/**
* Builds the list of auth-related constructor parameter names based on security schemes.
*/
internal fun buildAuthConstructorParams(securitySchemes: List<SecurityScheme>): List<String> {
val isSingleBearer = isSingleBearer(securitySchemes)

return securitySchemes.flatMap { scheme ->
when (scheme) {
is SecurityScheme.Bearer if isSingleBearer -> listOf(
TOKEN,
)

is SecurityScheme.Bearer -> listOf(
"${scheme.name.toCamelCase()}Token",
)

is SecurityScheme.ApiKey -> listOf(
"${scheme.name.toCamelCase()}Key",
)

is SecurityScheme.Basic -> listOf(
"${scheme.name.toCamelCase()}Username",
"${scheme.name.toCamelCase()}Password",
)
}
}
}

private fun isSingleBearer(securitySchemes: List<SecurityScheme>): Boolean =
securitySchemes.size == 1 && securitySchemes.first() is SecurityScheme.Bearer

private fun buildApplyAuth(securitySchemes: List<SecurityScheme>): FunSpec {
val builder = FunSpec
.builder(APPLY_AUTH)
.addModifiers(KModifier.PROTECTED)
.receiver(HTTP_REQUEST_BUILDER)

if (securitySchemes.isEmpty()) return builder.build()

val headerSchemes = securitySchemes.filter { scheme ->
scheme is SecurityScheme.Bearer ||
scheme is SecurityScheme.Basic ||
(scheme is SecurityScheme.ApiKey && scheme.location == ApiKeyLocation.HEADER)
}
val querySchemes = securitySchemes
.filterIsInstance<SecurityScheme.ApiKey>()
.filter { it.location == ApiKeyLocation.QUERY }

if (headerSchemes.isNotEmpty()) {
val isSingleBearer = isSingleBearer(securitySchemes)

builder.beginControlFlow("%M", HEADERS_FUN)
for (scheme in headerSchemes) {
when (scheme) {
is SecurityScheme.Bearer -> {
val paramName = if (isSingleBearer) TOKEN else "${scheme.name.toCamelCase()}Token"
builder.addStatement(
"append(%T.Authorization, %P)",
HTTP_HEADERS,
CodeBlock.of($$"Bearer ${$$paramName()}"),
)
}

is SecurityScheme.Basic -> {
val usernameParam = "${scheme.name.toCamelCase()}Username"
val passwordParam = "${scheme.name.toCamelCase()}Password"
builder.addStatement(
"append(%T.Authorization, %P)",
HTTP_HEADERS,
CodeBlock.of(
$$"Basic ${%T.getEncoder().encodeToString(\"${$$usernameParam()}:${$$passwordParam()}\".toByteArray())}",
BASE64_CLASS,
),
)
}

is SecurityScheme.ApiKey -> {
val paramName = "${scheme.name.toCamelCase()}Key"
builder.addStatement(
"append(%S, $paramName())",
scheme.parameterName,
)
}
}
}
builder.endControlFlow()
}

if (querySchemes.isNotEmpty()) {
builder.beginControlFlow("url")
for (scheme in querySchemes) {
val paramName = "${scheme.name.toCamelCase()}Key"
builder.addStatement(
"parameters.append(%S, $paramName())",
scheme.parameterName,
)
}
builder.endControlFlow()
}

return builder.build()
}

private fun buildSafeCall(): FunSpec = FunSpec
.builder(SAFE_CALL)
Expand Down
Loading
Loading