129 lines
4.0 KiB
Python
129 lines
4.0 KiB
Python
|
|
import os
|
||
|
|
import torch
|
||
|
|
import numpy as np
|
||
|
|
from TTS.api import TTS
|
||
|
|
from faster_whisper import WhisperModel
|
||
|
|
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
|
||
|
|
from fastapi.responses import Response
|
||
|
|
import tempfile
|
||
|
|
import uuid
|
||
|
|
from pydantic import BaseModel
|
||
|
|
from typing import Optional
|
||
|
|
|
||
|
|
app = FastAPI(title="TTS-STT Service", description="Coqui TTS and faster-whisper service")
|
||
|
|
|
||
|
|
# Initialize models
|
||
|
|
tts_model = None
|
||
|
|
whisper_model = None
|
||
|
|
|
||
|
|
@app.on_event("startup")
|
||
|
|
async def startup_event():
|
||
|
|
global tts_model, whisper_model
|
||
|
|
|
||
|
|
# Get model names from environment or use defaults
|
||
|
|
tts_model_name = os.getenv("TTS_MODEL", "tts_models/multilingual/multi-dataset/your_tts")
|
||
|
|
whisper_model_name = os.getenv("WHISPER_MODEL", "large-v3")
|
||
|
|
device = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
|
||
|
|
|
||
|
|
print(f"Loading TTS model: {tts_model_name} on {device}")
|
||
|
|
tts_model = TTS(model_name=tts_model_name, progress_bar=False, gpu=device=="cuda")
|
||
|
|
|
||
|
|
print(f"Loading Whisper model: {whisper_model_name} on {device}")
|
||
|
|
whisper_model = WhisperModel(whisper_model_name, device=device, compute_type="float16" if device=="cuda" else "int8")
|
||
|
|
|
||
|
|
print("Models loaded successfully")
|
||
|
|
|
||
|
|
class TTSRequest(BaseModel):
|
||
|
|
text: str
|
||
|
|
language: Optional[str] = None
|
||
|
|
speaker_wav: Optional[str] = None # For voice cloning with YourTTS
|
||
|
|
|
||
|
|
@app.post("/tts")
|
||
|
|
async def text_to_speech(request: TTSRequest):
|
||
|
|
if not tts_model:
|
||
|
|
raise HTTPException(status_code=503, detail="TTS model not loaded")
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Generate unique filename
|
||
|
|
filename = f"{uuid.uuid4()}.wav"
|
||
|
|
filepath = f"/tmp/{filename}"
|
||
|
|
|
||
|
|
# Synthesize speech
|
||
|
|
if request.speaker_wav and os.path.exists(request.speaker_wav):
|
||
|
|
# Voice cloning
|
||
|
|
tts_model.tts_to_file(
|
||
|
|
text=request.text,
|
||
|
|
speaker_wav=request.speaker_wav,
|
||
|
|
language=request.language or "en",
|
||
|
|
file_path=filepath
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
# Regular TTS
|
||
|
|
tts_model.tts_to_file(
|
||
|
|
text=request.text,
|
||
|
|
language=request.language or "en",
|
||
|
|
file_path=filepath
|
||
|
|
)
|
||
|
|
|
||
|
|
# Read and return the audio file
|
||
|
|
with open(filepath, "rb") as f:
|
||
|
|
audio_data = f.read()
|
||
|
|
|
||
|
|
# Clean up
|
||
|
|
os.remove(filepath)
|
||
|
|
|
||
|
|
return Response(content=audio_data, media_type="audio/wav")
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
raise HTTPException(status_code=500, detail=str(e))
|
||
|
|
|
||
|
|
@app.post("/stt")
|
||
|
|
async def speech_to_text(file: UploadFile = File(...), language: Optional[str] = None):
|
||
|
|
if not whisper_model:
|
||
|
|
raise HTTPException(status_code=503, detail="Whisper model not loaded")
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Save uploaded file temporarily
|
||
|
|
filename = f"{uuid.uuid4()}_{file.filename}"
|
||
|
|
filepath = f"/tmp/{filename}"
|
||
|
|
|
||
|
|
with open(filepath, "wb") as buffer:
|
||
|
|
content = await file.read()
|
||
|
|
buffer.write(content)
|
||
|
|
|
||
|
|
# Transcribe
|
||
|
|
segments, info = whisper_model.transcribe(
|
||
|
|
filepath,
|
||
|
|
language=language,
|
||
|
|
beam_size=5,
|
||
|
|
vad_filter=True,
|
||
|
|
vad_parameters=dict(min_silence_duration_ms=500)
|
||
|
|
)
|
||
|
|
|
||
|
|
# Collect transcription
|
||
|
|
transcription = " ".join([segment.text for segment in segments])
|
||
|
|
|
||
|
|
# Clean up
|
||
|
|
os.remove(filepath)
|
||
|
|
|
||
|
|
return {
|
||
|
|
"text": transcription.strip(),
|
||
|
|
"language": info.language,
|
||
|
|
"language_probability": info.language_probability
|
||
|
|
}
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
raise HTTPException(status_code=500, detail=str(e))
|
||
|
|
|
||
|
|
@app.get("/health")
|
||
|
|
async def health_check():
|
||
|
|
return {
|
||
|
|
"status": "healthy",
|
||
|
|
"tts_loaded": tts_model is not None,
|
||
|
|
"whisper_loaded": whisper_model is not None
|
||
|
|
}
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
import uvicorn
|
||
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|