Add TTS/STT service with Coqui TTS and faster-whisper
- Added tts-stt service definition to docker-compose.yml - Created tts-stt directory with: - Dockerfile: Based on debian:trixie-slim with ROCm dependencies - requirements.txt: Python packages including TTS, faster-whisper, FastAPI - app.py: FastAPI service with /tts and /stt endpoints - Service includes GPU device mapping (/dev/kfd, /dev/dri) for ROCm acceleration - Uses YourTTS multilingual model for TTS and large-v3 for Whisper - Configured to use persistent storage for models and cache
This commit is contained in:
128
ai/tts-stt/app.py
Normal file
128
ai/tts-stt/app.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from TTS.api import TTS
|
||||
from faster_whisper import WhisperModel
|
||||
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
|
||||
from fastapi.responses import Response
|
||||
import tempfile
|
||||
import uuid
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
app = FastAPI(title="TTS-STT Service", description="Coqui TTS and faster-whisper service")
|
||||
|
||||
# Initialize models
|
||||
tts_model = None
|
||||
whisper_model = None
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
global tts_model, whisper_model
|
||||
|
||||
# Get model names from environment or use defaults
|
||||
tts_model_name = os.getenv("TTS_MODEL", "tts_models/multilingual/multi-dataset/your_tts")
|
||||
whisper_model_name = os.getenv("WHISPER_MODEL", "large-v3")
|
||||
device = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
print(f"Loading TTS model: {tts_model_name} on {device}")
|
||||
tts_model = TTS(model_name=tts_model_name, progress_bar=False, gpu=device=="cuda")
|
||||
|
||||
print(f"Loading Whisper model: {whisper_model_name} on {device}")
|
||||
whisper_model = WhisperModel(whisper_model_name, device=device, compute_type="float16" if device=="cuda" else "int8")
|
||||
|
||||
print("Models loaded successfully")
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
text: str
|
||||
language: Optional[str] = None
|
||||
speaker_wav: Optional[str] = None # For voice cloning with YourTTS
|
||||
|
||||
@app.post("/tts")
|
||||
async def text_to_speech(request: TTSRequest):
|
||||
if not tts_model:
|
||||
raise HTTPException(status_code=503, detail="TTS model not loaded")
|
||||
|
||||
try:
|
||||
# Generate unique filename
|
||||
filename = f"{uuid.uuid4()}.wav"
|
||||
filepath = f"/tmp/{filename}"
|
||||
|
||||
# Synthesize speech
|
||||
if request.speaker_wav and os.path.exists(request.speaker_wav):
|
||||
# Voice cloning
|
||||
tts_model.tts_to_file(
|
||||
text=request.text,
|
||||
speaker_wav=request.speaker_wav,
|
||||
language=request.language or "en",
|
||||
file_path=filepath
|
||||
)
|
||||
else:
|
||||
# Regular TTS
|
||||
tts_model.tts_to_file(
|
||||
text=request.text,
|
||||
language=request.language or "en",
|
||||
file_path=filepath
|
||||
)
|
||||
|
||||
# Read and return the audio file
|
||||
with open(filepath, "rb") as f:
|
||||
audio_data = f.read()
|
||||
|
||||
# Clean up
|
||||
os.remove(filepath)
|
||||
|
||||
return Response(content=audio_data, media_type="audio/wav")
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/stt")
|
||||
async def speech_to_text(file: UploadFile = File(...), language: Optional[str] = None):
|
||||
if not whisper_model:
|
||||
raise HTTPException(status_code=503, detail="Whisper model not loaded")
|
||||
|
||||
try:
|
||||
# Save uploaded file temporarily
|
||||
filename = f"{uuid.uuid4()}_{file.filename}"
|
||||
filepath = f"/tmp/{filename}"
|
||||
|
||||
with open(filepath, "wb") as buffer:
|
||||
content = await file.read()
|
||||
buffer.write(content)
|
||||
|
||||
# Transcribe
|
||||
segments, info = whisper_model.transcribe(
|
||||
filepath,
|
||||
language=language,
|
||||
beam_size=5,
|
||||
vad_filter=True,
|
||||
vad_parameters=dict(min_silence_duration_ms=500)
|
||||
)
|
||||
|
||||
# Collect transcription
|
||||
transcription = " ".join([segment.text for segment in segments])
|
||||
|
||||
# Clean up
|
||||
os.remove(filepath)
|
||||
|
||||
return {
|
||||
"text": transcription.strip(),
|
||||
"language": info.language,
|
||||
"language_probability": info.language_probability
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {
|
||||
"status": "healthy",
|
||||
"tts_loaded": tts_model is not None,
|
||||
"whisper_loaded": whisper_model is not None
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
Reference in New Issue
Block a user