Skip to content

Commit 886fc35

Browse files
committed
chore(rivetkit): make execute generic
1 parent 67f79b0 commit 886fc35

File tree

7 files changed

+115
-55
lines changed

7 files changed

+115
-55
lines changed

rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/actor-db-drizzle.ts

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,38 +16,38 @@ export const dbActorDrizzle = actor({
1616
await c.db.execute(
1717
`INSERT INTO test_data (value, payload, created_at) VALUES ('${value}', '', ${Date.now()})`,
1818
);
19-
const results = (await c.db.execute(
19+
const results = await c.db.execute<{ id: number }>(
2020
`SELECT last_insert_rowid() as id`,
21-
)) as Array<{ id: number }>;
21+
);
2222
return { id: results[0].id };
2323
},
2424
getValues: async (c) => {
25-
const results = (await c.db.execute(
26-
`SELECT * FROM test_data ORDER BY id`,
27-
)) as Array<{
25+
const results = await c.db.execute<{
2826
id: number;
2927
value: string;
3028
payload: string;
3129
created_at: number;
32-
}>;
30+
}>(
31+
`SELECT * FROM test_data ORDER BY id`,
32+
);
3333
return results;
3434
},
3535
getValue: async (c, id: number) => {
36-
const results = (await c.db.execute(
36+
const results = await c.db.execute<{ value: string }>(
3737
`SELECT value FROM test_data WHERE id = ${id}`,
38-
)) as Array<{ value: string }>;
38+
);
3939
return results[0]?.value ?? null;
4040
},
4141
getCount: async (c) => {
42-
const results = (await c.db.execute(
42+
const results = await c.db.execute<{ count: number }>(
4343
`SELECT COUNT(*) as count FROM test_data`,
44-
)) as Array<{ count: number }>;
44+
);
4545
return results[0].count;
4646
},
4747
rawSelectCount: async (c) => {
48-
const results = (await c.db.execute(
48+
const results = await c.db.execute<{ count: number }>(
4949
`SELECT COUNT(*) as count FROM test_data`,
50-
)) as Array<{ count: number }>;
50+
);
5151
return results[0]?.count ?? 0;
5252
},
5353
insertMany: async (c, count: number) => {
@@ -88,15 +88,15 @@ export const dbActorDrizzle = actor({
8888
await c.db.execute(
8989
`INSERT INTO test_data (value, payload, created_at) VALUES ('payload', '${payload}', ${Date.now()})`,
9090
);
91-
const results = (await c.db.execute(
91+
const results = await c.db.execute<{ id: number }>(
9292
`SELECT last_insert_rowid() as id`,
93-
)) as Array<{ id: number }>;
93+
);
9494
return { id: results[0].id, size };
9595
},
9696
getPayloadSize: async (c, id: number) => {
97-
const results = (await c.db.execute(
97+
const results = await c.db.execute<{ size: number }>(
9898
`SELECT length(payload) as size FROM test_data WHERE id = ${id}`,
99-
)) as Array<{ size: number }>;
99+
);
100100
return results[0]?.size ?? 0;
101101
},
102102
repeatUpdate: async (c, id: number, count: number) => {
@@ -119,9 +119,9 @@ export const dbActorDrizzle = actor({
119119
await c.db.execute(
120120
`BEGIN; INSERT INTO test_data (value, payload, created_at) VALUES ('${value}', '', ${Date.now()}); UPDATE test_data SET value = '${value}-updated' WHERE id = last_insert_rowid(); COMMIT;`,
121121
);
122-
const results = (await c.db.execute(
122+
const results = await c.db.execute<{ value: string }>(
123123
`SELECT value FROM test_data ORDER BY id DESC LIMIT 1`,
124-
)) as Array<{ value: string }>;
124+
);
125125
return results[0]?.value ?? null;
126126
},
127127
triggerSleep: (c) => {

rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/actor-db-raw.ts

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,38 +22,38 @@ export const dbActorRaw = actor({
2222
await c.db.execute(
2323
`INSERT INTO test_data (value, payload, created_at) VALUES ('${value}', '', ${Date.now()})`,
2424
);
25-
const results = (await c.db.execute(
25+
const results = await c.db.execute<{ id: number }>(
2626
`SELECT last_insert_rowid() as id`,
27-
)) as Array<{ id: number }>;
27+
);
2828
return { id: results[0].id };
2929
},
3030
getValues: async (c) => {
31-
const results = (await c.db.execute(
32-
`SELECT * FROM test_data ORDER BY id`,
33-
)) as Array<{
31+
const results = await c.db.execute<{
3432
id: number;
3533
value: string;
3634
payload: string;
3735
created_at: number;
38-
}>;
36+
}>(
37+
`SELECT * FROM test_data ORDER BY id`,
38+
);
3939
return results;
4040
},
4141
getValue: async (c, id: number) => {
42-
const results = (await c.db.execute(
42+
const results = await c.db.execute<{ value: string }>(
4343
`SELECT value FROM test_data WHERE id = ${id}`,
44-
)) as Array<{ value: string }>;
44+
);
4545
return results[0]?.value ?? null;
4646
},
4747
getCount: async (c) => {
48-
const results = (await c.db.execute(
48+
const results = await c.db.execute<{ count: number }>(
4949
`SELECT COUNT(*) as count FROM test_data`,
50-
)) as Array<{ count: number }>;
50+
);
5151
return results[0].count;
5252
},
5353
rawSelectCount: async (c) => {
54-
const results = (await c.db.execute(
54+
const results = await c.db.execute<{ count: number }>(
5555
`SELECT COUNT(*) as count FROM test_data`,
56-
)) as Array<{ count: number }>;
56+
);
5757
return results[0].count;
5858
},
5959
insertMany: async (c, count: number) => {
@@ -94,15 +94,15 @@ export const dbActorRaw = actor({
9494
await c.db.execute(
9595
`INSERT INTO test_data (value, payload, created_at) VALUES ('payload', '${payload}', ${Date.now()})`,
9696
);
97-
const results = (await c.db.execute(
97+
const results = await c.db.execute<{ id: number }>(
9898
`SELECT last_insert_rowid() as id`,
99-
)) as Array<{ id: number }>;
99+
);
100100
return { id: results[0].id, size };
101101
},
102102
getPayloadSize: async (c, id: number) => {
103-
const results = (await c.db.execute(
103+
const results = await c.db.execute<{ size: number }>(
104104
`SELECT length(payload) as size FROM test_data WHERE id = ${id}`,
105-
)) as Array<{ size: number }>;
105+
);
106106
return results[0]?.size ?? 0;
107107
},
108108
repeatUpdate: async (c, id: number, count: number) => {
@@ -125,9 +125,9 @@ export const dbActorRaw = actor({
125125
await c.db.execute(
126126
`BEGIN; INSERT INTO test_data (value, payload, created_at) VALUES ('${value}', '', ${Date.now()}); UPDATE test_data SET value = '${value}-updated' WHERE id = last_insert_rowid(); COMMIT;`,
127127
);
128-
const results = (await c.db.execute(
128+
const results = await c.db.execute<{ value: string }>(
129129
`SELECT value FROM test_data ORDER BY id DESC LIMIT 1`,
130-
)) as Array<{ value: string }>;
130+
);
131131
return results[0]?.value ?? null;
132132
},
133133
triggerSleep: (c) => {

rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/workflow.ts

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -132,18 +132,18 @@ export const workflowAccessActor = actor({
132132
}
133133

134134
await loopCtx.step("access-step", async () => {
135-
await actorLoopCtx.db.execute(
135+
await loopCtx.db.execute(
136136
`INSERT INTO workflow_access_log (created_at) VALUES (${Date.now()})`,
137137
);
138-
const counts = (await actorLoopCtx.db.execute(
138+
const counts = await loopCtx.db.execute<{ count: number }>(
139139
`SELECT COUNT(*) as count FROM workflow_access_log`,
140-
)) as Array<{ count: number }>;
141-
const client = actorLoopCtx.client();
140+
);
141+
const client = loopCtx.client<typeof registry>();
142142

143-
actorLoopCtx.state.outsideDbError = outsideDbError;
144-
actorLoopCtx.state.outsideClientError = outsideClientError;
145-
actorLoopCtx.state.insideDbCount = counts[0]?.count ?? 0;
146-
actorLoopCtx.state.insideClientAvailable =
143+
loopCtx.state.outsideDbError = outsideDbError;
144+
loopCtx.state.outsideClientError = outsideClientError;
145+
loopCtx.state.insideDbCount = counts[0]?.count ?? 0;
146+
loopCtx.state.insideClientAvailable =
147147
typeof client.workflowQueueActor.getForId === "function";
148148
});
149149

rivetkit-typescript/packages/rivetkit/src/db/config.ts

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,10 @@ export type DatabaseProvider<DB extends RawAccess> = {
6666
* Raw database client with basic exec method
6767
*/
6868
export interface RawDatabaseClient {
69-
exec: (query: string, ...args: unknown[]) => Promise<unknown[]> | unknown[];
69+
exec: <TRow extends Record<string, unknown> = Record<string, unknown>>(
70+
query: string,
71+
...args: unknown[]
72+
) => Promise<TRow[]> | TRow[];
7073
}
7174

7275
/**
@@ -77,10 +80,12 @@ export interface DrizzleDatabaseClient {
7780
// For now, just a marker interface
7881
}
7982

80-
type ExecuteFunction = (
83+
type ExecuteFunction = <
84+
TRow extends Record<string, unknown> = Record<string, unknown>,
85+
>(
8186
query: string,
8287
...args: unknown[]
83-
) => Promise<unknown[]>;
88+
) => Promise<TRow[]>;
8489

8590
export type RawAccess = {
8691
/**

rivetkit-typescript/packages/rivetkit/src/db/drizzle/mod.ts

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,36 @@ export function db<
163163
const client = proxyDrizzle<TSchema>(callback, config);
164164

165165
return Object.assign(client, {
166-
execute: async (query: string, ...args: unknown[]) => {
167-
const result = await callback(query, args, "all");
168-
return result.rows;
166+
execute: async <
167+
TRow extends Record<string, unknown> = Record<string, unknown>,
168+
>(
169+
query: string,
170+
...args: unknown[]
171+
): Promise<TRow[]> => {
172+
if (args.length > 0) {
173+
const { rows, columns } = await waDb.query(query, args);
174+
return rows.map((row: unknown[]) => {
175+
const rowObj: Record<string, unknown> = {};
176+
for (let i = 0; i < row.length; i++) {
177+
rowObj[columns[i]] = row[i];
178+
}
179+
return rowObj;
180+
}) as TRow[];
181+
}
182+
183+
const results: Record<string, unknown>[] = [];
184+
let columnNames: string[] | null = null;
185+
await waDb.exec(query, (row: unknown[], columns: string[]) => {
186+
if (!columnNames) {
187+
columnNames = columns;
188+
}
189+
const rowObj: Record<string, unknown> = {};
190+
for (let i = 0; i < row.length; i++) {
191+
rowObj[columnNames[i]] = row[i];
192+
}
193+
results.push(rowObj);
194+
});
195+
return results as TRow[];
169196
},
170197
close: async () => {
171198
await waDb.close();

rivetkit-typescript/packages/rivetkit/src/db/mod.ts

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,13 @@ export function db({
4646
if (override) {
4747
// Use the override
4848
return {
49-
execute: async (query, ...args) => {
50-
return override.exec(query, ...args);
49+
execute: async <
50+
TRow extends Record<string, unknown> = Record<string, unknown>,
51+
>(
52+
query: string,
53+
...args: unknown[]
54+
): Promise<TRow[]> => {
55+
return await override.exec<TRow>(query, ...args);
5156
},
5257
close: async () => {
5358
// Override clients don't need cleanup
@@ -64,7 +69,12 @@ export function db({
6469
const db = await ctx.sqliteVfs.open(ctx.actorId, kvStore);
6570

6671
return {
67-
execute: async (query, ...args) => {
72+
execute: async <
73+
TRow extends Record<string, unknown> = Record<string, unknown>,
74+
>(
75+
query: string,
76+
...args: unknown[]
77+
): Promise<TRow[]> => {
6878
if (args.length > 0) {
6979
// Use parameterized query when args are provided
7080
const { rows, columns } = await db.query(query, args);
@@ -74,7 +84,7 @@ export function db({
7484
rowObj[columns[i]] = row[i];
7585
}
7686
return rowObj;
77-
});
87+
}) as TRow[];
7888
}
7989

8090
// Use exec for non-parameterized queries
@@ -90,7 +100,7 @@ export function db({
90100
}
91101
results.push(rowObj);
92102
});
93-
return results;
103+
return results as TRow[];
94104
},
95105
close: async () => {
96106
await db.close();

rivetkit-typescript/packages/rivetkit/tests/actor-types.test.ts

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import { actor, event, queue } from "@/actor/mod";
33
import type { ActorContext, ActorContextOf } from "@/actor/contexts";
44
import type { ActorDefinition } from "@/actor/definition";
55
import type { DatabaseProviderContext } from "@/db/config";
6+
import { db } from "@/db/mod";
67
import { workflow } from "@/workflow/mod";
78

89
describe("ActorDefinition", () => {
@@ -239,4 +240,21 @@ describe("ActorDefinition", () => {
239240
});
240241
});
241242
});
243+
244+
describe("database type inference", () => {
245+
it("supports typed rows for c.db.execute", () => {
246+
actor({
247+
state: {},
248+
db: db(),
249+
actions: {
250+
readFoo: async (c) => {
251+
const rows = await c.db.execute<{ foo: string }>(
252+
"SELECT foo FROM bar",
253+
);
254+
expectTypeOf(rows).toEqualTypeOf<Array<{ foo: string }>>();
255+
},
256+
},
257+
});
258+
});
259+
});
242260
});

0 commit comments

Comments
 (0)