Skip to content

Commit 5c7217f

Browse files
feat: add CAMB AI tool integration for speech & audio
1 parent 13e144e commit 5c7217f

File tree

13 files changed

+1379
-0
lines changed

13 files changed

+1379
-0
lines changed

python/packages/autogen-ext/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ llama-cpp = [
4141
"llama-cpp-python>=0.3.8",
4242
]
4343

44+
camb = ["camb-sdk>=1.0.0"]
4445
graphrag = ["graphrag>=2.3.0"]
4546
chromadb = ["chromadb>=1.0.0"]
4647
mem0 = ["mem0ai>=0.1.98"]
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from ._audio_separation import AudioSeparationArgs, CambAudioSeparationTool
2+
from ._config import CambToolConfig
3+
from ._text_to_sound import CambTextToSoundTool, TextToSoundArgs
4+
from ._toolkit import CambAIToolkit
5+
from ._transcription import CambTranscriptionTool, TranscriptionArgs
6+
from ._translated_tts import CambTranslatedTTSTool, TranslatedTTSArgs
7+
from ._translation import CambTranslationTool, TranslationArgs
8+
from ._tts import CambTTSTool, TTSArgs
9+
from ._voice_clone import CambVoiceCloneTool, VoiceCloneArgs
10+
from ._voice_list import CambVoiceListTool, VoiceListArgs
11+
12+
__all__ = [
13+
"AudioSeparationArgs",
14+
"CambAIToolkit",
15+
"CambAudioSeparationTool",
16+
"CambTextToSoundTool",
17+
"CambToolConfig",
18+
"CambTranscriptionTool",
19+
"CambTranslatedTTSTool",
20+
"CambTranslationTool",
21+
"CambTTSTool",
22+
"CambVoiceCloneTool",
23+
"CambVoiceListTool",
24+
"TextToSoundArgs",
25+
"TranscriptionArgs",
26+
"TranslatedTTSArgs",
27+
"TranslationArgs",
28+
"TTSArgs",
29+
"VoiceCloneArgs",
30+
"VoiceListArgs",
31+
]
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import json
2+
from typing import Optional
3+
4+
from autogen_core import CancellationToken
5+
from pydantic import BaseModel, Field, model_validator
6+
from typing_extensions import Self
7+
8+
from ._base import CambBaseTool
9+
from ._config import CambToolConfig
10+
11+
12+
class AudioSeparationArgs(BaseModel):
13+
"""Arguments for the CAMB.AI audio separation tool."""
14+
15+
audio_url: Optional[str] = Field(
16+
default=None,
17+
description="URL of the audio file to separate.",
18+
)
19+
audio_file_path: Optional[str] = Field(
20+
default=None,
21+
description="Local file path of the audio file to separate.",
22+
)
23+
24+
@model_validator(mode="after")
25+
def _validate_audio_source(self) -> "AudioSeparationArgs":
26+
if not self.audio_url and not self.audio_file_path:
27+
raise ValueError("Either audio_url or audio_file_path must be provided.")
28+
if self.audio_url and self.audio_file_path:
29+
raise ValueError("Only one of audio_url or audio_file_path should be provided.")
30+
return self
31+
32+
33+
class CambAudioSeparationTool(CambBaseTool[AudioSeparationArgs, str]):
34+
"""Audio separation tool using CAMB.AI.
35+
36+
Separates vocals from background audio using the CAMB.AI audio separation API.
37+
Returns a JSON string with vocals and background URLs.
38+
39+
.. note::
40+
This tool requires the :code:`camb` extra for the :code:`autogen-ext` package.
41+
42+
To install:
43+
44+
.. code-block:: bash
45+
46+
pip install -U "autogen-agentchat" "autogen-ext[camb]"
47+
48+
Example usage:
49+
50+
.. code-block:: python
51+
52+
import asyncio
53+
from autogen_core import CancellationToken
54+
from autogen_ext.tools.camb import CambAudioSeparationTool, AudioSeparationArgs
55+
56+
async def main():
57+
tool = CambAudioSeparationTool(api_key="your-api-key")
58+
result = await tool.run(
59+
AudioSeparationArgs(audio_file_path="/path/to/audio.mp3"),
60+
CancellationToken(),
61+
)
62+
print(f"Separation result: {result}")
63+
64+
asyncio.run(main())
65+
"""
66+
67+
component_provider_override = "autogen_ext.tools.camb.CambAudioSeparationTool"
68+
69+
def __init__(
70+
self,
71+
api_key: Optional[str] = None,
72+
base_url: Optional[str] = None,
73+
timeout: Optional[float] = None,
74+
max_poll_attempts: int = 60,
75+
poll_interval: float = 2.0,
76+
) -> None:
77+
super().__init__(
78+
args_type=AudioSeparationArgs,
79+
return_type=str,
80+
name="camb_audio_separation",
81+
description=(
82+
"Separate vocals from background audio using CAMB.AI. "
83+
"Returns JSON with vocals and background URLs."
84+
),
85+
api_key=api_key,
86+
base_url=base_url,
87+
timeout=timeout,
88+
max_poll_attempts=max_poll_attempts,
89+
poll_interval=poll_interval,
90+
)
91+
92+
async def run(self, args: AudioSeparationArgs, cancellation_token: CancellationToken) -> str:
93+
client = self._get_client()
94+
95+
kwargs: dict = {}
96+
if args.audio_url:
97+
kwargs["media_url"] = args.audio_url
98+
elif args.audio_file_path:
99+
kwargs["media_file"] = open(args.audio_file_path, "rb")
100+
101+
try:
102+
task = await client.audio_separation.create_audio_separation(**kwargs)
103+
finally:
104+
if "media_file" in kwargs:
105+
kwargs["media_file"].close()
106+
107+
task_id = task.task_id
108+
109+
status = await self._poll_task_status(
110+
client.audio_separation.get_audio_separation_status,
111+
task_id,
112+
)
113+
114+
run_id = status.run_id
115+
result = await client.audio_separation.get_audio_separation_run_info(run_id)
116+
117+
output = {
118+
"foreground_audio_url": getattr(result, "foreground_audio_url", None),
119+
"background_audio_url": getattr(result, "background_audio_url", None),
120+
}
121+
return json.dumps(output)
122+
123+
@classmethod
124+
def _from_config(cls, config: CambToolConfig) -> Self:
125+
return cls(
126+
api_key=config.api_key,
127+
base_url=config.base_url,
128+
timeout=config.timeout,
129+
max_poll_attempts=config.max_poll_attempts,
130+
poll_interval=config.poll_interval,
131+
)
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import asyncio
2+
import os
3+
import struct
4+
import tempfile
5+
from abc import abstractmethod
6+
from typing import Any, Generic, Optional, TypeVar
7+
8+
from autogen_core import CancellationToken, Component
9+
from autogen_core.tools import BaseTool
10+
from pydantic import BaseModel
11+
from typing_extensions import Self
12+
13+
from ._config import CambToolConfig
14+
15+
ArgsT = TypeVar("ArgsT", bound=BaseModel)
16+
ReturnT = TypeVar("ReturnT")
17+
18+
19+
class CambBaseTool(BaseTool[ArgsT, ReturnT], Component[CambToolConfig], Generic[ArgsT, ReturnT]):
20+
"""Abstract base class for CAMB.AI tools.
21+
22+
Manages the AsyncCambAI client lifecycle and provides shared utilities
23+
for polling async tasks, saving audio, and detecting audio formats.
24+
Uses the ``camb-sdk`` package with its native async client.
25+
"""
26+
27+
component_type = "tool"
28+
component_config_schema = CambToolConfig
29+
30+
def __init__(
31+
self,
32+
args_type: type[ArgsT],
33+
return_type: type[ReturnT],
34+
name: str,
35+
description: str,
36+
api_key: Optional[str] = None,
37+
base_url: Optional[str] = None,
38+
timeout: Optional[float] = None,
39+
max_poll_attempts: int = 60,
40+
poll_interval: float = 2.0,
41+
) -> None:
42+
super().__init__(
43+
args_type=args_type,
44+
return_type=return_type,
45+
name=name,
46+
description=description,
47+
)
48+
self._api_key = api_key
49+
self._base_url = base_url
50+
self._timeout = timeout
51+
self._max_poll_attempts = max_poll_attempts
52+
self._poll_interval = poll_interval
53+
self._client: Any = None
54+
55+
def _get_api_key(self) -> str:
56+
"""Resolve API key from parameter or environment variable."""
57+
key = self._api_key or os.environ.get("CAMB_API_KEY")
58+
if not key:
59+
raise ValueError(
60+
"CAMB.AI API key is required. Provide it via the api_key parameter "
61+
"or set the CAMB_API_KEY environment variable."
62+
)
63+
return key
64+
65+
def _get_client(self) -> Any:
66+
"""Get or create the AsyncCambAI client (lazy initialization)."""
67+
if self._client is None:
68+
from camb.client import AsyncCambAI
69+
70+
kwargs: dict[str, Any] = {"api_key": self._get_api_key()}
71+
if self._base_url:
72+
kwargs["base_url"] = self._base_url
73+
if self._timeout is not None:
74+
kwargs["timeout"] = self._timeout
75+
self._client = AsyncCambAI(**kwargs)
76+
return self._client
77+
78+
async def _poll_task_status(
79+
self,
80+
status_func: Any,
81+
task_id: str,
82+
) -> Any:
83+
"""Poll an async task until completion or failure.
84+
85+
Args:
86+
status_func: Async function to call for status checks (e.g. client.transcription.get_transcription_task_status).
87+
task_id: The task ID to poll.
88+
89+
Returns:
90+
The final status result when the task completes.
91+
92+
Raises:
93+
RuntimeError: If the task fails or times out.
94+
"""
95+
for _ in range(self._max_poll_attempts):
96+
result = await status_func(task_id)
97+
status = getattr(result, "status", None)
98+
if status is None and hasattr(result, "message"):
99+
status = getattr(result.message, "status", None)
100+
if status in ("SUCCESS", "complete", "completed"):
101+
return result
102+
if status in ("ERROR", "TIMEOUT", "PAYMENT_REQUIRED", "failed", "error"):
103+
reason = getattr(result, "exception_reason", "") or ""
104+
raise RuntimeError(f"CAMB.AI task failed with status: {status}. {reason}")
105+
await asyncio.sleep(self._poll_interval)
106+
raise RuntimeError(
107+
f"CAMB.AI task timed out after {self._max_poll_attempts * self._poll_interval}s"
108+
)
109+
110+
@staticmethod
111+
def _detect_audio_format(data: bytes) -> str:
112+
"""Detect audio format from raw bytes."""
113+
if data[:4] == b"RIFF":
114+
return "wav"
115+
if data[:3] == b"ID3" or data[:2] == b"\xff\xfb":
116+
return "mp3"
117+
if data[:4] == b"fLaC":
118+
return "flac"
119+
if data[:4] == b"OggS":
120+
return "ogg"
121+
return "wav"
122+
123+
@staticmethod
124+
def _add_wav_header(
125+
raw_data: bytes, sample_rate: int = 24000, channels: int = 1, bits_per_sample: int = 16
126+
) -> bytes:
127+
"""Add a WAV header to raw PCM audio data."""
128+
data_size = len(raw_data)
129+
header = struct.pack(
130+
"<4sI4s4sIHHIIHH4sI",
131+
b"RIFF",
132+
36 + data_size,
133+
b"WAVE",
134+
b"fmt ",
135+
16,
136+
1, # PCM format
137+
channels,
138+
sample_rate,
139+
sample_rate * channels * bits_per_sample // 8,
140+
channels * bits_per_sample // 8,
141+
bits_per_sample,
142+
b"data",
143+
data_size,
144+
)
145+
return header + raw_data
146+
147+
@staticmethod
148+
def _save_audio(data: bytes, extension: str = "wav") -> str:
149+
"""Save audio data to a temporary file and return the file path."""
150+
with tempfile.NamedTemporaryFile(suffix=f".{extension}", delete=False) as f:
151+
f.write(data)
152+
return f.name
153+
154+
def _to_config(self) -> CambToolConfig:
155+
return CambToolConfig(
156+
api_key=self._api_key,
157+
base_url=self._base_url,
158+
timeout=self._timeout,
159+
max_poll_attempts=self._max_poll_attempts,
160+
poll_interval=self._poll_interval,
161+
)
162+
163+
@classmethod
164+
@abstractmethod
165+
def _from_config(cls, config: CambToolConfig) -> Self:
166+
...
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from typing import Optional
2+
3+
from pydantic import BaseModel
4+
5+
6+
class CambToolConfig(BaseModel):
7+
"""Configuration for CAMB.AI tools.
8+
9+
Args:
10+
api_key: CAMB.AI API key. If not provided, falls back to CAMB_API_KEY environment variable.
11+
base_url: Base URL for the CAMB.AI API.
12+
timeout: Request timeout in seconds.
13+
max_poll_attempts: Maximum number of polling attempts for async tasks.
14+
poll_interval: Interval between polling attempts in seconds.
15+
"""
16+
17+
api_key: Optional[str] = None
18+
base_url: Optional[str] = None
19+
timeout: Optional[float] = None
20+
max_poll_attempts: int = 60
21+
poll_interval: float = 2.0

0 commit comments

Comments
 (0)