Files
compose/ai/tts-stt/app.py

129 lines
4.0 KiB
Python
Raw Normal View History

import os
import torch
import numpy as np
from TTS.api import TTS
from faster_whisper import WhisperModel
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import Response
import tempfile
import uuid
from pydantic import BaseModel
from typing import Optional
app = FastAPI(title="TTS-STT Service", description="Coqui TTS and faster-whisper service")
# Initialize models
tts_model = None
whisper_model = None
@app.on_event("startup")
async def startup_event():
global tts_model, whisper_model
# Get model names from environment or use defaults
tts_model_name = os.getenv("TTS_MODEL", "tts_models/multilingual/multi-dataset/your_tts")
whisper_model_name = os.getenv("WHISPER_MODEL", "large-v3")
device = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
print(f"Loading TTS model: {tts_model_name} on {device}")
tts_model = TTS(model_name=tts_model_name, progress_bar=False, gpu=device=="cuda")
print(f"Loading Whisper model: {whisper_model_name} on {device}")
whisper_model = WhisperModel(whisper_model_name, device=device, compute_type="float16" if device=="cuda" else "int8")
print("Models loaded successfully")
class TTSRequest(BaseModel):
text: str
language: Optional[str] = None
speaker_wav: Optional[str] = None # For voice cloning with YourTTS
@app.post("/tts")
async def text_to_speech(request: TTSRequest):
if not tts_model:
raise HTTPException(status_code=503, detail="TTS model not loaded")
try:
# Generate unique filename
filename = f"{uuid.uuid4()}.wav"
filepath = f"/tmp/{filename}"
# Synthesize speech
if request.speaker_wav and os.path.exists(request.speaker_wav):
# Voice cloning
tts_model.tts_to_file(
text=request.text,
speaker_wav=request.speaker_wav,
language=request.language or "en",
file_path=filepath
)
else:
# Regular TTS
tts_model.tts_to_file(
text=request.text,
language=request.language or "en",
file_path=filepath
)
# Read and return the audio file
with open(filepath, "rb") as f:
audio_data = f.read()
# Clean up
os.remove(filepath)
return Response(content=audio_data, media_type="audio/wav")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/stt")
async def speech_to_text(file: UploadFile = File(...), language: Optional[str] = None):
if not whisper_model:
raise HTTPException(status_code=503, detail="Whisper model not loaded")
try:
# Save uploaded file temporarily
filename = f"{uuid.uuid4()}_{file.filename}"
filepath = f"/tmp/{filename}"
with open(filepath, "wb") as buffer:
content = await file.read()
buffer.write(content)
# Transcribe
segments, info = whisper_model.transcribe(
filepath,
language=language,
beam_size=5,
vad_filter=True,
vad_parameters=dict(min_silence_duration_ms=500)
)
# Collect transcription
transcription = " ".join([segment.text for segment in segments])
# Clean up
os.remove(filepath)
return {
"text": transcription.strip(),
"language": info.language,
"language_probability": info.language_probability
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"tts_loaded": tts_model is not None,
"whisper_loaded": whisper_model is not None
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)