Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 72 additions & 12 deletions rivetkit-typescript/packages/rivetkit/src/client/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,24 @@ import {
getRivetRunner,
} from "@/utils/env-vars";
import { RegistryConfig } from "@/registry/config";
import {
EndpointSchema,
type ParsedEndpoint,
zodCheckDuplicateCredentials,
} from "@/utils/endpoint-parser";

export const ClientConfigSchema = z.object({
/**
* Base client config schema without transforms so it can be merged in to other schemas.
*/
export const ClientConfigSchemaBase = z.object({
/** Endpoint to connect to for Rivet Engine or RivetKit manager API. */
endpoint: z
.string()
.optional()
.transform((x) => x ?? getRivetEngine() ?? getRivetEndpoint()),
endpoint: EndpointSchema.optional(),

/** Token to use to authenticate with the API. */
token: z
.string()
.optional()
.transform((x) => x ?? getRivetToken()),
token: z.string().optional(),

/** Namespace to connect to. */
namespace: z.string().default(() => getRivetNamespace() ?? "default"),
namespace: z.string().optional(),

/** Name of the runner. This is used to group together runners in to different pools. */
runnerName: z.string().default(() => getRivetRunner() ?? "default"),
Expand All @@ -46,26 +48,84 @@ export const ClientConfigSchema = z.object({
disableMetadataLookup: z.boolean().optional().default(false),
});

export const ClientConfigSchema = ClientConfigSchemaBase.transform(
(config, ctx) => transformClientConfig(config, ctx),
);

export type ClientConfig = z.infer<typeof ClientConfigSchema>;

export type ClientConfigInput = z.input<typeof ClientConfigSchema>;

export function resolveEndpoint(
parsedEndpoint: ParsedEndpoint | undefined,
): ParsedEndpoint | undefined {
if (parsedEndpoint) {
return parsedEndpoint;
}

const envEndpoint = getRivetEngine() ?? getRivetEndpoint();
if (envEndpoint) {
return EndpointSchema.parse(envEndpoint);
}

return undefined;
}

export function validateClientConfig(
resolvedEndpoint: ParsedEndpoint | undefined,
config: z.infer<typeof ClientConfigSchemaBase>,
ctx: z.RefinementCtx,
) {
if (resolvedEndpoint) {
zodCheckDuplicateCredentials(resolvedEndpoint, config, ctx);
}
}

export function transformClientConfig(
config: z.infer<typeof ClientConfigSchemaBase>,
ctx?: z.RefinementCtx,
) {
const resolvedEndpoint = resolveEndpoint(config.endpoint);

// Validate if context is provided (when called from Zod transform)
if (ctx) {
validateClientConfig(resolvedEndpoint, config, ctx);
}

return {
...config,
endpoint: resolvedEndpoint?.endpoint,
namespace:
resolvedEndpoint?.namespace ??
config.namespace ??
getRivetNamespace() ??
"default",
token: resolvedEndpoint?.token ?? config.token ?? getRivetToken(),
};
}

/**
* Converts a base config in to a client config.
*
* The base config does not include all of the properties of the client config,
* so this converts the subset of properties in to the client config.
*
* Note: We construct the object directly rather than using ClientConfigSchema.parse()
* because RegistryConfig has already transformed the endpoint, namespace, and token.
* Re-parsing would attempt to extract namespace/token from the endpoint URL again.
*/
export function convertRegistryConfigToClientConfig(
config: RegistryConfig,
): ClientConfig {
return ClientConfigSchema.parse({
return {
endpoint: config.endpoint,
token: config.token,
namespace: config.namespace,
runnerName: config.runner.runnerName,
headers: config.headers,
encoding: "bare",
getUpgradeWebSocket: undefined,
// We don't need health checks for internal clients
disableMetadataLookup: true,
});
};
}
50 changes: 33 additions & 17 deletions rivetkit-typescript/packages/rivetkit/src/drivers/engine/config.ts
Original file line number Diff line number Diff line change
@@ -1,24 +1,40 @@
import { z } from "zod";
import { ClientConfigSchema } from "@/client/config";
import {
ClientConfigSchemaBase,
transformClientConfig,
} from "@/client/config";
import { getRivetRunnerKey } from "@/utils/env-vars";

const EngineConfigSchemaBase = z
.object({
/** Unique key for this runner. Runners connecting a given key will replace any other runner connected with the same key. */
runnerKey: z
.string()
.optional()
.transform((x) => x ?? getRivetRunnerKey()),
/**
* Base engine config schema without transforms so it can be merged in to other schemas.
*
* We include the client config since this includes the common properties like endpoint, namespace, etc.
*/
export const EngineConfigSchemaBase = ClientConfigSchemaBase.extend({
/** Unique key for this runner. Runners connecting a given key will replace any other runner connected with the same key. */
runnerKey: z.string().optional(),

/** How many actors this runner can run. */
totalSlots: z.number().default(100_000),
})
// We include the client config since this includes the common properties like endpoint, namespace, etc.
.merge(ClientConfigSchema);
/** How many actors this runner can run. */
totalSlots: z.number().default(100_000),
});

export const EngingConfigSchema = EngineConfigSchemaBase.default(() =>
EngineConfigSchemaBase.parse({}),
const EngineConfigSchemaTransformed = EngineConfigSchemaBase.transform(
(config, ctx) => transformEngineConfig(config, ctx),
);

export type EngineConfig = z.infer<typeof EngingConfigSchema>;
export type EngineConfigInput = z.input<typeof EngingConfigSchema>;
export const EngineConfigSchema = EngineConfigSchemaTransformed.default(() =>
EngineConfigSchemaTransformed.parse({}),
);

export type EngineConfig = z.infer<typeof EngineConfigSchema>;
export type EngineConfigInput = z.input<typeof EngineConfigSchema>;

export function transformEngineConfig(
config: z.infer<typeof EngineConfigSchemaBase>,
ctx?: z.RefinementCtx,
) {
return {
...transformClientConfig(config, ctx),
runnerKey: config.runnerKey ?? getRivetRunnerKey(),
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export { EngineActorDriver } from "./actor-driver";
export {
type EngineConfig as Config,
type EngineConfigInput as InputConfig,
EngingConfigSchema as ConfigSchema,
EngineConfigSchema as ConfigSchema,
} from "./config";

export function createEngineDriver(): DriverConfig {
Expand Down
63 changes: 40 additions & 23 deletions rivetkit-typescript/packages/rivetkit/src/registry/config/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ import { DriverConfigSchema, type DriverConfig } from "./driver";
import invariant from "invariant";
import { RunnerConfigSchema } from "./runner";
import { ServerlessConfigSchema } from "./serverless";
import {
EndpointSchema,
zodCheckDuplicateCredentials,
} from "@/utils/endpoint-parser";
import { resolveEndpoint } from "@/client/config";

export { DriverConfigSchema, type DriverConfig };

Expand Down Expand Up @@ -76,15 +81,9 @@ export const RegistryConfigSchema = z
// getUpgradeWebSocket: z.custom<GetUpgradeWebSocket>().optional(),

// MARK: Runner Configuration
endpoint: z
.string()
.optional()
.transform((x) => x ?? getRivetEndpoint()),
token: z
.string()
.optional()
.transform((x) => x ?? getRivetToken()),
namespace: z.string().default(() => getRivetNamespace() ?? "default"),
endpoint: EndpointSchema.optional(),
token: z.string().optional(),
namespace: z.string().optional(),
headers: z.record(z.string(), z.string()).optional().default({}),

// MARK: Client
Expand Down Expand Up @@ -123,10 +122,16 @@ export const RegistryConfigSchema = z
RunnerConfigSchema.parse({}),
),
})
.superRefine((config, ctx) => {
.transform((config, ctx) => {
const isDevEnv = isDev();
const resolvedEndpoint = resolveEndpoint(config.endpoint);

// Validate duplicate credentials
if (resolvedEndpoint) {
zodCheckDuplicateCredentials(resolvedEndpoint, config, ctx);
}

if (config.endpoint && config.serveManager) {
if (resolvedEndpoint && config.serveManager) {
ctx.addIssue({
code: "custom",
message: "cannot specify both endpoint and serveManager",
Expand All @@ -135,7 +140,7 @@ export const RegistryConfigSchema = z

if (config.serverless) {
// Can't spawn engine AND connect to remote endpoint
if (config.serverless.spawnEngine && config.endpoint) {
if (config.serverless.spawnEngine && resolvedEndpoint) {
ctx.addIssue({
code: "custom",
message: "cannot specify both spawnEngine and endpoint",
Expand All @@ -145,7 +150,7 @@ export const RegistryConfigSchema = z
// configureRunnerPool requires an engine (via endpoint or spawnEngine)
if (
config.serverless.configureRunnerPool &&
!config.endpoint &&
!resolvedEndpoint &&
!config.serverless.spawnEngine
) {
ctx.addIssue({
Expand All @@ -158,7 +163,7 @@ export const RegistryConfigSchema = z
// advertiseEndpoint required in production without endpoint
if (
!isDevEnv &&
!config.endpoint &&
!resolvedEndpoint &&
!config.serverless.advertiseEndpoint
) {
ctx.addIssue({
Expand All @@ -169,21 +174,28 @@ export const RegistryConfigSchema = z
});
}
}
})
.transform((config) => {
const isDevEnv = isDev();

// Flatten the endpoint and apply defaults for namespace/token
const endpoint = resolvedEndpoint?.endpoint;
const namespace =
resolvedEndpoint?.namespace ??
config.namespace ??
getRivetNamespace() ??
"default";
const token =
resolvedEndpoint?.token ?? config.token ?? getRivetToken();

if (config.serverless) {
let serveManager: boolean;
let advertiseEndpoint: string;

if (config.endpoint) {
if (endpoint) {
// Remote endpoint provided:
// - Do not start manager server
// - Redirect clients to remote endpoint
serveManager = config.serveManager ?? false;
advertiseEndpoint =
config.serverless.advertiseEndpoint ?? config.endpoint;
config.serverless.advertiseEndpoint ?? endpoint;
} else if (isDevEnv) {
// Development mode, no endpoint:
// - Start manager server
Expand All @@ -205,8 +217,7 @@ export const RegistryConfigSchema = z
}

// If endpoint is set or spawning engine, we'll use engine driver - disable manager inspector
const willUseEngine =
!!config.endpoint || config.serverless.spawnEngine;
const willUseEngine = !!endpoint || config.serverless.spawnEngine;
const inspector = willUseEngine
? {
...config.inspector,
Expand All @@ -216,6 +227,9 @@ export const RegistryConfigSchema = z

return {
...config,
endpoint,
namespace,
token,
serveManager,
advertiseEndpoint,
inspector,
Expand All @@ -226,7 +240,7 @@ export const RegistryConfigSchema = z
// - If dev mode without endpoint: start manager server
// - If prod mode without endpoint: do not start manager server
let serveManager: boolean;
if (config.endpoint) {
if (endpoint) {
serveManager = config.serveManager ?? false;
} else if (isDevEnv) {
serveManager = config.serveManager ?? true;
Expand All @@ -235,7 +249,7 @@ export const RegistryConfigSchema = z
}

// If endpoint is set, we'll use engine driver - disable manager inspector
const willUseEngine = !!config.endpoint;
const willUseEngine = !!endpoint;
const inspector = willUseEngine
? {
...config.inspector,
Expand All @@ -245,6 +259,9 @@ export const RegistryConfigSchema = z

return {
...config,
endpoint,
namespace,
token,
serveManager,
inspector,
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ import type { Logger } from "pino";
import { z } from "zod";
import type { ActorDriverBuilder } from "@/actor/driver";
import { LogLevelSchema } from "@/common/log";
import { EngingConfigSchema as EngineConfigSchema } from "@/drivers/engine/config";
import {
EngineConfigSchemaBase,
transformEngineConfig,
} from "@/drivers/engine/config";
import { InspectorConfigSchema } from "@/inspector/config";
import type { ManagerDriverBuilder } from "@/manager/driver";
import type { GetUpgradeWebSocket } from "@/utils";
Expand Down Expand Up @@ -134,11 +137,16 @@ const LegacyRunnerConfigSchemaUnmerged = z
// created or must be imported async using `await import(...)`
getUpgradeWebSocket: z.custom<GetUpgradeWebSocket>().optional(),
})
.merge(EngineConfigSchema.removeDefault());
.merge(EngineConfigSchemaBase);

const LegacyRunnerConfigSchemaTransformed =
LegacyRunnerConfigSchemaUnmerged.transform((config, ctx) => ({
...config,
...transformEngineConfig(config, ctx),
}));

const LegacyRunnerConfigSchemaBase = LegacyRunnerConfigSchemaUnmerged;
export const LegacyRunnerConfigSchema = LegacyRunnerConfigSchemaBase.default(() =>
LegacyRunnerConfigSchemaBase.parse({}),
export const LegacyRunnerConfigSchema = LegacyRunnerConfigSchemaTransformed.default(
() => LegacyRunnerConfigSchemaTransformed.parse({}),
);

export type LegacyRunnerConfig = z.infer<typeof LegacyRunnerConfigSchema>;
Expand Down
Loading
Loading