108 lines
3.8 KiB
Python
108 lines
3.8 KiB
Python
from flask import Flask, request, jsonify
|
|
from vosk import Model, KaldiRecognizer
|
|
import soundfile as sf
|
|
import io
|
|
import os
|
|
import json
|
|
import numpy as np
|
|
from multiprocessing import Process, Queue
|
|
import difflib
|
|
|
|
app = Flask(__name__)
|
|
|
|
MODEL_PATH = "/app/model"
|
|
|
|
# Check if model exists and load it
|
|
print(f"Checking for model at: {MODEL_PATH}")
|
|
if os.path.exists(MODEL_PATH):
|
|
print(f"Model directory exists at {MODEL_PATH}")
|
|
print(f"Contents: {os.listdir(MODEL_PATH)}")
|
|
try:
|
|
model = Model(MODEL_PATH)
|
|
print("Model loaded successfully!")
|
|
except Exception as e:
|
|
print(f"Error loading model: {e}")
|
|
raise RuntimeError(f"Failed to load Vosk model: {e}")
|
|
else:
|
|
print(f"Model directory not found at {MODEL_PATH}")
|
|
raise RuntimeError(f"Vosk model not found at {MODEL_PATH}. Please download and mount a model.")
|
|
|
|
def similarity(a, b):
|
|
return difflib.SequenceMatcher(None, a, b).ratio()
|
|
|
|
def confirm_voice(audio_bytes, reference_text, samplerate, queue):
|
|
data, _ = sf.read(io.BytesIO(audio_bytes))
|
|
if len(data.shape) > 1:
|
|
data = data[:, 0]
|
|
if data.dtype != np.int16:
|
|
data = (data * 32767).astype(np.int16)
|
|
recognizer = KaldiRecognizer(model, samplerate)
|
|
recognizer.AcceptWaveform(data.tobytes())
|
|
result = recognizer.Result()
|
|
text = json.loads(result).get('text', '')
|
|
sim = similarity(text, reference_text)
|
|
queue.put({'transcription': text, 'similarity': sim, 'confirmed': sim > 0.2})
|
|
|
|
@app.route('/', methods=['GET'])
|
|
def health_check():
|
|
return jsonify({'status': 'ok', 'service': 'vosk-transcription-api', 'model': 'persian'})
|
|
|
|
@app.route('/batch_confirm', methods=['POST'])
|
|
def batch_confirm():
|
|
# Expecting a multipart/form-data with multiple audio files and a JSON list of references
|
|
# audio files: audio0, audio1, ...
|
|
# references: JSON list in 'references' field
|
|
references = request.form.get('references')
|
|
if not references:
|
|
return jsonify({'error': 'Missing references'}), 400
|
|
try:
|
|
references = json.loads(references)
|
|
except Exception:
|
|
return jsonify({'error': 'Invalid references JSON'}), 400
|
|
audio_files = []
|
|
for i in range(len(references)):
|
|
audio_file = request.files.get(f'audio{i}')
|
|
if not audio_file:
|
|
return jsonify({'error': f'Missing audio file audio{i}'}), 400
|
|
audio_files.append(audio_file.read())
|
|
results = []
|
|
processes = []
|
|
queues = []
|
|
# Get sample rates for each audio
|
|
samplerates = []
|
|
for audio_bytes in audio_files:
|
|
data, samplerate = sf.read(io.BytesIO(audio_bytes))
|
|
samplerates.append(samplerate)
|
|
for idx, (audio_bytes, reference_text, samplerate) in enumerate(zip(audio_files, references, samplerates)):
|
|
queue = Queue()
|
|
p = Process(target=confirm_voice, args=(audio_bytes, reference_text, samplerate, queue))
|
|
processes.append(p)
|
|
queues.append(queue)
|
|
p.start()
|
|
for p in processes:
|
|
p.join()
|
|
for queue in queues:
|
|
results.append(queue.get())
|
|
return jsonify({'results': results})
|
|
|
|
@app.route('/transcribe', methods=['POST'])
|
|
def transcribe():
|
|
if 'audio' not in request.files:
|
|
return jsonify({'error': 'No audio file provided'}), 400
|
|
audio_file = request.files['audio']
|
|
audio_bytes = audio_file.read()
|
|
data, samplerate = sf.read(io.BytesIO(audio_bytes))
|
|
if len(data.shape) > 1:
|
|
data = data[:, 0] # Use first channel if stereo
|
|
# Convert to 16-bit PCM
|
|
if data.dtype != np.int16:
|
|
data = (data * 32767).astype(np.int16)
|
|
recognizer = KaldiRecognizer(model, samplerate)
|
|
recognizer.AcceptWaveform(data.tobytes())
|
|
result = recognizer.Result()
|
|
print(result) # For debugging
|
|
text = json.loads(result).get('text', '')
|
|
return jsonify({'transcription': text})
|
|
|
|
if __name__ == '__main__':
|
|
app.run(host='0.0.0.0', port=5000) |