-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
126 lines (102 loc) · 4.12 KB
/
main.py
File metadata and controls
126 lines (102 loc) · 4.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import tempfile
from pathlib import Path
import whisper
import soundfile as sf
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import FileResponse
from pydantic import BaseModel
from voxcpm import VoxCPM
app = FastAPI(title="Speakify", description="Audio to text and text to audio conversion service")
# Initialize models
whisper_model = None
tts_model = None
class TextInput(BaseModel):
text: str
class TranscriptionResponse(BaseModel):
text: str
@app.on_event("startup")
async def startup_event():
"""Initialize models on startup"""
global whisper_model, tts_model
# Load Whisper model
whisper_model = whisper.load_model("base")
# Load VoxCPM model for TTS
try:
tts_model = VoxCPM.from_pretrained("openbmb/VoxCPM-0.5B")
print("VoxCPM model loaded successfully")
except Exception as e:
print(f"Warning: Could not load VoxCPM model: {e}")
tts_model = None
@app.get("/")
async def root():
return {"message": "Speakify API - Audio to Text and Text to Audio"}
@app.post("/audio-to-text", response_model=TranscriptionResponse)
async def audio_to_text(file: UploadFile = File(...)):
"""Convert audio file to text using Whisper"""
if not whisper_model:
raise HTTPException(status_code=500, detail="Whisper model not loaded")
# Validate file type
allowed_types = ["audio/x-wav", "audio/wav", "audio/mp3", "audio/m4a", "audio/ogg", "audio/flac"]
if file.content_type not in allowed_types:
raise HTTPException(
status_code=400,
detail=f"Unsupported file type: {file.content_type}. Supported types: {allowed_types}"
)
try:
# Save uploaded file temporarily
filename = file.filename or "temp_audio"
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(filename).suffix) as tmp_file:
content = await file.read()
tmp_file.write(content)
tmp_file_path = tmp_file.name
# Transcribe audio
result = whisper_model.transcribe(tmp_file_path)
# Clean up temporary file
os.unlink(tmp_file_path)
return TranscriptionResponse(text=result["text"].strip())
except Exception as e:
# Clean up temporary file if it exists
if 'tmp_file_path' in locals():
try:
os.unlink(tmp_file_path)
except OSError:
pass
raise HTTPException(status_code=500, detail=f"Error processing audio: {str(e)}")
@app.post("/text-to-audio")
async def text_to_audio(text_input: TextInput):
"""Convert text to audio using VoxCPM"""
if not tts_model:
raise HTTPException(status_code=500, detail="VoxCPM TTS model not loaded")
if not text_input.text.strip():
raise HTTPException(status_code=400, detail="Text cannot be empty")
try:
# Generate audio using VoxCPM
wav = tts_model.generate(
text=text_input.text,
prompt_wav_path="reference_speaker.wav", # Voice to clone or None for no voice reference
prompt_text="Hello, this is a test of the text-to-audio endpoint", # reference text of recording
cfg_value=2.0, # Language model guidance
inference_timesteps=10 # Quality/speed tradeoff
)
# Save audio to temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
sf.write(tmp_file.name, wav, 16000)
return FileResponse(
tmp_file.name,
media_type="audio/wav",
filename="generated_audio.wav"
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generating audio: {str(e)}")
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"whisper_loaded": whisper_model is not None,
"tts_loaded": tts_model is not None
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=9876)