Skip to content

Commit 5077236

Browse files
authored
Merge pull request #1859 from dgageot/fix-script-args
Fix script args with DMR
2 parents 5b32cd9 + 362a24b commit 5077236

File tree

6 files changed

+254
-20
lines changed

6 files changed

+254
-20
lines changed

pkg/model/provider/dmr/client.go

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -472,10 +472,6 @@ func (c *Client) CreateChatCompletionStream(ctx context.Context, messages []chat
472472
slog.Debug("Adding tools to DMR request", "tool_count", len(requestTools))
473473
toolsParam := make([]openai.ChatCompletionToolUnionParam, len(requestTools))
474474
for i, tool := range requestTools {
475-
// DMR requires the `description` key to be present; ensure a non-empty value
476-
// NOTE(krissetto): workaround, remove when fixed upstream, this shouldn't be necceessary
477-
desc := cmp.Or(tool.Description, "Function "+tool.Name)
478-
479475
parameters, err := ConvertParametersToSchema(tool.Parameters)
480476
if err != nil {
481477
slog.Error("Failed to convert tool parameters to DMR schema", "error", err, "tool", tool.Name)
@@ -488,9 +484,11 @@ func (c *Client) CreateChatCompletionStream(ctx context.Context, messages []chat
488484
return nil, fmt.Errorf("converted parameters is not a map for tool %s", tool.Name)
489485
}
490486

487+
// DMR requires the `description` key to be present; ensure a non-empty value
488+
// NOTE(krissetto): workaround, remove when fixed upstream, this shouldn't be necessary
491489
toolsParam[i] = openai.ChatCompletionFunctionTool(shared.FunctionDefinitionParam{
492490
Name: tool.Name,
493-
Description: openai.String(desc),
491+
Description: openai.String(cmp.Or(tool.Description, "Function "+tool.Name)),
494492
Parameters: paramsMap,
495493
})
496494
}
@@ -872,20 +870,6 @@ func (c *Client) Rerank(ctx context.Context, query string, documents []types.Doc
872870
return scores, nil
873871
}
874872

875-
// ConvertParametersToSchema converts parameters to DMR Schema format
876-
func ConvertParametersToSchema(params any) (any, error) {
877-
m, err := tools.SchemaToMap(params)
878-
if err != nil {
879-
return nil, err
880-
}
881-
882-
// DMR models tend to dislike `additionalProperties` in the schema
883-
// e.g. ai/qwen3 and ai/gpt-oss
884-
delete(m, "additionalProperties")
885-
886-
return m, nil
887-
}
888-
889873
type speculativeDecodingOpts struct {
890874
draftModel string
891875
numTokens int

pkg/model/provider/dmr/schema.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package dmr
2+
3+
import (
4+
"github.com/docker/cagent/pkg/tools"
5+
)
6+
7+
// ConvertParametersToSchema converts parameters to DMR Schema format
8+
func ConvertParametersToSchema(params any) (any, error) {
9+
m, err := tools.SchemaToMap(params)
10+
if err != nil {
11+
return nil, err
12+
}
13+
14+
// DMR models tend to dislike `additionalProperties` in the schema
15+
// e.g. ai/qwen3 and ai/gpt-oss
16+
delete(m, "additionalProperties")
17+
18+
return m, nil
19+
}

pkg/tools/builtin/script_shell.go

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"context"
66
"encoding/json"
77
"fmt"
8+
"maps"
89
"os"
910
"os/exec"
1011
"slices"
@@ -110,7 +111,7 @@ func (t *ScriptShellTool) Tools(context.Context) ([]tools.Tool, error) {
110111

111112
inputSchema, err := tools.SchemaToMap(map[string]any{
112113
"type": "object",
113-
"properties": cfg.Args,
114+
"properties": defaultPropertyTypes(cfg.Args, "string"),
114115
"required": cfg.Required,
115116
})
116117
if err != nil {
@@ -158,3 +159,20 @@ func (t *ScriptShellTool) execute(ctx context.Context, toolConfig *latest.Script
158159

159160
return tools.ResultSuccess(limitOutput(string(output))), nil
160161
}
162+
163+
// defaultPropertyTypes returns a copy of properties where any property
164+
// missing a "type" field gets the given default type.
165+
func defaultPropertyTypes(properties map[string]any, defaultType string) map[string]any {
166+
result := make(map[string]any, len(properties))
167+
for k, v := range properties {
168+
if prop, ok := v.(map[string]any); ok && prop["type"] == nil {
169+
propCopy := make(map[string]any, len(prop)+1)
170+
maps.Copy(propCopy, prop)
171+
propCopy["type"] = defaultType
172+
result[k] = propCopy
173+
continue
174+
}
175+
result[k] = v
176+
}
177+
return result
178+
}

pkg/tools/builtin/script_shell_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,38 @@ func TestNewScriptShellTool_MissingRequired(t *testing.T) {
115115
require.Nil(t, tool)
116116
require.ErrorContains(t, err, "tool 'docker_images' has required arg 'img' which is not defined in args")
117117
}
118+
119+
func TestNewScriptShellTool_ArgWithoutType(t *testing.T) {
120+
shellTools := map[string]latest.ScriptShellToolConfig{
121+
"greet": {
122+
Description: "Greet someone",
123+
Cmd: "echo Hello $name",
124+
Args: map[string]any{
125+
"name": map[string]any{
126+
"description": "Name to greet",
127+
},
128+
},
129+
Required: []string{"name"},
130+
},
131+
}
132+
133+
tool, err := NewScriptShellTool(shellTools, nil)
134+
require.NoError(t, err)
135+
136+
allTools, err := tool.Tools(t.Context())
137+
require.NoError(t, err)
138+
assert.Len(t, allTools, 1)
139+
140+
schema, err := json.Marshal(allTools[0].Parameters)
141+
require.NoError(t, err)
142+
assert.JSONEq(t, `{
143+
"type": "object",
144+
"properties": {
145+
"name": {
146+
"description": "Name to greet",
147+
"type": "string"
148+
}
149+
},
150+
"required": ["name"]
151+
}`, string(schema))
152+
}

pkg/tools/schema.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,41 @@ func SchemaToMap(params any) (map[string]any, error) {
4747
delete(m, "required")
4848
}
4949

50+
// Ensure all properties have a type set, recursively.
51+
ensurePropertyTypes(m)
52+
5053
return m, nil
5154
}
5255

56+
// ensurePropertyTypes recursively walks a JSON Schema map and ensures
57+
// every property has a "type" set, defaulting to "object" if missing.
58+
// It descends into nested "properties" and array "items".
59+
func ensurePropertyTypes(schema map[string]any) {
60+
props, ok := schema["properties"].(map[string]any)
61+
if !ok {
62+
return
63+
}
64+
65+
for _, v := range props {
66+
prop, ok := v.(map[string]any)
67+
if !ok {
68+
continue
69+
}
70+
71+
if prop["type"] == nil {
72+
prop["type"] = "object"
73+
}
74+
75+
// Recurse into nested object properties.
76+
ensurePropertyTypes(prop)
77+
78+
// Recurse into array items.
79+
if items, ok := prop["items"].(map[string]any); ok {
80+
ensurePropertyTypes(items)
81+
}
82+
}
83+
}
84+
5385
func ConvertSchema(params, v any) error {
5486
// First unmarshal to a map to check we have a type and non-nil properties
5587
m, err := SchemaToMap(params)

pkg/tools/schema_test.go

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,149 @@ func TestSchemaToMap_MissingEmptyProperties(t *testing.T) {
4040
"properties": map[string]any{},
4141
}, m)
4242
}
43+
44+
func TestSchemaToMap_PropertyWithoutType(t *testing.T) {
45+
m, err := SchemaToMap(map[string]any{
46+
"type": "object",
47+
"properties": map[string]any{
48+
"name": map[string]any{
49+
"type": "string",
50+
},
51+
"metadata": map[string]any{
52+
"description": "some metadata",
53+
},
54+
},
55+
})
56+
require.NoError(t, err)
57+
58+
assert.Equal(t, map[string]any{
59+
"type": "object",
60+
"properties": map[string]any{
61+
"name": map[string]any{
62+
"type": "string",
63+
},
64+
"metadata": map[string]any{
65+
"type": "object",
66+
"description": "some metadata",
67+
},
68+
},
69+
}, m)
70+
}
71+
72+
func TestSchemaToMap_NestedPropertyWithoutType(t *testing.T) {
73+
m, err := SchemaToMap(map[string]any{
74+
"type": "object",
75+
"properties": map[string]any{
76+
"config": map[string]any{
77+
"type": "object",
78+
"properties": map[string]any{
79+
"host": map[string]any{
80+
"type": "string",
81+
},
82+
"metadata": map[string]any{
83+
"description": "nested metadata without type",
84+
},
85+
},
86+
},
87+
},
88+
})
89+
require.NoError(t, err)
90+
91+
assert.Equal(t, map[string]any{
92+
"type": "object",
93+
"properties": map[string]any{
94+
"config": map[string]any{
95+
"type": "object",
96+
"properties": map[string]any{
97+
"host": map[string]any{
98+
"type": "string",
99+
},
100+
"metadata": map[string]any{
101+
"type": "object",
102+
"description": "nested metadata without type",
103+
},
104+
},
105+
},
106+
},
107+
}, m)
108+
}
109+
110+
func TestSchemaToMap_ArrayItemsPropertyWithoutType(t *testing.T) {
111+
m, err := SchemaToMap(map[string]any{
112+
"type": "object",
113+
"properties": map[string]any{
114+
"items": map[string]any{
115+
"type": "array",
116+
"items": map[string]any{
117+
"type": "object",
118+
"properties": map[string]any{
119+
"value": map[string]any{
120+
"description": "value without type",
121+
},
122+
},
123+
},
124+
},
125+
},
126+
})
127+
require.NoError(t, err)
128+
129+
assert.Equal(t, map[string]any{
130+
"type": "object",
131+
"properties": map[string]any{
132+
"items": map[string]any{
133+
"type": "array",
134+
"items": map[string]any{
135+
"type": "object",
136+
"properties": map[string]any{
137+
"value": map[string]any{
138+
"type": "object",
139+
"description": "value without type",
140+
},
141+
},
142+
},
143+
},
144+
},
145+
}, m)
146+
}
147+
148+
func TestSchemaToMap_DeeplyNestedPropertyWithoutType(t *testing.T) {
149+
m, err := SchemaToMap(map[string]any{
150+
"type": "object",
151+
"properties": map[string]any{
152+
"level1": map[string]any{
153+
"type": "object",
154+
"properties": map[string]any{
155+
"level2": map[string]any{
156+
"type": "object",
157+
"properties": map[string]any{
158+
"level3": map[string]any{
159+
"description": "deeply nested without type",
160+
},
161+
},
162+
},
163+
},
164+
},
165+
},
166+
})
167+
require.NoError(t, err)
168+
169+
assert.Equal(t, map[string]any{
170+
"type": "object",
171+
"properties": map[string]any{
172+
"level1": map[string]any{
173+
"type": "object",
174+
"properties": map[string]any{
175+
"level2": map[string]any{
176+
"type": "object",
177+
"properties": map[string]any{
178+
"level3": map[string]any{
179+
"type": "object",
180+
"description": "deeply nested without type",
181+
},
182+
},
183+
},
184+
},
185+
},
186+
},
187+
}, m)
188+
}

0 commit comments

Comments
 (0)