@@ -2,8 +2,9 @@ import type { RemoteTools } from './remote-tools';
22import type { ClientOptions } from 'openai' ;
33import type { ChatCompletionCreateParamsNonStreaming } from 'openai/resources/chat/completions' ;
44
5- import { Mistral } from '@mistralai/mistralai' ;
6- import OpenAI from 'openai' ;
5+ import { AIMessage } from '@langchain/core/messages' ;
6+ import { ChatMistralAI } from '@langchain/mistralai' ;
7+ import { ChatOpenAI } from '@langchain/openai' ;
78
89import {
910 AINotConfiguredError ,
@@ -13,7 +14,8 @@ import {
1314
1415export type OpenAiConfiguration = ClientOptions & {
1516 provider : 'openai' ;
16- model : ChatCompletionCreateParamsNonStreaming [ 'model' ] | string ;
17+ // (string & NonNullable<unknown>) allows custom models while preserving autocomplete for known models
18+ model : ChatCompletionCreateParamsNonStreaming [ 'model' ] | ( string & NonNullable < unknown > ) ;
1719} ;
1820
1921export type MistralConfiguration = {
@@ -33,10 +35,15 @@ export type OpenAIBody = Pick<
3335
3436export type DispatchBody = OpenAIBody ;
3537
38+ type LangChainClient = {
39+ invoke : ( messages : unknown ) => Promise < AIMessage > ;
40+ bindTools : ( tools : unknown , options ?: unknown ) => LangChainClient ;
41+ } ;
42+
3643export class ProviderDispatcher {
37- private readonly openAiClient : OpenAI | null = null ;
44+ private readonly client : LangChainClient | null = null ;
3845
39- private readonly mistralClient : Mistral | null = null ;
46+ private readonly provider : AiProvider | null = null ;
4047
4148 private readonly model : string ;
4249
@@ -45,129 +52,97 @@ export class ProviderDispatcher {
4552 constructor ( configuration : AiConfiguration | null , remoteTools : RemoteTools ) {
4653 this . remoteTools = remoteTools ;
4754 this . model = configuration ?. model ?? '' ;
55+ this . provider = configuration ?. provider ?? null ;
4856
4957 if ( configuration ?. provider === 'openai' && configuration . apiKey ) {
50- const { provider, model, ...clientOptions } = configuration ;
51- this . openAiClient = new OpenAI ( clientOptions ) ;
58+ const { provider, model, apiKey, ...clientOptions } = configuration ;
59+ this . client = new ChatOpenAI ( {
60+ apiKey,
61+ model,
62+ configuration : clientOptions ,
63+ } ) as unknown as LangChainClient ;
5264 }
5365
5466 if ( configuration ?. provider === 'mistral' && configuration . apiKey ) {
55- this . mistralClient = new Mistral ( { apiKey : configuration . apiKey } ) ;
67+ this . client = new ChatMistralAI ( {
68+ apiKey : configuration . apiKey ,
69+ model : configuration . model ,
70+ } ) as unknown as LangChainClient ;
5671 }
5772 }
5873
5974 async dispatch ( body : DispatchBody ) : Promise < unknown > {
60- if ( this . openAiClient ) {
61- return this . dispatchOpenAI ( body ) ;
62- }
63-
64- if ( this . mistralClient ) {
65- return this . dispatchMistral ( body ) ;
75+ if ( ! this . client ) {
76+ throw new AINotConfiguredError ( ) ;
6677 }
6778
68- throw new AINotConfiguredError ( ) ;
69- }
70-
71- private async dispatchOpenAI ( body : DispatchBody ) : Promise < unknown > {
7279 const { tools, messages, tool_choice : toolChoice } = body ;
7380
7481 try {
75- return await this . openAiClient ! . chat . completions . create ( {
76- model : this . model ,
77- tools : this . enhanceRemoteTools ( tools ) ,
78- messages,
79- tool_choice : toolChoice ,
80- } as ChatCompletionCreateParamsNonStreaming ) ;
82+ const enhancedTools = this . enhanceRemoteTools ( tools ) ;
83+
84+ const clientWithTools =
85+ enhancedTools && enhancedTools . length > 0
86+ ? this . client . bindTools ( enhancedTools , { tool_choice : toolChoice } )
87+ : this . client ;
88+
89+ const response = await clientWithTools . invoke ( messages ) ;
90+
91+ return this . convertAIMessageToOpenAI ( response ) ;
8192 } catch ( error ) {
93+ if ( this . provider === 'mistral' ) {
94+ throw new MistralUnprocessableError (
95+ `Error while calling Mistral: ${ ( error as Error ) . message } ` ,
96+ ) ;
97+ }
98+
8299 throw new OpenAIUnprocessableError ( `Error while calling OpenAI: ${ ( error as Error ) . message } ` ) ;
83100 }
84101 }
85102
86- private async dispatchMistral ( body : DispatchBody ) : Promise < unknown > {
87- const { tools, messages, tool_choice : toolChoice } = body ;
103+ private convertAIMessageToOpenAI ( message : AIMessage ) : unknown {
104+ const toolCalls = message . tool_calls ?. map ( tc => ( {
105+ id : tc . id ?? `call_${ Date . now ( ) } ` ,
106+ type : 'function' as const ,
107+ function : {
108+ name : tc . name ,
109+ arguments : typeof tc . args === 'string' ? tc . args : JSON . stringify ( tc . args ) ,
110+ } ,
111+ } ) ) ;
88112
89- try {
90- const response = await this . mistralClient ! . chat . complete ( {
91- model : this . model ,
92- tools : this . enhanceRemoteTools ( tools ) as Parameters <
93- typeof this . mistralClient . chat . complete
94- > [ 0 ] [ 'tools' ] ,
95- messages : messages as Parameters < typeof this . mistralClient . chat . complete > [ 0 ] [ 'messages' ] ,
96- toolChoice : toolChoice as 'auto' | 'none' | 'required' ,
97- } ) ;
98-
99- return this . convertMistralToOpenAI ( response ) ;
100- } catch ( error ) {
101- throw new MistralUnprocessableError (
102- `Error while calling Mistral: ${ ( error as Error ) . message } ` ,
103- ) ;
104- }
105- }
113+ // Usage metadata types vary by provider, use type assertions
114+ const usageMetadata = message . usage_metadata as
115+ | { input_tokens ?: number ; output_tokens ?: number ; total_tokens ?: number }
116+ | undefined ;
117+ const tokenUsage = ( message . response_metadata as { tokenUsage ?: Record < string , number > } )
118+ ?. tokenUsage ;
106119
107- private convertMistralToOpenAI ( response : unknown ) : unknown {
108- const mistralResponse = response as {
109- id ?: string ;
110- model ?: string ;
111- choices ?: Array < {
112- index ?: number ;
113- message ?: {
114- role ?: string ;
115- content ?: string | null ;
116- toolCalls ?: Array < {
117- id ?: string ;
118- function ?: { name ?: string ; arguments ?: string } ;
119- } > ;
120- } ;
121- finishReason ?: string ;
122- } > ;
123- usage ?: {
124- promptTokens ?: number ;
125- completionTokens ?: number ;
126- totalTokens ?: number ;
127- } ;
128- } ;
120+ const content = typeof message . content === 'string' ? message . content : null ;
129121
130122 return {
131- id : mistralResponse . id ?? `chatcmpl-${ Date . now ( ) } ` ,
123+ id : message . id ?? `chatcmpl-${ Date . now ( ) } ` ,
132124 object : 'chat.completion' ,
133125 created : Math . floor ( Date . now ( ) / 1000 ) ,
134- model : mistralResponse . model ?? this . model ,
135- choices : ( mistralResponse . choices ?? [ ] ) . map ( choice => ( {
136- index : choice . index ?? 0 ,
137- message : {
138- role : choice . message ?. role ?? 'assistant' ,
139- content : choice . message ?. content ?? null ,
140- ...( choice . message ?. toolCalls && choice . message . toolCalls . length > 0
141- ? {
142- tool_calls : choice . message . toolCalls . map ( tc => ( {
143- id : tc . id ,
144- type : 'function' ,
145- function : {
146- name : tc . function ?. name ,
147- arguments : tc . function ?. arguments ,
148- } ,
149- } ) ) ,
150- }
151- : { } ) ,
126+ model : this . model ,
127+ choices : [
128+ {
129+ index : 0 ,
130+ message : {
131+ role : 'assistant' ,
132+ content,
133+ ...( toolCalls && toolCalls . length > 0 ? { tool_calls : toolCalls } : { } ) ,
134+ } ,
135+ finish_reason : toolCalls && toolCalls . length > 0 ? 'tool_calls' : 'stop' ,
152136 } ,
153- finish_reason : this . convertFinishReason ( choice . finishReason ) ,
154- } ) ) ,
137+ ] ,
155138 usage : {
156- prompt_tokens : mistralResponse . usage ?. promptTokens ?? 0 ,
157- completion_tokens : mistralResponse . usage ?. completionTokens ?? 0 ,
158- total_tokens : mistralResponse . usage ?. totalTokens ?? 0 ,
139+ prompt_tokens : usageMetadata ?. input_tokens ?? tokenUsage ?. promptTokens ?? 0 ,
140+ completion_tokens : usageMetadata ?. output_tokens ?? tokenUsage ?. completionTokens ?? 0 ,
141+ total_tokens : usageMetadata ?. total_tokens ?? tokenUsage ?. totalTokens ?? 0 ,
159142 } ,
160143 } ;
161144 }
162145
163- private convertFinishReason ( reason ?: string ) : string {
164- if ( reason === 'tool_calls' ) return 'tool_calls' ;
165- if ( reason === 'stop' ) return 'stop' ;
166- if ( reason === 'length' ) return 'length' ;
167-
168- return 'stop' ;
169- }
170-
171146 private enhanceRemoteTools ( tools ?: ChatCompletionCreateParamsNonStreaming [ 'tools' ] ) {
172147 if ( ! tools || ! Array . isArray ( tools ) ) return tools ;
173148
0 commit comments