Skip to content

Commit 4adaf07

Browse files
author
Piotr Stachaczynski
committed
feat: include tool callback func
1 parent da88cfb commit 4adaf07

File tree

17 files changed

+76
-32
lines changed

17 files changed

+76
-32
lines changed

src/MaIN.Core.UnitTests/AgentContextTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ public async Task ProcessAsync_WithStringMessage_ShouldReturnChatResult()
173173
.ReturnsAsync(chat);
174174

175175
_mockAgentService
176-
.Setup(s => s.Process(It.IsAny<Chat>(), _agentContext.GetAgentId(), It.IsAny<Knowledge>(), It.IsAny<bool>(), null))
176+
.Setup(s => s.Process(It.IsAny<Chat>(), _agentContext.GetAgentId(), It.IsAny<Knowledge>(), It.IsAny<bool>(), null, null))
177177
.ReturnsAsync(new Chat {
178178
Model = "test-model",
179179
Name = "test",

src/MaIN.Core.UnitTests/FlowContextTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ public async Task ProcessAsync_WithStringMessage_ShouldReturnChatResult()
9898
.ReturnsAsync(chat);
9999

100100
_mockAgentService
101-
.Setup(s => s.Process(It.IsAny<Chat>(), firstAgent.Id, It.IsAny<Knowledge>(), It.IsAny<bool>(), null))
101+
.Setup(s => s.Process(It.IsAny<Chat>(), firstAgent.Id, It.IsAny<Knowledge>(), It.IsAny<bool>(), null, null))
102102
.ReturnsAsync(new Chat {
103103
Model = "test-model",
104104
Name = "test",

src/MaIN.Core/.nuspec

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
<package>
33
<metadata>
44
<id>MaIN.NET</id>
5-
<version>0.7.7</version>
5+
<version>0.7.8</version>
66
<authors>Wisedev</authors>
77
<owners>Wisedev</owners>
88
<icon>favicon.png</icon>

src/MaIN.Core/Hub/Contexts/AgentContext.cs

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,11 @@ public async Task<ChatResult> ProcessAsync(Chat chat, bool translate = false)
227227
};
228228
}
229229

230-
public async Task<ChatResult> ProcessAsync(string message, bool translate = false, Func<LLMTokenValue, Task>? callback = null)
230+
public async Task<ChatResult> ProcessAsync(
231+
string message,
232+
bool translate = false,
233+
Func<LLMTokenValue, Task>? tokenCallback = null,
234+
Func<ToolInvocation, Task>? toolCallback = null)
231235
{
232236
if (_knowledge == null)
233237
{
@@ -241,7 +245,7 @@ public async Task<ChatResult> ProcessAsync(string message, bool translate = fals
241245
Type = MessageType.LocalLLM,
242246
Time = DateTime.Now
243247
});
244-
var result = await _agentService.Process(chat, _agent.Id, _knowledge, translate, callback);
248+
var result = await _agentService.Process(chat, _agent.Id, _knowledge, translate, tokenCallback, toolCallback);
245249
var messageResult = result.Messages.LastOrDefault()!;
246250
return new ChatResult()
247251
{
@@ -252,15 +256,18 @@ public async Task<ChatResult> ProcessAsync(string message, bool translate = fals
252256
};
253257
}
254258

255-
public async Task<ChatResult> ProcessAsync(Message message, bool translate = false, Func<LLMTokenValue, Task>? callback = null)
259+
public async Task<ChatResult> ProcessAsync(Message message,
260+
bool translate = false,
261+
Func<LLMTokenValue, Task>? tokenCallback = null,
262+
Func<ToolInvocation, Task>? toolCallback = null)
256263
{
257264
if (_knowledge == null)
258265
{
259266
LoadExistingKnowledgeIfExists();
260267
}
261268
var chat = await _agentService.GetChatByAgent(_agent.Id);
262269
chat.Messages.Add(message);
263-
var result = await _agentService.Process(chat, _agent.Id, _knowledge, translate, callback);
270+
var result = await _agentService.Process(chat, _agent.Id, _knowledge, translate, tokenCallback, toolCallback);;
264271
var messageResult = result.Messages.LastOrDefault()!;
265272
return new ChatResult()
266273
{
@@ -271,7 +278,11 @@ public async Task<ChatResult> ProcessAsync(Message message, bool translate = fal
271278
};
272279
}
273280

274-
public async Task<ChatResult> ProcessAsync(IEnumerable<Message> messages, bool translate = false, Func<LLMTokenValue, Task>? callback = null)
281+
public async Task<ChatResult> ProcessAsync(
282+
IEnumerable<Message> messages,
283+
bool translate = false,
284+
Func<LLMTokenValue, Task>? tokenCallback = null,
285+
Func<ToolInvocation, Task>? toolCallback = null)
275286
{
276287
if (_knowledge == null)
277288
{
@@ -285,7 +296,7 @@ public async Task<ChatResult> ProcessAsync(IEnumerable<Message> messages, bool t
285296
chat.Messages.Add(systemMsg);
286297
}
287298
chat.Messages.AddRange(messages);
288-
var result = await _agentService.Process(chat, _agent.Id, _knowledge, translate, callback);
299+
var result = await _agentService.Process(chat, _agent.Id, _knowledge, translate, tokenCallback, toolCallback);;
289300
var messageResult = result.Messages.LastOrDefault()!;
290301
return new ChatResult()
291302
{

src/MaIN.Domain/Entities/Tools/ToolCall.cs

Lines changed: 0 additions & 8 deletions
This file was deleted.
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
namespace MaIN.Domain.Entities.Tools;
2+
3+
public class ToolInvocation
4+
{
5+
public string ToolName { get; set; }
6+
public string Arguments { get; set; } = null!;
7+
public bool Done { get; set; } = false;
8+
}

src/MaIN.Services/Services/Abstract/IAgentService.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
using MaIN.Domain.Entities;
22
using MaIN.Domain.Entities.Agents;
33
using MaIN.Domain.Entities.Agents.Knowledge;
4+
using MaIN.Domain.Entities.Tools;
45
using MaIN.Domain.Models;
5-
using MaIN.Services.Services.LLMService;
66

77
namespace MaIN.Services.Services.Abstract;
88

99
public interface IAgentService
1010
{
1111
Task<Chat> Process(Chat chat, string agentId, Knowledge? knowledge, bool translatePrompt = false,
12-
Func<LLMTokenValue, Task>? callback = null);
12+
Func<LLMTokenValue, Task>? callbackToken = null, Func<ToolInvocation, Task>? callbackTool = null);
1313
Task<Agent> CreateAgent(Agent agent, bool flow = false, bool interactiveResponse = false,
1414
InferenceParams? inferenceParams = null, MemoryParams? memoryParams = null, bool disableCache = false);
1515
Task<Chat> GetChatByAgent(string agentId);

src/MaIN.Services/Services/Abstract/IStepProcessor.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using MaIN.Domain.Entities;
22
using MaIN.Domain.Entities.Agents.Knowledge;
3+
using MaIN.Domain.Entities.Tools;
34
using MaIN.Domain.Models;
45
using MaIN.Infrastructure.Models;
56
using Microsoft.Extensions.Logging;
@@ -13,6 +14,7 @@ Task<Chat> ProcessSteps(AgentContextDocument context,
1314
Knowledge? knowledge,
1415
Chat chat,
1516
Func<LLMTokenValue, Task>? callback,
17+
Func<ToolInvocation, Task>? callbackTool,
1618
Func<string, string, string?, string, string, Task> notifyProgress,
1719
Func<Chat, Task> updateChat,
1820
ILogger logger);

src/MaIN.Services/Services/AgentService.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
using MaIN.Domain.Entities;
44
using MaIN.Domain.Entities.Agents;
55
using MaIN.Domain.Entities.Agents.Knowledge;
6+
using MaIN.Domain.Entities.Tools;
67
using MaIN.Domain.Models;
78
using MaIN.Infrastructure.Repositories.Abstract;
89
using MaIN.Services.Constants;
910
using MaIN.Services.Mappers;
1011
using MaIN.Services.Services.Abstract;
1112
using MaIN.Services.Services.ImageGenServices;
12-
using MaIN.Services.Services.LLMService;
1313
using MaIN.Services.Services.LLMService.Factory;
1414
using MaIN.Services.Services.Models.Commands;
1515
using MaIN.Services.Services.Steps.Commands;
@@ -35,7 +35,8 @@ public async Task<Chat> Process(
3535
string agentId,
3636
Knowledge? knowledge,
3737
bool translatePrompt = false,
38-
Func<LLMTokenValue, Task>? callback = null)
38+
Func<LLMTokenValue, Task>? callbackToken = null,
39+
Func<ToolInvocation, Task>? callbackTool = null)
3940
{
4041
var agent = await agentRepository.GetAgentById(agentId);
4142
if (agent == null)
@@ -53,7 +54,8 @@ await notificationService.DispatchNotification(
5354
agent,
5455
knowledge,
5556
chat,
56-
callback,
57+
callbackToken,
58+
callbackTool,
5759
async (status, id, progress, behaviour, details) =>
5860
{
5961
await notificationService.DispatchNotification(

src/MaIN.Services/Services/LLMService/AnthropicService.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
using System.Text;
1111
using LLama.Common;
1212
using MaIN.Domain.Configuration;
13+
using MaIN.Domain.Entities.Tools;
1314
using MaIN.Services.Services.LLMService.Utils;
1415
using MaIN.Services.Services.LLMService;
1516

@@ -237,7 +238,19 @@ await notificationService.DispatchNotification(
237238
try
238239
{
239240
var inputJson = JsonSerializer.Serialize(toolUse.Input);
241+
options.ToolCallback?.Invoke(new ToolInvocation()
242+
{
243+
ToolName = toolUse.Name,
244+
Arguments = toolUse.Input.ToString() ?? string.Empty,
245+
Done = false
246+
});
240247
var toolResult = await executor(inputJson);
248+
options.ToolCallback?.Invoke(new ToolInvocation()
249+
{
250+
ToolName = toolUse.Name,
251+
Arguments = toolUse.Input.ToString() ?? string.Empty,
252+
Done = true
253+
});
241254

242255
toolResults.Add(new
243256
{

0 commit comments

Comments
 (0)