Sure! Pl
This commit is contained in:
25
.gitignore
vendored
Normal file
25
.gitignore
vendored
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
# Python-generated files
|
||||||
|
__pycache__/
|
||||||
|
*.py[oc]
|
||||||
|
build/
|
||||||
|
dist/
|
||||||
|
wheels/
|
||||||
|
*.egg-info
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
.venv
|
||||||
|
|
||||||
|
# Vosk model files
|
||||||
|
vosk-model-fa-0.42/
|
||||||
|
|
||||||
|
# Archive files
|
||||||
|
*.zip
|
||||||
|
|
||||||
|
# Dataset files
|
||||||
|
confirmed_dataset/
|
||||||
|
|
||||||
|
# Data files
|
||||||
|
*.csv
|
||||||
|
|
||||||
|
# Kiro files
|
||||||
|
*.kiro
|
||||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
|||||||
|
3.10
|
||||||
212
another_copy_of_fine_tune_whisper(1)(1).py
Normal file
212
another_copy_of_fine_tune_whisper(1)(1).py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
from datasets import load_dataset, DatasetDict
|
||||||
|
|
||||||
|
common_voice = DatasetDict()
|
||||||
|
|
||||||
|
common_voice["train"] = load_dataset("Ashegh-Sad-Warrior/Persian_Common_Voice_17_0", split="validated[:20%]")
|
||||||
|
common_voice["test"] = load_dataset("Ashegh-Sad-Warrior/Persian_Common_Voice_17_0", split="validated[20%:23%]")
|
||||||
|
|
||||||
|
print(common_voice)
|
||||||
|
|
||||||
|
|
||||||
|
common_voice = common_voice.remove_columns(['client_id', 'path', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment', 'variant'])
|
||||||
|
|
||||||
|
print(common_voice)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
from transformers import WhisperFeatureExtractor
|
||||||
|
|
||||||
|
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
|
||||||
|
|
||||||
|
|
||||||
|
from transformers import WhisperTokenizer
|
||||||
|
|
||||||
|
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="fa", task="transcribe")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
from transformers import WhisperProcessor
|
||||||
|
|
||||||
|
processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="fa", task="transcribe")
|
||||||
|
|
||||||
|
from parsivar import Normalizer,Tokenizer, SpellCheck
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tqdm
|
||||||
|
normalizer = Normalizer()
|
||||||
|
spell_checker = SpellCheck()
|
||||||
|
vocab = np.array([])
|
||||||
|
for i in tqdm.tqdm(common_voice["train"]["sentence"]):
|
||||||
|
i = spell_checker.spell_corrector(normalizer.normalize(i))
|
||||||
|
vocab = np.append(vocab,Tokenizer().tokenize_words(i), axis=0)
|
||||||
|
for i in tqdm.tqdm(common_voice["test"]["sentence"]):
|
||||||
|
i = spell_checker.spell_corrector(normalizer.normalize(i))
|
||||||
|
vocab = np.append(vocab,Tokenizer().tokenize_words(i), axis=0)
|
||||||
|
vocab = np.unique(vocab)
|
||||||
|
print(vocab, vocab.shape)
|
||||||
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tqdm
|
||||||
|
normalizer = Normalizer()
|
||||||
|
vocab = np.array([])
|
||||||
|
for i in tqdm.tqdm(common_voice["train"]["sentence"]):
|
||||||
|
i = normalizer.normalize(i)
|
||||||
|
vocab = np.append(vocab,Tokenizer().tokenize_words(i), axis=0)
|
||||||
|
for i in tqdm.tqdm(common_voice["test"]["sentence"]):
|
||||||
|
i = normalizer.normalize(i)
|
||||||
|
vocab = np.append(vocab,Tokenizer().tokenize_words(i), axis=0)
|
||||||
|
vocab = np.unique(vocab)
|
||||||
|
print(vocab, vocab.shape)
|
||||||
|
|
||||||
|
tokenizer.add_tokens(list(vocab))
|
||||||
|
|
||||||
|
processor.tokenizer = tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
print(common_voice["train"][0])
|
||||||
|
|
||||||
|
|
||||||
|
from datasets import Audio
|
||||||
|
|
||||||
|
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
print(common_voice["train"][0])
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_dataset(batch):
|
||||||
|
# load and resample audio data from 48 to 16kHz
|
||||||
|
audio = batch["audio"]
|
||||||
|
|
||||||
|
# compute log-Mel input features from input audio array
|
||||||
|
batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
|
||||||
|
|
||||||
|
# encode target text to label ids
|
||||||
|
batch["labels"] = tokenizer(batch["sentence"]).input_ids
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names["train"])
|
||||||
|
|
||||||
|
|
||||||
|
from transformers import WhisperForConditionalGeneration
|
||||||
|
|
||||||
|
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
|
||||||
|
|
||||||
|
model.generation_config.language = "fa"
|
||||||
|
model.generation_config.task = "transcribe"
|
||||||
|
|
||||||
|
model.generation_config.forced_decoder_ids = None
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataCollatorSpeechSeq2SeqWithPadding:
|
||||||
|
processor: Any
|
||||||
|
decoder_start_token_id: int
|
||||||
|
|
||||||
|
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
||||||
|
# split inputs and labels since they have to be of different lengths and need different padding methods
|
||||||
|
# first treat the audio inputs by simply returning torch tensors
|
||||||
|
input_features = [{"input_features": feature["input_features"]} for feature in features]
|
||||||
|
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
|
||||||
|
|
||||||
|
# get the tokenized label sequences
|
||||||
|
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
||||||
|
# pad the labels to max length
|
||||||
|
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
|
||||||
|
|
||||||
|
# replace padding with -100 to ignore loss correctly
|
||||||
|
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
|
||||||
|
|
||||||
|
# if bos token is appended in previous tokenization step,
|
||||||
|
# cut bos token here as it's append later anyways
|
||||||
|
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
|
||||||
|
labels = labels[:, 1:]
|
||||||
|
|
||||||
|
batch["labels"] = labels
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
|
||||||
|
processor=processor,
|
||||||
|
decoder_start_token_id=model.config.decoder_start_token_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
import evaluate
|
||||||
|
|
||||||
|
metric = evaluate.load("wer")
|
||||||
|
|
||||||
|
|
||||||
|
def compute_metrics(pred):
|
||||||
|
pred_ids = pred.predictions
|
||||||
|
label_ids = pred.label_ids
|
||||||
|
|
||||||
|
# replace -100 with the pad_token_id
|
||||||
|
label_ids[label_ids == -100] = tokenizer.pad_token_id
|
||||||
|
|
||||||
|
# we do not want to group tokens when computing the metrics
|
||||||
|
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
||||||
|
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
|
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
|
||||||
|
|
||||||
|
return {"wer": wer}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
|
||||||
|
training_args = Seq2SeqTrainingArguments(
|
||||||
|
output_dir="./whisper-tiny-fa", # change to a repo name of your choice
|
||||||
|
per_device_train_batch_size=4,
|
||||||
|
gradient_accumulation_steps=4, # increase by 2x for every 2x decrease in batch size
|
||||||
|
learning_rate=1e-6,
|
||||||
|
warmup_steps=500,
|
||||||
|
max_steps=4000,
|
||||||
|
gradient_checkpointing=True,
|
||||||
|
fp16=True,
|
||||||
|
eval_strategy="steps",
|
||||||
|
per_device_eval_batch_size=4,
|
||||||
|
predict_with_generate=True,
|
||||||
|
generation_max_length=448,
|
||||||
|
save_steps=1000,
|
||||||
|
eval_steps=1000,
|
||||||
|
logging_steps=25,
|
||||||
|
report_to=["tensorboard"],
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
metric_for_best_model="wer",
|
||||||
|
greater_is_better=False,
|
||||||
|
push_to_hub=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
from transformers import Seq2SeqTrainer
|
||||||
|
|
||||||
|
trainer = Seq2SeqTrainer(
|
||||||
|
args=training_args,
|
||||||
|
model=model,
|
||||||
|
train_dataset=common_voice["train"],
|
||||||
|
eval_dataset=common_voice["test"],
|
||||||
|
data_collator=data_collator,
|
||||||
|
compute_metrics=compute_metrics,
|
||||||
|
tokenizer=processor.feature_extractor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
model.resize_token_embeddings(len(processor.tokenizer))
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
11
docker-compose.yml
Normal file
11
docker-compose.yml
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
version: '3.8'
|
||||||
|
services:
|
||||||
|
vosk:
|
||||||
|
build: ./vosk_service
|
||||||
|
container_name: vosk-api
|
||||||
|
ports:
|
||||||
|
- "5000:5000"
|
||||||
|
# Uncomment the next lines if you want to mount the model from the host instead of copying into the image
|
||||||
|
# volumes:
|
||||||
|
# - ./vosk_service/model:/app/model
|
||||||
|
restart: unless-stopped
|
||||||
6
main.py
Normal file
6
main.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
def main():
|
||||||
|
print("Hello from whisper!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
30
pyproject.toml
Normal file
30
pyproject.toml
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
[project]
|
||||||
|
name = "whisper"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
dependencies = [
|
||||||
|
"brotli>=1.1.0",
|
||||||
|
"datasets>=4.0.0",
|
||||||
|
"librosa>=0.11.0",
|
||||||
|
"numpy>=2.2.6",
|
||||||
|
"pandas>=2.3.1",
|
||||||
|
"pyqt5>=5.15.11",
|
||||||
|
"requests>=2.32.4",
|
||||||
|
"sounddevice>=0.5.2",
|
||||||
|
"soundfile>=0.13.1",
|
||||||
|
"tk>=0.1.0",
|
||||||
|
"torch>=2.7.1",
|
||||||
|
"torchcodec>=0.4.0",
|
||||||
|
"tqdm>=4.67.1",
|
||||||
|
"transformers>=4.37.2",
|
||||||
|
"accelerate==0.21.0",
|
||||||
|
"faster-whisper>=0.10.0",
|
||||||
|
"sentencepiece>=0.2.0",
|
||||||
|
"evaluate>=0.4.5",
|
||||||
|
"jiwer>=4.0.0",
|
||||||
|
"tensorboard>=2.19.0",
|
||||||
|
"hazm>=0.7.0",
|
||||||
|
"pyyaml>=6.0.0",
|
||||||
|
]
|
||||||
190
vosk/test_files/batch_confirm_hf.py
Normal file
190
vosk/test_files/batch_confirm_hf.py
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
from datasets import load_dataset, Audio, Dataset
|
||||||
|
import soundfile as sf
|
||||||
|
import requests
|
||||||
|
import os
|
||||||
|
from tqdm import tqdm
|
||||||
|
import pandas as pd
|
||||||
|
import json
|
||||||
|
import pyarrow as pa
|
||||||
|
import pyarrow.parquet as pq
|
||||||
|
import numpy as np
|
||||||
|
from huggingface_hub import HfApi, create_repo
|
||||||
|
|
||||||
|
# Load the dataset with audio decoding
|
||||||
|
print("Loading dataset...")
|
||||||
|
ds = load_dataset(
|
||||||
|
"Ashegh-Sad-Warrior/Persian_Common_Voice_17_0",
|
||||||
|
split="validated[:500]",
|
||||||
|
streaming=False
|
||||||
|
).cast_column("audio", Audio(sampling_rate=16000))
|
||||||
|
|
||||||
|
output_dir = "confirmed_dataset"
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
confirmed = []
|
||||||
|
|
||||||
|
API_URL = "http://localhost:5000/batch_confirm"
|
||||||
|
batch_size = 8
|
||||||
|
|
||||||
|
# Hugging Face configuration
|
||||||
|
HF_DATASET_NAME = "dpr2000/persian-cv17-confirmed" # Change this to your desired dataset name
|
||||||
|
HF_PRIVATE = True # Set to True for private dataset
|
||||||
|
|
||||||
|
def save_flac(audio_array, path):
|
||||||
|
sf.write(path, audio_array, 16000, format="FLAC")
|
||||||
|
|
||||||
|
print("Processing batches...")
|
||||||
|
for i in tqdm(range(0, len(ds), batch_size)):
|
||||||
|
batch = ds[i:i+batch_size]
|
||||||
|
files = {}
|
||||||
|
references = []
|
||||||
|
temp_flacs = []
|
||||||
|
audio_arrays = []
|
||||||
|
# Fix: batch is a dict of lists
|
||||||
|
for j in range(len(batch["audio"])):
|
||||||
|
audio = batch["audio"][j]
|
||||||
|
flac_path = f"temp_{i+j}.flac"
|
||||||
|
save_flac(audio["array"], flac_path)
|
||||||
|
files[f"audio{j}"] = open(flac_path, "rb")
|
||||||
|
references.append(batch["sentence"][j])
|
||||||
|
temp_flacs.append(flac_path)
|
||||||
|
audio_arrays.append(audio["array"]) # Store the array for confirmed
|
||||||
|
data = {"references": json.dumps(references)}
|
||||||
|
try:
|
||||||
|
response = requests.post(API_URL, files=files, data=data, timeout=120)
|
||||||
|
if response.status_code == 200:
|
||||||
|
resp_json = response.json()
|
||||||
|
if "results" in resp_json:
|
||||||
|
results = resp_json["results"]
|
||||||
|
else:
|
||||||
|
print(f"Batch {i} failed: 'results' key missing in response: {resp_json}")
|
||||||
|
results = [None] * len(references)
|
||||||
|
else:
|
||||||
|
print(f"Batch {i} failed: HTTP {response.status_code} - {response.text}")
|
||||||
|
results = [None] * len(references)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Batch {i} failed: {e}")
|
||||||
|
results = [None] * len(references)
|
||||||
|
for j, result in enumerate(results):
|
||||||
|
if result and result.get("confirmed"):
|
||||||
|
# Save confirmed audio array and transcription
|
||||||
|
confirmed.append({"audio": audio_arrays[j], "transcription": references[j]})
|
||||||
|
os.remove(temp_flacs[j])
|
||||||
|
else:
|
||||||
|
os.remove(temp_flacs[j])
|
||||||
|
for f in files.values():
|
||||||
|
f.close()
|
||||||
|
|
||||||
|
# Save confirmed data using sharding approach
|
||||||
|
if confirmed:
|
||||||
|
print(f"\n🔄 Saving {len(confirmed)} confirmed samples...")
|
||||||
|
|
||||||
|
# Convert confirmed data to HuggingFace dataset format
|
||||||
|
def extract_minimal(example):
|
||||||
|
# Convert float32 audio (range -1.0 to 1.0) to int16 (range -32768 to 32767)
|
||||||
|
audio_float32 = np.array(example["audio"], dtype=np.float32)
|
||||||
|
# Ensure audio is in valid range and scale to int16
|
||||||
|
audio_float32 = np.clip(audio_float32, -1.0, 1.0)
|
||||||
|
audio_int16 = (audio_float32 * 32767).astype(np.int16)
|
||||||
|
return {
|
||||||
|
"audio": audio_int16.tobytes(), # Store as int16 bytes, compatible with Whisper
|
||||||
|
"text": example["transcription"]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create dataset from confirmed samples
|
||||||
|
confirmed_dataset = Dataset.from_list(confirmed)
|
||||||
|
confirmed_dataset = confirmed_dataset.map(extract_minimal, remove_columns=confirmed_dataset.column_names)
|
||||||
|
|
||||||
|
# Sharding parameters
|
||||||
|
num_shards = min(1, len(confirmed)) # Don't create more shards than samples
|
||||||
|
shard_size = len(confirmed_dataset) // num_shards + 1
|
||||||
|
|
||||||
|
# Write each shard separately
|
||||||
|
for i in range(num_shards):
|
||||||
|
start = i * shard_size
|
||||||
|
end = min(len(confirmed_dataset), (i + 1) * shard_size)
|
||||||
|
|
||||||
|
if start >= len(confirmed_dataset):
|
||||||
|
break
|
||||||
|
|
||||||
|
shard = confirmed_dataset.select(range(start, end))
|
||||||
|
table = pa.Table.from_pandas(shard.to_pandas()) # Convert to PyArrow table
|
||||||
|
|
||||||
|
shard_path = os.path.join(output_dir, f"confirmed_shard_{i:02}.parquet")
|
||||||
|
|
||||||
|
pq.write_table(
|
||||||
|
table,
|
||||||
|
shard_path,
|
||||||
|
compression="zstd",
|
||||||
|
compression_level=22, # Maximum compression
|
||||||
|
use_dictionary=True,
|
||||||
|
version="2.6"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"🔹 Shard {i+1}/{num_shards}: {len(shard)} samples saved")
|
||||||
|
|
||||||
|
print(f"\n✅ All confirmed data saved in {num_shards} shards in `{output_dir}/`")
|
||||||
|
|
||||||
|
# Push to Hugging Face Hub
|
||||||
|
print(f"\n🚀 Pushing dataset to Hugging Face Hub as '{HF_DATASET_NAME}'...")
|
||||||
|
try:
|
||||||
|
# Initialize HF API
|
||||||
|
api = HfApi()
|
||||||
|
|
||||||
|
# Create the repository (private if specified)
|
||||||
|
try:
|
||||||
|
create_repo(
|
||||||
|
repo_id=HF_DATASET_NAME,
|
||||||
|
repo_type="dataset",
|
||||||
|
private=HF_PRIVATE,
|
||||||
|
exist_ok=True
|
||||||
|
)
|
||||||
|
print(f"✅ Repository '{HF_DATASET_NAME}' created/verified")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Repository creation: {e}")
|
||||||
|
|
||||||
|
# Upload all parquet files
|
||||||
|
for i in range(num_shards):
|
||||||
|
shard_path = os.path.join(output_dir, f"confirmed_shard_{i:02}.parquet")
|
||||||
|
if os.path.exists(shard_path):
|
||||||
|
api.upload_file(
|
||||||
|
path_or_fileobj=shard_path,
|
||||||
|
path_in_repo=f"confirmed_shard_{i:02}.parquet",
|
||||||
|
repo_id=HF_DATASET_NAME,
|
||||||
|
repo_type="dataset"
|
||||||
|
)
|
||||||
|
print(f"📤 Uploaded shard {i+1}/{num_shards}")
|
||||||
|
|
||||||
|
# Create dataset info file
|
||||||
|
dataset_info = {
|
||||||
|
"dataset_name": HF_DATASET_NAME,
|
||||||
|
"description": "Persian Common Voice confirmed samples for Whisper fine-tuning",
|
||||||
|
"total_samples": len(confirmed),
|
||||||
|
"num_shards": num_shards,
|
||||||
|
"audio_format": "int16 PCM, 16kHz",
|
||||||
|
"columns": ["audio", "text"],
|
||||||
|
"source_dataset": "Ashegh-Sad-Warrior/Persian_Common_Voice_17_0",
|
||||||
|
"processing": "Vosk API batch confirmation"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Upload dataset info
|
||||||
|
import tempfile
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||||
|
json.dump(dataset_info, f, indent=2, ensure_ascii=False)
|
||||||
|
info_path = f.name
|
||||||
|
|
||||||
|
api.upload_file(
|
||||||
|
path_or_fileobj=info_path,
|
||||||
|
path_in_repo="dataset_info.json",
|
||||||
|
repo_id=HF_DATASET_NAME,
|
||||||
|
repo_type="dataset"
|
||||||
|
)
|
||||||
|
os.unlink(info_path)
|
||||||
|
|
||||||
|
print(f"🎉 Dataset successfully pushed to: https://huggingface.co/datasets/{HF_DATASET_NAME}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Failed to push to Hugging Face: {e}")
|
||||||
|
print("💡 Make sure you're logged in with: huggingface-cli login")
|
||||||
|
|
||||||
|
else:
|
||||||
|
print("❌ No confirmed samples to save")
|
||||||
52
vosk/test_files/debug_batch_confirm.py
Normal file
52
vosk/test_files/debug_batch_confirm.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
import requests
|
||||||
|
import json
|
||||||
|
import soundfile as sf
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Test the API connection
|
||||||
|
API_URL = "http://localhost:5000/batch_confirm"
|
||||||
|
|
||||||
|
def test_api():
|
||||||
|
print("Testing API connection...")
|
||||||
|
try:
|
||||||
|
response = requests.get("http://localhost:5000/")
|
||||||
|
print(f"API health check: {response.status_code}")
|
||||||
|
print(f"Response: {response.json()}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"API not reachable: {e}")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def test_batch_confirm():
|
||||||
|
print("\nTesting batch confirm...")
|
||||||
|
|
||||||
|
# Create a simple test audio file
|
||||||
|
test_audio = np.random.randn(16000).astype(np.float32) # 1 second of noise
|
||||||
|
test_path = "test_audio.flac"
|
||||||
|
sf.write(test_path, test_audio, 16000, format="FLAC")
|
||||||
|
|
||||||
|
# Test batch confirm
|
||||||
|
with open(test_path, "rb") as f:
|
||||||
|
files = {"audio0": f}
|
||||||
|
data = {"references": json.dumps(["test sentence"])}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(API_URL, files=files, data=data, timeout=30)
|
||||||
|
print(f"Batch confirm response: {response.status_code}")
|
||||||
|
if response.status_code == 200:
|
||||||
|
print(f"Response JSON: {response.json()}")
|
||||||
|
else:
|
||||||
|
print(f"Error: {response.text}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Batch confirm failed: {e}")
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
if os.path.exists(test_path):
|
||||||
|
os.remove(test_path)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if test_api():
|
||||||
|
test_batch_confirm()
|
||||||
|
else:
|
||||||
|
print("Please start the Vosk API first!")
|
||||||
33
vosk/test_files/download_large_persian_model.py
Normal file
33
vosk/test_files/download_large_persian_model.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import os
|
||||||
|
import requests
|
||||||
|
import zipfile
|
||||||
|
|
||||||
|
MODEL_URL = "https://alphacephei.com/vosk/models/vosk-model-fa-0.42.zip"
|
||||||
|
MODEL_ZIP = "vosk-model-fa-0.42.zip"
|
||||||
|
MODEL_DIR = "vosk-model-fa-0.42"
|
||||||
|
|
||||||
|
# Download the model zip if not present
|
||||||
|
if not os.path.exists(MODEL_ZIP):
|
||||||
|
print(f"Downloading {MODEL_URL} ...")
|
||||||
|
with requests.get(MODEL_URL, stream=True) as r:
|
||||||
|
r.raise_for_status()
|
||||||
|
total = int(r.headers.get('content-length', 0))
|
||||||
|
with open(MODEL_ZIP, 'wb') as f:
|
||||||
|
downloaded = 0
|
||||||
|
for chunk in r.iter_content(chunk_size=8192):
|
||||||
|
if chunk:
|
||||||
|
f.write(chunk)
|
||||||
|
downloaded += len(chunk)
|
||||||
|
print(f"\rDownloaded {downloaded/1024/1024:.2f} MB / {total/1024/1024:.2f} MB", end='', flush=True)
|
||||||
|
print("\nDownload complete.")
|
||||||
|
else:
|
||||||
|
print(f"{MODEL_ZIP} already exists.")
|
||||||
|
|
||||||
|
# Extract the model zip if not already extracted
|
||||||
|
if not os.path.exists(MODEL_DIR):
|
||||||
|
print(f"Extracting {MODEL_ZIP} ...")
|
||||||
|
with zipfile.ZipFile(MODEL_ZIP, 'r') as zip_ref:
|
||||||
|
zip_ref.extractall()
|
||||||
|
print(f"Extracted to {MODEL_DIR}.")
|
||||||
|
else:
|
||||||
|
print(f"{MODEL_DIR} already extracted.")
|
||||||
89
vosk/test_files/human_confirm_parquet.py
Normal file
89
vosk/test_files/human_confirm_parquet.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import sounddevice as sd
|
||||||
|
from PyQt5.QtWidgets import (
|
||||||
|
QApplication, QWidget, QLabel, QPushButton, QVBoxLayout, QHBoxLayout, QMessageBox
|
||||||
|
)
|
||||||
|
|
||||||
|
parquet_path = os.path.join('confirmed_dataset', 'confirmed_shard_00.parquet')
|
||||||
|
df = pd.read_parquet(parquet_path)
|
||||||
|
results = []
|
||||||
|
|
||||||
|
class AudioReviewer(QWidget):
|
||||||
|
def __init__(self, df):
|
||||||
|
super().__init__()
|
||||||
|
self.df = df
|
||||||
|
self.idx = 0
|
||||||
|
self.total = len(df)
|
||||||
|
self.audio = None
|
||||||
|
self.transcription = None
|
||||||
|
|
||||||
|
self.setWindowTitle("Human Audio Confirmation GUI (PyQt5)")
|
||||||
|
self.setGeometry(100, 100, 600, 200)
|
||||||
|
|
||||||
|
self.label = QLabel(f"Sample 1/{self.total}", self)
|
||||||
|
self.trans_label = QLabel("", self)
|
||||||
|
self.play_button = QPushButton("Play Audio", self)
|
||||||
|
self.yes_button = QPushButton("Yes (Correct)", self)
|
||||||
|
self.no_button = QPushButton("No (Incorrect)", self)
|
||||||
|
self.skip_button = QPushButton("Skip", self)
|
||||||
|
self.quit_button = QPushButton("Quit", self)
|
||||||
|
|
||||||
|
self.play_button.clicked.connect(self.play_audio)
|
||||||
|
self.yes_button.clicked.connect(lambda: self.save_and_next('y'))
|
||||||
|
self.no_button.clicked.connect(lambda: self.save_and_next('n'))
|
||||||
|
self.skip_button.clicked.connect(lambda: self.save_and_next('skip'))
|
||||||
|
self.quit_button.clicked.connect(self.quit)
|
||||||
|
|
||||||
|
vbox = QVBoxLayout()
|
||||||
|
vbox.addWidget(self.label)
|
||||||
|
vbox.addWidget(self.trans_label)
|
||||||
|
vbox.addWidget(self.play_button)
|
||||||
|
|
||||||
|
hbox = QHBoxLayout()
|
||||||
|
hbox.addWidget(self.yes_button)
|
||||||
|
hbox.addWidget(self.no_button)
|
||||||
|
hbox.addWidget(self.skip_button)
|
||||||
|
hbox.addWidget(self.quit_button)
|
||||||
|
vbox.addLayout(hbox)
|
||||||
|
|
||||||
|
self.setLayout(vbox)
|
||||||
|
self.load_sample()
|
||||||
|
|
||||||
|
def load_sample(self):
|
||||||
|
if self.idx >= self.total:
|
||||||
|
QMessageBox.information(self, "Done", "All samples reviewed!")
|
||||||
|
self.quit()
|
||||||
|
return
|
||||||
|
row = self.df.iloc[self.idx]
|
||||||
|
# Convert bytes back to numpy array
|
||||||
|
audio_bytes = row['audio']
|
||||||
|
self.audio = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32767.0 # Convert int16 to float32
|
||||||
|
self.transcription = row['text'] # Use 'text' column instead of 'transcription'
|
||||||
|
self.label.setText(f"Sample {self.idx+1}/{self.total}")
|
||||||
|
self.trans_label.setText(f"Transcription: {self.transcription}")
|
||||||
|
|
||||||
|
def play_audio(self):
|
||||||
|
sd.play(self.audio, 16000)
|
||||||
|
sd.wait()
|
||||||
|
|
||||||
|
def save_and_next(self, result):
|
||||||
|
results.append({
|
||||||
|
'index': self.idx,
|
||||||
|
'transcription': self.transcription,
|
||||||
|
'result': result
|
||||||
|
})
|
||||||
|
self.idx += 1
|
||||||
|
self.load_sample()
|
||||||
|
|
||||||
|
def quit(self):
|
||||||
|
pd.DataFrame(results).to_csv('human_confirmed_results.csv', index=False)
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app = QApplication(sys.argv)
|
||||||
|
reviewer = AudioReviewer(df)
|
||||||
|
reviewer.show()
|
||||||
|
sys.exit(app.exec_())
|
||||||
39
vosk/test_files/test_vosk_transcription.py
Normal file
39
vosk/test_files/test_vosk_transcription.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
import requests
|
||||||
|
import difflib
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Usage: python test_vosk_transcription.py <audio_file> <reference_text>
|
||||||
|
|
||||||
|
API_URL = 'http://localhost:5000/transcribe'
|
||||||
|
|
||||||
|
|
||||||
|
def similarity(a, b):
|
||||||
|
return difflib.SequenceMatcher(None, a, b).ratio()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if len(sys.argv) != 3:
|
||||||
|
print("Usage: python test_vosk_transcription.py <audio_file> <reference_text>")
|
||||||
|
sys.exit(1)
|
||||||
|
audio_path = sys.argv[1]
|
||||||
|
reference_text = sys.argv[2]
|
||||||
|
with open(audio_path, 'rb') as f:
|
||||||
|
files = {'audio': f}
|
||||||
|
response = requests.post(API_URL, files=files)
|
||||||
|
if response.status_code != 200:
|
||||||
|
print(f"API error: {response.text}")
|
||||||
|
sys.exit(1)
|
||||||
|
transcription = response.json().get('transcription', '')
|
||||||
|
sim = similarity(transcription, reference_text)
|
||||||
|
print(f"Transcription: {transcription}")
|
||||||
|
print(f"Reference: {reference_text}")
|
||||||
|
print(f"Similarity: {sim:.2f}")
|
||||||
|
if sim > 0.2:
|
||||||
|
print("Test PASSED: Similarity above threshold.")
|
||||||
|
sys.exit(0)
|
||||||
|
else:
|
||||||
|
print("Test FAILED: Similarity below threshold.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
22
vosk/vosk_service/Dockerfile
Normal file
22
vosk/vosk_service/Dockerfile
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
FROM python:3.10-slim
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
ffmpeg \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install Python dependencies
|
||||||
|
COPY requirements.txt ./
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
# Copy service code
|
||||||
|
COPY app.py ./
|
||||||
|
|
||||||
|
# Copy model directory
|
||||||
|
COPY model/ ./model/
|
||||||
|
|
||||||
|
EXPOSE 5000
|
||||||
|
|
||||||
|
CMD ["python", "app.py"]
|
||||||
26
vosk/vosk_service/README.md
Normal file
26
vosk/vosk_service/README.md
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# Vosk Speech-to-Text Docker Service
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
1. Download and extract a Vosk model (already downloading `vosk-model-small-en-us-0.15.zip`):
|
||||||
|
|
||||||
|
```sh
|
||||||
|
unzip model.zip -d model
|
||||||
|
mv model/* model/
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Build the Docker image:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
docker build -t vosk-api .
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Run the Docker container (mounting the model directory):
|
||||||
|
|
||||||
|
```sh
|
||||||
|
docker run -p 5000:5000 -v $(pwd)/model:/app/model vosk-api
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Usage
|
||||||
|
|
||||||
|
POST `/transcribe` with form-data key `audio` (WAV/FLAC/OGG file). Returns JSON with `transcription`.
|
||||||
108
vosk/vosk_service/app.py
Normal file
108
vosk/vosk_service/app.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
from flask import Flask, request, jsonify
|
||||||
|
from vosk import Model, KaldiRecognizer
|
||||||
|
import soundfile as sf
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
from multiprocessing import Process, Queue
|
||||||
|
import difflib
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
MODEL_PATH = "/app/model"
|
||||||
|
|
||||||
|
# Check if model exists and load it
|
||||||
|
print(f"Checking for model at: {MODEL_PATH}")
|
||||||
|
if os.path.exists(MODEL_PATH):
|
||||||
|
print(f"Model directory exists at {MODEL_PATH}")
|
||||||
|
print(f"Contents: {os.listdir(MODEL_PATH)}")
|
||||||
|
try:
|
||||||
|
model = Model(MODEL_PATH)
|
||||||
|
print("Model loaded successfully!")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading model: {e}")
|
||||||
|
raise RuntimeError(f"Failed to load Vosk model: {e}")
|
||||||
|
else:
|
||||||
|
print(f"Model directory not found at {MODEL_PATH}")
|
||||||
|
raise RuntimeError(f"Vosk model not found at {MODEL_PATH}. Please download and mount a model.")
|
||||||
|
|
||||||
|
def similarity(a, b):
|
||||||
|
return difflib.SequenceMatcher(None, a, b).ratio()
|
||||||
|
|
||||||
|
def confirm_voice(audio_bytes, reference_text, samplerate, queue):
|
||||||
|
data, _ = sf.read(io.BytesIO(audio_bytes))
|
||||||
|
if len(data.shape) > 1:
|
||||||
|
data = data[:, 0]
|
||||||
|
if data.dtype != np.int16:
|
||||||
|
data = (data * 32767).astype(np.int16)
|
||||||
|
recognizer = KaldiRecognizer(model, samplerate)
|
||||||
|
recognizer.AcceptWaveform(data.tobytes())
|
||||||
|
result = recognizer.Result()
|
||||||
|
text = json.loads(result).get('text', '')
|
||||||
|
sim = similarity(text, reference_text)
|
||||||
|
queue.put({'transcription': text, 'similarity': sim, 'confirmed': sim > 0.2})
|
||||||
|
|
||||||
|
@app.route('/', methods=['GET'])
|
||||||
|
def health_check():
|
||||||
|
return jsonify({'status': 'ok', 'service': 'vosk-transcription-api', 'model': 'persian'})
|
||||||
|
|
||||||
|
@app.route('/batch_confirm', methods=['POST'])
|
||||||
|
def batch_confirm():
|
||||||
|
# Expecting a multipart/form-data with multiple audio files and a JSON list of references
|
||||||
|
# audio files: audio0, audio1, ...
|
||||||
|
# references: JSON list in 'references' field
|
||||||
|
references = request.form.get('references')
|
||||||
|
if not references:
|
||||||
|
return jsonify({'error': 'Missing references'}), 400
|
||||||
|
try:
|
||||||
|
references = json.loads(references)
|
||||||
|
except Exception:
|
||||||
|
return jsonify({'error': 'Invalid references JSON'}), 400
|
||||||
|
audio_files = []
|
||||||
|
for i in range(len(references)):
|
||||||
|
audio_file = request.files.get(f'audio{i}')
|
||||||
|
if not audio_file:
|
||||||
|
return jsonify({'error': f'Missing audio file audio{i}'}), 400
|
||||||
|
audio_files.append(audio_file.read())
|
||||||
|
results = []
|
||||||
|
processes = []
|
||||||
|
queues = []
|
||||||
|
# Get sample rates for each audio
|
||||||
|
samplerates = []
|
||||||
|
for audio_bytes in audio_files:
|
||||||
|
data, samplerate = sf.read(io.BytesIO(audio_bytes))
|
||||||
|
samplerates.append(samplerate)
|
||||||
|
for idx, (audio_bytes, reference_text, samplerate) in enumerate(zip(audio_files, references, samplerates)):
|
||||||
|
queue = Queue()
|
||||||
|
p = Process(target=confirm_voice, args=(audio_bytes, reference_text, samplerate, queue))
|
||||||
|
processes.append(p)
|
||||||
|
queues.append(queue)
|
||||||
|
p.start()
|
||||||
|
for p in processes:
|
||||||
|
p.join()
|
||||||
|
for queue in queues:
|
||||||
|
results.append(queue.get())
|
||||||
|
return jsonify({'results': results})
|
||||||
|
|
||||||
|
@app.route('/transcribe', methods=['POST'])
|
||||||
|
def transcribe():
|
||||||
|
if 'audio' not in request.files:
|
||||||
|
return jsonify({'error': 'No audio file provided'}), 400
|
||||||
|
audio_file = request.files['audio']
|
||||||
|
audio_bytes = audio_file.read()
|
||||||
|
data, samplerate = sf.read(io.BytesIO(audio_bytes))
|
||||||
|
if len(data.shape) > 1:
|
||||||
|
data = data[:, 0] # Use first channel if stereo
|
||||||
|
# Convert to 16-bit PCM
|
||||||
|
if data.dtype != np.int16:
|
||||||
|
data = (data * 32767).astype(np.int16)
|
||||||
|
recognizer = KaldiRecognizer(model, samplerate)
|
||||||
|
recognizer.AcceptWaveform(data.tobytes())
|
||||||
|
result = recognizer.Result()
|
||||||
|
print(result) # For debugging
|
||||||
|
text = json.loads(result).get('text', '')
|
||||||
|
return jsonify({'transcription': text})
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app.run(host='0.0.0.0', port=5000)
|
||||||
3
vosk/vosk_service/requirements.txt
Normal file
3
vosk/vosk_service/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
vosk
|
||||||
|
Flask
|
||||||
|
soundfile
|
||||||
31
whisper-medium-finetuned/config/training_config.yaml
Normal file
31
whisper-medium-finetuned/config/training_config.yaml
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
model:
|
||||||
|
name: "openai/whisper-medium"
|
||||||
|
language: "persian"
|
||||||
|
task: "transcribe"
|
||||||
|
|
||||||
|
training:
|
||||||
|
batch_size: 16
|
||||||
|
learning_rate: 1e-5
|
||||||
|
num_epochs: 10
|
||||||
|
warmup_steps: 500
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
save_steps: 1000
|
||||||
|
eval_steps: 500
|
||||||
|
logging_steps: 100
|
||||||
|
|
||||||
|
data:
|
||||||
|
dataset_path: "confirmed_dataset/confirmed.parquet"
|
||||||
|
max_audio_length: 30.0
|
||||||
|
min_audio_length: 1.0
|
||||||
|
train_split: 0.9
|
||||||
|
eval_split: 0.1
|
||||||
|
|
||||||
|
persian:
|
||||||
|
use_hazm: true
|
||||||
|
normalize_text: true
|
||||||
|
remove_diacritics: true
|
||||||
|
|
||||||
|
output:
|
||||||
|
model_dir: "whisper-medium-finetuned/models"
|
||||||
|
logs_dir: "whisper-medium-finetuned/logs"
|
||||||
|
faster_whisper_dir: "whisper-medium-finetuned/faster_whisper"
|
||||||
2
whisper-medium-finetuned/src/__init__.py
Normal file
2
whisper-medium-finetuned/src/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# Whisper Persian Fine-tuning Pipeline
|
||||||
|
__version__ = "0.1.0"
|
||||||
258
whisper-medium-finetuned/src/dataset_processor.py
Normal file
258
whisper-medium-finetuned/src/dataset_processor.py
Normal file
@@ -0,0 +1,258 @@
|
|||||||
|
"""
|
||||||
|
Persian Dataset Processor
|
||||||
|
|
||||||
|
This module handles loading and preprocessing of the confirmed Persian dataset.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import librosa
|
||||||
|
from datasets import Dataset
|
||||||
|
from typing import List, Dict, Any, Optional, Tuple
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PersianDatasetProcessor:
|
||||||
|
"""Handles loading and preprocessing of Persian audio-text dataset."""
|
||||||
|
|
||||||
|
def __init__(self, parquet_path: str, sample_rate: int = 16000):
|
||||||
|
self.parquet_path = parquet_path
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.dataset = None
|
||||||
|
self.validated_data = []
|
||||||
|
|
||||||
|
def load_confirmed_dataset(self, parquet_path: Optional[str] = None) -> Dataset:
|
||||||
|
"""
|
||||||
|
Load audio arrays and transcriptions from confirmed.parquet file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parquet_path: Path to the parquet file (optional, uses instance path if None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset: HuggingFace Dataset object
|
||||||
|
"""
|
||||||
|
path = parquet_path or self.parquet_path
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Loading dataset from {path}")
|
||||||
|
df = pd.read_parquet(path)
|
||||||
|
|
||||||
|
if df.empty:
|
||||||
|
raise ValueError("Dataset is empty")
|
||||||
|
|
||||||
|
logger.info(f"Loaded {len(df)} samples from dataset")
|
||||||
|
|
||||||
|
# Validate required columns
|
||||||
|
required_columns = ['audio', 'transcription']
|
||||||
|
missing_columns = [col for col in required_columns if col not in df.columns]
|
||||||
|
if missing_columns:
|
||||||
|
raise ValueError(f"Missing required columns: {missing_columns}")
|
||||||
|
|
||||||
|
# Convert to HuggingFace Dataset
|
||||||
|
dataset_dict = {
|
||||||
|
'audio': df['audio'].tolist(),
|
||||||
|
'transcription': df['transcription'].tolist()
|
||||||
|
}
|
||||||
|
|
||||||
|
self.dataset = Dataset.from_dict(dataset_dict)
|
||||||
|
logger.info("Dataset converted to HuggingFace format")
|
||||||
|
|
||||||
|
return self.dataset
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load dataset: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def validate_audio_data(self, dataset: Optional[Dataset] = None) -> bool:
|
||||||
|
"""
|
||||||
|
Validate audio data integrity and transcription quality.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: Dataset to validate (optional, uses instance dataset if None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if validation passes
|
||||||
|
"""
|
||||||
|
data = dataset or self.dataset
|
||||||
|
if data is None:
|
||||||
|
raise ValueError("No dataset to validate")
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info("Validating audio data...")
|
||||||
|
valid_samples = []
|
||||||
|
invalid_count = 0
|
||||||
|
|
||||||
|
for i, sample in enumerate(data):
|
||||||
|
audio_array = sample['audio']
|
||||||
|
transcription = sample['transcription']
|
||||||
|
|
||||||
|
# Validate audio
|
||||||
|
if len(audio_array) == 0:
|
||||||
|
logger.warning(f"Sample {i}: Empty audio array")
|
||||||
|
invalid_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if np.isnan(audio_array).any() or np.isinf(audio_array).any():
|
||||||
|
logger.warning(f"Sample {i}: Invalid audio values (NaN/Inf)")
|
||||||
|
invalid_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Calculate duration
|
||||||
|
duration = len(audio_array) / self.sample_rate
|
||||||
|
if duration < 0.5 or duration > 30.0: # Filter very short/long audio
|
||||||
|
logger.warning(f"Sample {i}: Invalid duration {duration:.2f}s")
|
||||||
|
invalid_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Validate transcription
|
||||||
|
if not transcription or not transcription.strip():
|
||||||
|
logger.warning(f"Sample {i}: Empty transcription")
|
||||||
|
invalid_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(transcription.strip()) < 2:
|
||||||
|
logger.warning(f"Sample {i}: Transcription too short")
|
||||||
|
invalid_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Add valid sample
|
||||||
|
valid_samples.append({
|
||||||
|
'audio': audio_array,
|
||||||
|
'transcription': transcription.strip(),
|
||||||
|
'duration': duration
|
||||||
|
})
|
||||||
|
|
||||||
|
self.validated_data = valid_samples
|
||||||
|
logger.info(f"Validation complete: {len(valid_samples)} valid, {invalid_count} invalid samples")
|
||||||
|
|
||||||
|
return len(valid_samples) > 0
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Validation failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def prepare_audio_features(self, audio_array: np.ndarray) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Convert audio array to mel-spectrogram features for Whisper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_array: Raw audio array
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Mel-spectrogram features
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Ensure audio is the right length and format
|
||||||
|
if len(audio_array.shape) > 1:
|
||||||
|
audio_array = audio_array.flatten()
|
||||||
|
|
||||||
|
# Resample if necessary
|
||||||
|
if len(audio_array) == 0:
|
||||||
|
raise ValueError("Empty audio array")
|
||||||
|
|
||||||
|
# Pad or truncate to 30 seconds max
|
||||||
|
max_length = 30 * self.sample_rate
|
||||||
|
if len(audio_array) > max_length:
|
||||||
|
audio_array = audio_array[:max_length]
|
||||||
|
|
||||||
|
# Convert to mel-spectrogram using librosa (Whisper-compatible)
|
||||||
|
mel_spec = librosa.feature.melspectrogram(
|
||||||
|
y=audio_array,
|
||||||
|
sr=self.sample_rate,
|
||||||
|
n_mels=80, # Whisper uses 80 mel bins
|
||||||
|
hop_length=160, # Whisper hop length
|
||||||
|
n_fft=400 # Whisper n_fft
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to log scale
|
||||||
|
log_mel = librosa.power_to_db(mel_spec)
|
||||||
|
|
||||||
|
# Convert to tensor
|
||||||
|
features = torch.from_numpy(log_mel).float()
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to prepare audio features: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def create_training_dataset(self, train_split: float = 0.9) -> Tuple[Dataset, Dataset]:
|
||||||
|
"""
|
||||||
|
Create training and validation datasets from validated data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_split: Fraction of data to use for training
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (train_dataset, eval_dataset)
|
||||||
|
"""
|
||||||
|
if not self.validated_data:
|
||||||
|
raise ValueError("No validated data available. Run validate_audio_data first.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info("Creating training datasets...")
|
||||||
|
|
||||||
|
# Shuffle data
|
||||||
|
import random
|
||||||
|
random.shuffle(self.validated_data)
|
||||||
|
|
||||||
|
# Split data
|
||||||
|
split_idx = int(len(self.validated_data) * train_split)
|
||||||
|
train_data = self.validated_data[:split_idx]
|
||||||
|
eval_data = self.validated_data[split_idx:]
|
||||||
|
|
||||||
|
logger.info(f"Train samples: {len(train_data)}, Eval samples: {len(eval_data)}")
|
||||||
|
|
||||||
|
# Create datasets
|
||||||
|
train_dict = {
|
||||||
|
'audio': [sample['audio'] for sample in train_data],
|
||||||
|
'transcription': [sample['transcription'] for sample in train_data],
|
||||||
|
'duration': [sample['duration'] for sample in train_data]
|
||||||
|
}
|
||||||
|
|
||||||
|
eval_dict = {
|
||||||
|
'audio': [sample['audio'] for sample in eval_data],
|
||||||
|
'transcription': [sample['transcription'] for sample in eval_data],
|
||||||
|
'duration': [sample['duration'] for sample in eval_data]
|
||||||
|
}
|
||||||
|
|
||||||
|
train_dataset = Dataset.from_dict(train_dict)
|
||||||
|
eval_dataset = Dataset.from_dict(eval_dict)
|
||||||
|
|
||||||
|
logger.info("Training datasets created successfully")
|
||||||
|
|
||||||
|
return train_dataset, eval_dataset
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create training datasets: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_dataset_statistics(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get statistics about the validated dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict containing dataset statistics
|
||||||
|
"""
|
||||||
|
if not self.validated_data:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
durations = [sample['duration'] for sample in self.validated_data]
|
||||||
|
transcription_lengths = [len(sample['transcription']) for sample in self.validated_data]
|
||||||
|
|
||||||
|
stats = {
|
||||||
|
'total_samples': len(self.validated_data),
|
||||||
|
'total_duration_hours': sum(durations) / 3600,
|
||||||
|
'avg_duration_seconds': np.mean(durations),
|
||||||
|
'min_duration_seconds': min(durations),
|
||||||
|
'max_duration_seconds': max(durations),
|
||||||
|
'avg_transcription_length': np.mean(transcription_lengths),
|
||||||
|
'min_transcription_length': min(transcription_lengths),
|
||||||
|
'max_transcription_length': max(transcription_lengths)
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats
|
||||||
1
whisper-medium-finetuned/src/faster_whisper_converter.py
Normal file
1
whisper-medium-finetuned/src/faster_whisper_converter.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# Faster Whisper conversion functionality
|
||||||
1
whisper-medium-finetuned/src/finetuning_preparator.py
Normal file
1
whisper-medium-finetuned/src/finetuning_preparator.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# Fine-tuning preparation functionality
|
||||||
152
whisper-medium-finetuned/src/model_loader.py
Normal file
152
whisper-medium-finetuned/src/model_loader.py
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
"""
|
||||||
|
Whisper Model Loader
|
||||||
|
|
||||||
|
This module handles loading and initializing the Whisper Medium model from Hugging Face.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import (
|
||||||
|
WhisperForConditionalGeneration,
|
||||||
|
WhisperProcessor,
|
||||||
|
WhisperTokenizer,
|
||||||
|
WhisperFeatureExtractor
|
||||||
|
)
|
||||||
|
from typing import Tuple, Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WhisperModelLoader:
|
||||||
|
"""Handles loading and verification of Whisper models from Hugging Face."""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str = "openai/whisper-medium"):
|
||||||
|
self.model_name = model_name
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
logger.info(f"Using device: {self.device}")
|
||||||
|
|
||||||
|
def load_model(self) -> WhisperForConditionalGeneration:
|
||||||
|
"""
|
||||||
|
Load the Whisper model from Hugging Face.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WhisperForConditionalGeneration: The loaded model
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"Loading Whisper model: {self.model_name}")
|
||||||
|
model = WhisperForConditionalGeneration.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32
|
||||||
|
)
|
||||||
|
model.to(self.device)
|
||||||
|
logger.info("Model loaded successfully")
|
||||||
|
return model
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load model: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def load_processor(self) -> WhisperProcessor:
|
||||||
|
"""
|
||||||
|
Load the Whisper processor from Hugging Face.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WhisperProcessor: The loaded processor
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"Loading Whisper processor: {self.model_name}")
|
||||||
|
processor = WhisperProcessor.from_pretrained(self.model_name)
|
||||||
|
logger.info("Processor loaded successfully")
|
||||||
|
return processor
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load processor: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def load_tokenizer(self) -> WhisperTokenizer:
|
||||||
|
"""
|
||||||
|
Load the Whisper tokenizer separately.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WhisperTokenizer: The loaded tokenizer
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"Loading Whisper tokenizer: {self.model_name}")
|
||||||
|
tokenizer = WhisperTokenizer.from_pretrained(self.model_name)
|
||||||
|
logger.info("Tokenizer loaded successfully")
|
||||||
|
return tokenizer
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load tokenizer: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def load_feature_extractor(self) -> WhisperFeatureExtractor:
|
||||||
|
"""
|
||||||
|
Load the Whisper feature extractor separately.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WhisperFeatureExtractor: The loaded feature extractor
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"Loading Whisper feature extractor: {self.model_name}")
|
||||||
|
feature_extractor = WhisperFeatureExtractor.from_pretrained(self.model_name)
|
||||||
|
logger.info("Feature extractor loaded successfully")
|
||||||
|
return feature_extractor
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load feature extractor: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def verify_model_compatibility(self, model: WhisperForConditionalGeneration) -> bool:
|
||||||
|
"""
|
||||||
|
Verify that the loaded model is compatible for fine-tuning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The loaded Whisper model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if compatible, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check if model has the expected architecture
|
||||||
|
if not hasattr(model, 'model'):
|
||||||
|
logger.error("Model does not have expected 'model' attribute")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not hasattr(model.model, 'encoder') or not hasattr(model.model, 'decoder'):
|
||||||
|
logger.error("Model does not have encoder/decoder architecture")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if model supports gradient computation
|
||||||
|
if not any(p.requires_grad for p in model.parameters()):
|
||||||
|
logger.warning("No parameters require gradients - enabling gradients")
|
||||||
|
for param in model.parameters():
|
||||||
|
param.requires_grad = True
|
||||||
|
|
||||||
|
# Test forward pass with dummy input
|
||||||
|
dummy_input = torch.randn(1, 80, 3000).to(self.device)
|
||||||
|
dummy_labels = torch.randint(0, 1000, (1, 10)).to(self.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(input_features=dummy_input, labels=dummy_labels)
|
||||||
|
if not hasattr(outputs, 'loss'):
|
||||||
|
logger.error("Model output does not contain loss")
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.info("Model compatibility verification passed")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Model compatibility verification failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def load_all_components(self) -> Tuple[WhisperForConditionalGeneration, WhisperProcessor]:
|
||||||
|
"""
|
||||||
|
Load all necessary components for fine-tuning.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing (model, processor)
|
||||||
|
"""
|
||||||
|
model = self.load_model()
|
||||||
|
processor = self.load_processor()
|
||||||
|
|
||||||
|
if not self.verify_model_compatibility(model):
|
||||||
|
raise ValueError("Model compatibility verification failed")
|
||||||
|
|
||||||
|
return model, processor
|
||||||
1
whisper-medium-finetuned/src/persian_tokenizer.py
Normal file
1
whisper-medium-finetuned/src/persian_tokenizer.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# Persian tokenization functionality placeholder
|
||||||
1
whisper-medium-finetuned/tests/__init__.py
Normal file
1
whisper-medium-finetuned/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# Test package
|
||||||
7
whisper-medium-finetuned/tests/test_dataset_processor.py
Normal file
7
whisper-medium-finetuned/tests/test_dataset_processor.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for PersianDatasetProcessor
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
import numpy as n
|
||||||
76
whisper-medium-finetuned/tests/test_model_loader.py
Normal file
76
whisper-medium-finetuned/tests/test_model_loader.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for WhisperModelLoader
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
import torch
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Add src to path for imports
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||||
|
|
||||||
|
from model_loader import WhisperModelLoader
|
||||||
|
|
||||||
|
|
||||||
|
class TestWhisperModelLoader(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.loader = WhisperModelLoader("openai/whisper-medium")
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
"""Test WhisperModelLoader initialization"""
|
||||||
|
self.assertEqual(self.loader.model_name, "openai/whisper-medium")
|
||||||
|
self.assertIsInstance(self.loader.device, torch.device)
|
||||||
|
|
||||||
|
@patch('model_loader.WhisperForConditionalGeneration.from_pretrained')
|
||||||
|
def test_load_model_success(self, mock_from_pretrained):
|
||||||
|
"""Test successful model loading"""
|
||||||
|
mock_model = MagicMock()
|
||||||
|
mock_from_pretrained.return_value = mock_model
|
||||||
|
|
||||||
|
result = self.loader.load_model()
|
||||||
|
|
||||||
|
mock_from_pretrained.assert_called_once()
|
||||||
|
mock_model.to.assert_called_once_with(self.loader.device)
|
||||||
|
self.assertEqual(result, mock_model)
|
||||||
|
|
||||||
|
@patch('model_loader.WhisperProcessor.from_pretrained')
|
||||||
|
def test_load_processor_success(self, mock_from_pretrained):
|
||||||
|
"""Test successful processor loading"""
|
||||||
|
mock_processor = MagicMock()
|
||||||
|
mock_from_pretrained.return_value = mock_processor
|
||||||
|
|
||||||
|
result = self.loader.load_processor()
|
||||||
|
|
||||||
|
mock_from_pretrained.assert_called_once_with("openai/whisper-medium")
|
||||||
|
self.assertEqual(result, mock_processor)
|
||||||
|
|
||||||
|
def test_verify_model_compatibility_valid_model(self):
|
||||||
|
"""Test model compatibility verification with valid model"""
|
||||||
|
mock_model = MagicMock()
|
||||||
|
mock_model.model.encoder = MagicMock()
|
||||||
|
mock_model.model.decoder = MagicMock()
|
||||||
|
mock_model.parameters.return_value = [MagicMock(requires_grad=True)]
|
||||||
|
|
||||||
|
# Mock forward pass
|
||||||
|
mock_output = MagicMock()
|
||||||
|
mock_output.loss = torch.tensor(1.0)
|
||||||
|
mock_model.return_value = mock_output
|
||||||
|
|
||||||
|
result = self.loader.verify_model_compatibility(mock_model)
|
||||||
|
self.assertTrue(result)
|
||||||
|
|
||||||
|
def test_verify_model_compatibility_invalid_model(self):
|
||||||
|
"""Test model compatibility verification with invalid model"""
|
||||||
|
mock_model = MagicMock()
|
||||||
|
# Remove required attributes
|
||||||
|
del mock_model.model
|
||||||
|
|
||||||
|
result = self.loader.verify_model_compatibility(mock_model)
|
||||||
|
self.assertFalse(result)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user