Skip to content

Commit 2718df0

Browse files
committed
feat: add missing tool resolution strategy
1 parent 4c12dc4 commit 2718df0

File tree

4 files changed

+179
-34
lines changed

4 files changed

+179
-34
lines changed

core/src/main/java/com/google/adk/agents/RunConfig.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package com.google.adk.agents;
1818

19+
import com.google.adk.tools.MissingToolResolutionStrategy;
1920
import com.google.auto.value.AutoValue;
2021
import com.google.common.collect.ImmutableList;
2122
import com.google.errorprone.annotations.CanIgnoreReturnValue;
@@ -70,6 +71,8 @@ public enum ToolExecutionMode {
7071

7172
public abstract int maxLlmCalls();
7273

74+
public abstract MissingToolResolutionStrategy missingToolResolutionStrategy();
75+
7376
public abstract Builder toBuilder();
7477

7578
public static Builder builder() {
@@ -78,6 +81,7 @@ public static Builder builder() {
7881
.setResponseModalities(ImmutableList.of())
7982
.setStreamingMode(StreamingMode.NONE)
8083
.setToolExecutionMode(ToolExecutionMode.NONE)
84+
.setMissingToolResolutionStrategy(MissingToolResolutionStrategy.THROW_EXCEPTION)
8185
.setMaxLlmCalls(500);
8286
}
8387

@@ -90,7 +94,8 @@ public static Builder builder(RunConfig runConfig) {
9094
.setResponseModalities(runConfig.responseModalities())
9195
.setSpeechConfig(runConfig.speechConfig())
9296
.setOutputAudioTranscription(runConfig.outputAudioTranscription())
93-
.setInputAudioTranscription(runConfig.inputAudioTranscription());
97+
.setInputAudioTranscription(runConfig.inputAudioTranscription())
98+
.setMissingToolResolutionStrategy(runConfig.missingToolResolutionStrategy());
9499
}
95100

96101
/** Builder for {@link RunConfig}. */
@@ -123,6 +128,10 @@ public abstract Builder setInputAudioTranscription(
123128
@CanIgnoreReturnValue
124129
public abstract Builder setMaxLlmCalls(int maxLlmCalls);
125130

131+
@CanIgnoreReturnValue
132+
public abstract Builder setMissingToolResolutionStrategy(
133+
MissingToolResolutionStrategy missingToolResolutionStrategy);
134+
126135
abstract RunConfig autoBuild();
127136

128137
public RunConfig build() {

core/src/main/java/com/google/adk/flows/llmflows/Functions.java

Lines changed: 97 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@
3030
import com.google.adk.events.EventActions;
3131
import com.google.adk.tools.BaseTool;
3232
import com.google.adk.tools.FunctionTool;
33+
import com.google.adk.tools.MissingToolResolutionStrategy;
3334
import com.google.adk.tools.ToolConfirmation;
3435
import com.google.adk.tools.ToolContext;
35-
import com.google.common.base.VerifyException;
3636
import com.google.common.collect.ImmutableList;
3737
import com.google.common.collect.ImmutableMap;
3838
import com.google.genai.types.Content;
@@ -72,6 +72,84 @@ public static String generateClientFunctionCallId() {
7272
return AF_FUNCTION_CALL_ID_PREFIX + UUID.randomUUID();
7373
}
7474

75+
/** Container for separated valid and missing tool calls. */
76+
private static class ToolCallSeparation {
77+
private final ImmutableList<FunctionCall> validCalls;
78+
private final Flowable<Event> missingToolsFlowable;
79+
80+
ToolCallSeparation(
81+
ImmutableList<FunctionCall> validCalls, Flowable<Event> missingToolsFlowable) {
82+
this.validCalls = validCalls;
83+
this.missingToolsFlowable = missingToolsFlowable;
84+
}
85+
86+
ImmutableList<FunctionCall> validCalls() {
87+
return validCalls;
88+
}
89+
90+
Flowable<Event> missingToolsFlowable() {
91+
return missingToolsFlowable;
92+
}
93+
}
94+
95+
/**
96+
* Separates function calls into valid calls and missing tool events.
97+
*
98+
* @param invocationContext The invocation context.
99+
* @param functionCalls The list of function calls to separate.
100+
* @param tools The available tools.
101+
* @return A ToolCallSeparation containing valid calls and a flowable for missing tools.
102+
*/
103+
private static ToolCallSeparation separateValidAndMissingToolCalls(
104+
InvocationContext invocationContext,
105+
ImmutableList<FunctionCall> functionCalls,
106+
Map<String, BaseTool> tools) {
107+
MissingToolResolutionStrategy missingToolResolutionStrategy =
108+
invocationContext.runConfig().missingToolResolutionStrategy();
109+
ImmutableList.Builder<Maybe<Event>> missingTools = ImmutableList.builder();
110+
ImmutableList.Builder<FunctionCall> validCalls = ImmutableList.builder();
111+
112+
for (FunctionCall functionCall : functionCalls) {
113+
if (!tools.containsKey(functionCall.name().get())) {
114+
missingTools.add(
115+
missingToolResolutionStrategy.onMissingTool(invocationContext, functionCall));
116+
} else {
117+
validCalls.add(functionCall);
118+
}
119+
}
120+
121+
Flowable<Event> missingToolsFlowable =
122+
Flowable.fromIterable(missingTools.build()).concatMapMaybe(maybe -> maybe);
123+
124+
return new ToolCallSeparation(validCalls.build(), missingToolsFlowable);
125+
}
126+
127+
/**
128+
* Creates a combined flowable of function response events based on execution mode.
129+
*
130+
* @param invocationContext The invocation context.
131+
* @param validCalls The list of valid function calls.
132+
* @param missingToolsFlowable The flowable for missing tool events.
133+
* @param functionCallMapper The mapper to convert function calls to events.
134+
* @return A combined flowable of all events.
135+
*/
136+
private static Flowable<Event> createCombinedFlowable(
137+
InvocationContext invocationContext,
138+
ImmutableList<FunctionCall> validCalls,
139+
Flowable<Event> missingToolsFlowable,
140+
Function<FunctionCall, Maybe<Event>> functionCallMapper) {
141+
Flowable<Event> functionResponseEventsFlowable;
142+
if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) {
143+
functionResponseEventsFlowable =
144+
Flowable.fromIterable(validCalls).concatMapMaybe(functionCallMapper);
145+
} else {
146+
functionResponseEventsFlowable =
147+
Flowable.fromIterable(validCalls).flatMapMaybe(functionCallMapper);
148+
}
149+
150+
return Flowable.concat(missingToolsFlowable, functionResponseEventsFlowable);
151+
}
152+
75153
/**
76154
* Populates missing function call IDs in the provided event's content.
77155
*
@@ -137,12 +215,8 @@ public static Maybe<Event> handleFunctionCalls(
137215
Map<String, BaseTool> tools,
138216
Map<String, ToolConfirmation> toolConfirmations) {
139217
ImmutableList<FunctionCall> functionCalls = functionCallEvent.functionCalls();
140-
141-
for (FunctionCall functionCall : functionCalls) {
142-
if (!tools.containsKey(functionCall.name().get())) {
143-
throw new VerifyException("Tool not found: " + functionCall.name().get());
144-
}
145-
}
218+
ToolCallSeparation separation =
219+
separateValidAndMissingToolCalls(invocationContext, functionCalls, tools);
146220

147221
Function<FunctionCall, Maybe<Event>> functionCallMapper =
148222
functionCall -> {
@@ -199,15 +273,13 @@ public static Maybe<Event> handleFunctionCalls(
199273
});
200274
};
201275

202-
Flowable<Event> functionResponseEventsFlowable;
203-
if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) {
204-
functionResponseEventsFlowable =
205-
Flowable.fromIterable(functionCalls).concatMapMaybe(functionCallMapper);
206-
} else {
207-
functionResponseEventsFlowable =
208-
Flowable.fromIterable(functionCalls).flatMapMaybe(functionCallMapper);
209-
}
210-
return functionResponseEventsFlowable
276+
Flowable<Event> allEventsFlowable =
277+
createCombinedFlowable(
278+
invocationContext,
279+
separation.validCalls(),
280+
separation.missingToolsFlowable(),
281+
functionCallMapper);
282+
return allEventsFlowable
211283
.toList()
212284
.flatMapMaybe(
213285
events -> {
@@ -242,12 +314,8 @@ public static Maybe<Event> handleFunctionCalls(
242314
public static Maybe<Event> handleFunctionCallsLive(
243315
InvocationContext invocationContext, Event functionCallEvent, Map<String, BaseTool> tools) {
244316
ImmutableList<FunctionCall> functionCalls = functionCallEvent.functionCalls();
245-
246-
for (FunctionCall functionCall : functionCalls) {
247-
if (!tools.containsKey(functionCall.name().get())) {
248-
throw new VerifyException("Tool not found: " + functionCall.name().get());
249-
}
250-
}
317+
ToolCallSeparation separation =
318+
separateValidAndMissingToolCalls(invocationContext, functionCalls, tools);
251319

252320
Function<FunctionCall, Maybe<Event>> functionCallMapper =
253321
functionCall -> {
@@ -310,18 +378,14 @@ public static Maybe<Event> handleFunctionCallsLive(
310378
});
311379
};
312380

313-
Flowable<Event> responseEventsFlowable;
314-
315-
if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) {
316-
responseEventsFlowable =
317-
Flowable.fromIterable(functionCalls).concatMapMaybe(functionCallMapper);
318-
319-
} else {
320-
responseEventsFlowable =
321-
Flowable.fromIterable(functionCalls).flatMapMaybe(functionCallMapper);
322-
}
381+
Flowable<Event> allEventsFlowable =
382+
createCombinedFlowable(
383+
invocationContext,
384+
separation.validCalls(),
385+
separation.missingToolsFlowable(),
386+
functionCallMapper);
323387

324-
return responseEventsFlowable
388+
return allEventsFlowable
325389
.toList()
326390
.flatMapMaybe(
327391
events -> {
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package com.google.adk.tools;
2+
3+
import com.google.adk.agents.InvocationContext;
4+
import com.google.adk.events.Event;
5+
import com.google.common.base.VerifyException;
6+
import com.google.genai.types.FunctionCall;
7+
import io.reactivex.rxjava3.core.Maybe;
8+
import java.util.function.BiFunction;
9+
10+
public interface MissingToolResolutionStrategy {
11+
public static final MissingToolResolutionStrategy THROW_EXCEPTION =
12+
(invocationContext, functionCall) -> {
13+
throw new VerifyException(
14+
"Tool not found: " + functionCall.name().orElse(functionCall.toJson()));
15+
};
16+
17+
public static final MissingToolResolutionStrategy RETURN_ERROR =
18+
(invocationContext, functionCall) ->
19+
Maybe.error(
20+
new VerifyException(
21+
"Tool not found: " + functionCall.name().orElse(functionCall.toJson())));
22+
23+
public static final MissingToolResolutionStrategy IGNORE =
24+
(invocationContext, functionCall) -> Maybe.empty();
25+
26+
public static MissingToolResolutionStrategy respondWithEvent(
27+
BiFunction<InvocationContext, FunctionCall, Maybe<Event>> eventFactory) {
28+
return eventFactory::apply;
29+
}
30+
31+
public static MissingToolResolutionStrategy respondWithEventSync(
32+
BiFunction<InvocationContext, FunctionCall, Event> eventFactory) {
33+
return respondWithEvent(
34+
(invocationContext, functionCall) ->
35+
Maybe.just(eventFactory.apply(invocationContext, functionCall)));
36+
}
37+
38+
Maybe<Event> onMissingTool(InvocationContext invocationContext, FunctionCall functionCall);
39+
}

core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@
2323
import static org.junit.Assert.assertThrows;
2424

2525
import com.google.adk.agents.InvocationContext;
26+
import com.google.adk.agents.RunConfig;
2627
import com.google.adk.events.Event;
2728
import com.google.adk.testing.TestUtils;
29+
import com.google.adk.tools.MissingToolResolutionStrategy;
2830
import com.google.common.collect.ImmutableList;
2931
import com.google.common.collect.ImmutableMap;
3032
import com.google.genai.types.Content;
@@ -67,6 +69,37 @@ public void handleFunctionCalls_missingTool() {
6769
invocationContext, event, /* tools= */ ImmutableMap.of()));
6870
}
6971

72+
@Test
73+
public void handleFunctionCalls_missingTool_recoveryStrategy() {
74+
InvocationContext invocationContext =
75+
createInvocationContext(
76+
createRootAgent(),
77+
RunConfig.builder()
78+
.setMissingToolResolutionStrategy(
79+
MissingToolResolutionStrategy.respondWithEventSync(
80+
(ctx, call) ->
81+
Event.builder()
82+
.content(
83+
Content.fromParts(
84+
Part.fromText("tool missing: " + call.name().get())))
85+
.build()))
86+
.build());
87+
Event event =
88+
createEvent("event").toBuilder()
89+
.content(
90+
Content.fromParts(
91+
Part.fromText("..."), Part.fromFunctionCall("missing_tool", ImmutableMap.of())))
92+
.build();
93+
94+
Event functionResponseEvent =
95+
Functions.handleFunctionCalls(invocationContext, event, /* tools= */ ImmutableMap.of())
96+
.blockingGet();
97+
98+
assertThat(functionResponseEvent).isNotNull();
99+
assertThat(functionResponseEvent.content().get().parts().get())
100+
.containsExactly(Part.fromText("tool missing: missing_tool"));
101+
}
102+
70103
@Test
71104
public void handleFunctionCalls_singleFunctionCall() {
72105
InvocationContext invocationContext = createInvocationContext(createRootAgent());

0 commit comments

Comments
 (0)