271 lines
8.7 KiB
Python
271 lines
8.7 KiB
Python
from flask import Flask, request, jsonify
|
|
from vosk import Model, KaldiRecognizer
|
|
import soundfile as sf
|
|
import io
|
|
import os
|
|
import json
|
|
import numpy as np
|
|
from multiprocessing import Process, Queue, Pool, cpu_count
|
|
import difflib
|
|
import asyncio
|
|
import aiohttp
|
|
from aiohttp import web
|
|
import logging
|
|
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
|
import time
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Configuration for high-performance processing
|
|
NUM_WORKERS = 192 # Use all available cores
|
|
BATCH_SIZE = 32
|
|
MAX_CONCURRENT_PROCESSES = 48
|
|
|
|
MODEL_PATH = "/app/model"
|
|
|
|
# Global model instance (shared across processes)
|
|
model = None
|
|
|
|
def load_model():
|
|
"""Load the Vosk model"""
|
|
global model
|
|
print(f"Checking for model at: {MODEL_PATH}")
|
|
if os.path.exists(MODEL_PATH):
|
|
print(f"Model directory exists at {MODEL_PATH}")
|
|
print(f"Contents: {os.listdir(MODEL_PATH)}")
|
|
try:
|
|
model = Model(MODEL_PATH)
|
|
print("Model loaded successfully!")
|
|
return model
|
|
except Exception as e:
|
|
print(f"Error loading model: {e}")
|
|
raise RuntimeError(f"Failed to load Vosk model: {e}")
|
|
else:
|
|
print(f"Model directory not found at {MODEL_PATH}")
|
|
raise RuntimeError(f"Vosk model not found at {MODEL_PATH}. Please download and mount a model.")
|
|
|
|
def similarity(a, b):
|
|
"""Calculate similarity between two strings"""
|
|
return difflib.SequenceMatcher(None, a, b).ratio()
|
|
|
|
def confirm_voice_process(args):
|
|
"""Process a single audio file in a separate process"""
|
|
audio_bytes, reference_text, samplerate = args
|
|
|
|
try:
|
|
data, _ = sf.read(io.BytesIO(audio_bytes))
|
|
if len(data.shape) > 1:
|
|
data = data[:, 0]
|
|
if data.dtype != np.int16:
|
|
data = (data * 32767).astype(np.int16)
|
|
|
|
# Create recognizer in this process
|
|
local_model = Model(MODEL_PATH)
|
|
recognizer = KaldiRecognizer(local_model, samplerate)
|
|
recognizer.AcceptWaveform(data.tobytes())
|
|
result = recognizer.Result()
|
|
text = json.loads(result).get('text', '')
|
|
sim = similarity(text, reference_text)
|
|
|
|
return {
|
|
'transcription': text,
|
|
'similarity': sim,
|
|
'confirmed': sim > 0.2
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error processing audio: {e}")
|
|
return {
|
|
'transcription': '',
|
|
'similarity': 0.0,
|
|
'confirmed': False
|
|
}
|
|
|
|
def process_batch_parallel(audio_files, references):
|
|
"""Process a batch of audio files using parallel processing"""
|
|
# Prepare data for parallel processing
|
|
samplerates = []
|
|
for audio_bytes in audio_files:
|
|
data, samplerate = sf.read(io.BytesIO(audio_bytes))
|
|
samplerates.append(samplerate)
|
|
|
|
# Prepare arguments for parallel processing
|
|
process_args = [
|
|
(audio_bytes, reference_text, samplerate)
|
|
for audio_bytes, reference_text, samplerate in zip(audio_files, references, samplerates)
|
|
]
|
|
|
|
# Use ProcessPoolExecutor for parallel processing
|
|
with ProcessPoolExecutor(max_workers=MAX_CONCURRENT_PROCESSES) as executor:
|
|
results = list(executor.map(confirm_voice_process, process_args))
|
|
|
|
return results
|
|
|
|
# Flask app for backward compatibility
|
|
app = Flask(__name__)
|
|
|
|
@app.route('/', methods=['GET'])
|
|
def health_check():
|
|
return jsonify({'status': 'ok', 'service': 'vosk-transcription-api', 'model': 'persian'})
|
|
|
|
@app.route('/batch_confirm', methods=['POST'])
|
|
def batch_confirm():
|
|
"""Handle batch confirmation requests"""
|
|
start_time = time.time()
|
|
|
|
# Parse request
|
|
references = request.form.get('references')
|
|
if not references:
|
|
return jsonify({'error': 'Missing references'}), 400
|
|
try:
|
|
references = json.loads(references)
|
|
except Exception:
|
|
return jsonify({'error': 'Invalid references JSON'}), 400
|
|
|
|
# Get audio files
|
|
audio_files = []
|
|
for i in range(len(references)):
|
|
audio_file = request.files.get(f'audio{i}')
|
|
if not audio_file:
|
|
return jsonify({'error': f'Missing audio file audio{i}'}), 400
|
|
audio_files.append(audio_file.read())
|
|
|
|
# Process batch in parallel
|
|
results = process_batch_parallel(audio_files, references)
|
|
|
|
processing_time = time.time() - start_time
|
|
logger.info(f"Processed batch of {len(results)} files in {processing_time:.2f}s")
|
|
|
|
return jsonify({'results': results})
|
|
|
|
@app.route('/transcribe', methods=['POST'])
|
|
def transcribe():
|
|
"""Handle single transcription request"""
|
|
if 'audio' not in request.files:
|
|
return jsonify({'error': 'No audio file provided'}), 400
|
|
|
|
audio_file = request.files['audio']
|
|
audio_bytes = audio_file.read()
|
|
|
|
try:
|
|
data, samplerate = sf.read(io.BytesIO(audio_bytes))
|
|
if len(data.shape) > 1:
|
|
data = data[:, 0]
|
|
if data.dtype != np.int16:
|
|
data = (data * 32767).astype(np.int16)
|
|
|
|
recognizer = KaldiRecognizer(model, samplerate)
|
|
recognizer.AcceptWaveform(data.tobytes())
|
|
result = recognizer.Result()
|
|
text = json.loads(result).get('text', '')
|
|
|
|
return jsonify({'transcription': text})
|
|
except Exception as e:
|
|
logger.error(f"Error in transcription: {e}")
|
|
return jsonify({'error': str(e)}), 500
|
|
|
|
# Async version using aiohttp for better performance
|
|
async def async_batch_confirm(request):
|
|
"""Async version of batch confirmation"""
|
|
start_time = time.time()
|
|
|
|
# Parse multipart data
|
|
data = await request.post()
|
|
|
|
# Get references
|
|
references_text = data.get('references')
|
|
if not references_text:
|
|
return web.json_response({'error': 'Missing references'}, status=400)
|
|
|
|
try:
|
|
references = json.loads(references_text)
|
|
except Exception:
|
|
return web.json_response({'error': 'Invalid references JSON'}, status=400)
|
|
|
|
# Get audio files
|
|
audio_files = []
|
|
for i in range(len(references)):
|
|
audio_file = data.get(f'audio{i}')
|
|
if not audio_file:
|
|
return web.json_response({'error': f'Missing audio file audio{i}'}, status=400)
|
|
|
|
audio_bytes = await audio_file.read()
|
|
audio_files.append(audio_bytes)
|
|
|
|
# Process in thread pool to avoid blocking
|
|
loop = asyncio.get_event_loop()
|
|
with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_PROCESSES) as executor:
|
|
results = await loop.run_in_executor(
|
|
executor,
|
|
process_batch_parallel,
|
|
audio_files,
|
|
references
|
|
)
|
|
|
|
processing_time = time.time() - start_time
|
|
logger.info(f"Async processed batch of {len(results)} files in {processing_time:.2f}s")
|
|
|
|
return web.json_response({'results': results})
|
|
|
|
async def async_transcribe(request):
|
|
"""Async version of single transcription"""
|
|
data = await request.post()
|
|
|
|
if 'audio' not in data:
|
|
return web.json_response({'error': 'No audio file provided'}, status=400)
|
|
|
|
audio_file = data['audio']
|
|
audio_bytes = await audio_file.read()
|
|
|
|
try:
|
|
data, samplerate = sf.read(io.BytesIO(audio_bytes))
|
|
if len(data.shape) > 1:
|
|
data = data[:, 0]
|
|
if data.dtype != np.int16:
|
|
data = (data * 32767).astype(np.int16)
|
|
|
|
recognizer = KaldiRecognizer(model, samplerate)
|
|
recognizer.AcceptWaveform(data.tobytes())
|
|
result = recognizer.Result()
|
|
text = json.loads(result).get('text', '')
|
|
|
|
return web.json_response({'transcription': text})
|
|
except Exception as e:
|
|
logger.error(f"Error in async transcription: {e}")
|
|
return web.json_response({'error': str(e)}, status=500)
|
|
|
|
async def health_check_async(request):
|
|
"""Async health check"""
|
|
return web.json_response({
|
|
'status': 'ok',
|
|
'service': 'vosk-transcription-api-async',
|
|
'model': 'persian',
|
|
'workers': MAX_CONCURRENT_PROCESSES
|
|
})
|
|
|
|
def create_async_app():
|
|
"""Create async aiohttp app"""
|
|
app = web.Application()
|
|
|
|
# Add routes
|
|
app.router.add_get('/', health_check_async)
|
|
app.router.add_post('/batch_confirm', async_batch_confirm)
|
|
app.router.add_post('/transcribe', async_transcribe)
|
|
|
|
return app
|
|
|
|
if __name__ == '__main__':
|
|
# Load model
|
|
load_model()
|
|
|
|
# Choose between Flask and aiohttp based on environment
|
|
use_async = os.getenv('USE_ASYNC', 'false').lower() == 'true'
|
|
|
|
if use_async:
|
|
# Run async version
|
|
app = create_async_app()
|
|
web.run_app(app, host='0.0.0.0', port=5000)
|
|
else:
|
|
# Run Flask version
|
|
app.run(host='0.0.0.0', port=5000, threaded=True, processes=4) |