Sure! Pl
This commit is contained in:
31
whisper-medium-finetuned/config/training_config.yaml
Normal file
31
whisper-medium-finetuned/config/training_config.yaml
Normal file
@@ -0,0 +1,31 @@
|
||||
model:
|
||||
name: "openai/whisper-medium"
|
||||
language: "persian"
|
||||
task: "transcribe"
|
||||
|
||||
training:
|
||||
batch_size: 16
|
||||
learning_rate: 1e-5
|
||||
num_epochs: 10
|
||||
warmup_steps: 500
|
||||
gradient_accumulation_steps: 2
|
||||
save_steps: 1000
|
||||
eval_steps: 500
|
||||
logging_steps: 100
|
||||
|
||||
data:
|
||||
dataset_path: "confirmed_dataset/confirmed.parquet"
|
||||
max_audio_length: 30.0
|
||||
min_audio_length: 1.0
|
||||
train_split: 0.9
|
||||
eval_split: 0.1
|
||||
|
||||
persian:
|
||||
use_hazm: true
|
||||
normalize_text: true
|
||||
remove_diacritics: true
|
||||
|
||||
output:
|
||||
model_dir: "whisper-medium-finetuned/models"
|
||||
logs_dir: "whisper-medium-finetuned/logs"
|
||||
faster_whisper_dir: "whisper-medium-finetuned/faster_whisper"
|
||||
2
whisper-medium-finetuned/src/__init__.py
Normal file
2
whisper-medium-finetuned/src/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Whisper Persian Fine-tuning Pipeline
|
||||
__version__ = "0.1.0"
|
||||
258
whisper-medium-finetuned/src/dataset_processor.py
Normal file
258
whisper-medium-finetuned/src/dataset_processor.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""
|
||||
Persian Dataset Processor
|
||||
|
||||
This module handles loading and preprocessing of the confirmed Persian dataset.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import torch
|
||||
import librosa
|
||||
from datasets import Dataset
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PersianDatasetProcessor:
|
||||
"""Handles loading and preprocessing of Persian audio-text dataset."""
|
||||
|
||||
def __init__(self, parquet_path: str, sample_rate: int = 16000):
|
||||
self.parquet_path = parquet_path
|
||||
self.sample_rate = sample_rate
|
||||
self.dataset = None
|
||||
self.validated_data = []
|
||||
|
||||
def load_confirmed_dataset(self, parquet_path: Optional[str] = None) -> Dataset:
|
||||
"""
|
||||
Load audio arrays and transcriptions from confirmed.parquet file.
|
||||
|
||||
Args:
|
||||
parquet_path: Path to the parquet file (optional, uses instance path if None)
|
||||
|
||||
Returns:
|
||||
Dataset: HuggingFace Dataset object
|
||||
"""
|
||||
path = parquet_path or self.parquet_path
|
||||
|
||||
try:
|
||||
logger.info(f"Loading dataset from {path}")
|
||||
df = pd.read_parquet(path)
|
||||
|
||||
if df.empty:
|
||||
raise ValueError("Dataset is empty")
|
||||
|
||||
logger.info(f"Loaded {len(df)} samples from dataset")
|
||||
|
||||
# Validate required columns
|
||||
required_columns = ['audio', 'transcription']
|
||||
missing_columns = [col for col in required_columns if col not in df.columns]
|
||||
if missing_columns:
|
||||
raise ValueError(f"Missing required columns: {missing_columns}")
|
||||
|
||||
# Convert to HuggingFace Dataset
|
||||
dataset_dict = {
|
||||
'audio': df['audio'].tolist(),
|
||||
'transcription': df['transcription'].tolist()
|
||||
}
|
||||
|
||||
self.dataset = Dataset.from_dict(dataset_dict)
|
||||
logger.info("Dataset converted to HuggingFace format")
|
||||
|
||||
return self.dataset
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load dataset: {e}")
|
||||
raise
|
||||
|
||||
def validate_audio_data(self, dataset: Optional[Dataset] = None) -> bool:
|
||||
"""
|
||||
Validate audio data integrity and transcription quality.
|
||||
|
||||
Args:
|
||||
dataset: Dataset to validate (optional, uses instance dataset if None)
|
||||
|
||||
Returns:
|
||||
bool: True if validation passes
|
||||
"""
|
||||
data = dataset or self.dataset
|
||||
if data is None:
|
||||
raise ValueError("No dataset to validate")
|
||||
|
||||
try:
|
||||
logger.info("Validating audio data...")
|
||||
valid_samples = []
|
||||
invalid_count = 0
|
||||
|
||||
for i, sample in enumerate(data):
|
||||
audio_array = sample['audio']
|
||||
transcription = sample['transcription']
|
||||
|
||||
# Validate audio
|
||||
if len(audio_array) == 0:
|
||||
logger.warning(f"Sample {i}: Empty audio array")
|
||||
invalid_count += 1
|
||||
continue
|
||||
|
||||
if np.isnan(audio_array).any() or np.isinf(audio_array).any():
|
||||
logger.warning(f"Sample {i}: Invalid audio values (NaN/Inf)")
|
||||
invalid_count += 1
|
||||
continue
|
||||
|
||||
# Calculate duration
|
||||
duration = len(audio_array) / self.sample_rate
|
||||
if duration < 0.5 or duration > 30.0: # Filter very short/long audio
|
||||
logger.warning(f"Sample {i}: Invalid duration {duration:.2f}s")
|
||||
invalid_count += 1
|
||||
continue
|
||||
|
||||
# Validate transcription
|
||||
if not transcription or not transcription.strip():
|
||||
logger.warning(f"Sample {i}: Empty transcription")
|
||||
invalid_count += 1
|
||||
continue
|
||||
|
||||
if len(transcription.strip()) < 2:
|
||||
logger.warning(f"Sample {i}: Transcription too short")
|
||||
invalid_count += 1
|
||||
continue
|
||||
|
||||
# Add valid sample
|
||||
valid_samples.append({
|
||||
'audio': audio_array,
|
||||
'transcription': transcription.strip(),
|
||||
'duration': duration
|
||||
})
|
||||
|
||||
self.validated_data = valid_samples
|
||||
logger.info(f"Validation complete: {len(valid_samples)} valid, {invalid_count} invalid samples")
|
||||
|
||||
return len(valid_samples) > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Validation failed: {e}")
|
||||
return False
|
||||
|
||||
def prepare_audio_features(self, audio_array: np.ndarray) -> torch.Tensor:
|
||||
"""
|
||||
Convert audio array to mel-spectrogram features for Whisper.
|
||||
|
||||
Args:
|
||||
audio_array: Raw audio array
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Mel-spectrogram features
|
||||
"""
|
||||
try:
|
||||
# Ensure audio is the right length and format
|
||||
if len(audio_array.shape) > 1:
|
||||
audio_array = audio_array.flatten()
|
||||
|
||||
# Resample if necessary
|
||||
if len(audio_array) == 0:
|
||||
raise ValueError("Empty audio array")
|
||||
|
||||
# Pad or truncate to 30 seconds max
|
||||
max_length = 30 * self.sample_rate
|
||||
if len(audio_array) > max_length:
|
||||
audio_array = audio_array[:max_length]
|
||||
|
||||
# Convert to mel-spectrogram using librosa (Whisper-compatible)
|
||||
mel_spec = librosa.feature.melspectrogram(
|
||||
y=audio_array,
|
||||
sr=self.sample_rate,
|
||||
n_mels=80, # Whisper uses 80 mel bins
|
||||
hop_length=160, # Whisper hop length
|
||||
n_fft=400 # Whisper n_fft
|
||||
)
|
||||
|
||||
# Convert to log scale
|
||||
log_mel = librosa.power_to_db(mel_spec)
|
||||
|
||||
# Convert to tensor
|
||||
features = torch.from_numpy(log_mel).float()
|
||||
|
||||
return features
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to prepare audio features: {e}")
|
||||
raise
|
||||
|
||||
def create_training_dataset(self, train_split: float = 0.9) -> Tuple[Dataset, Dataset]:
|
||||
"""
|
||||
Create training and validation datasets from validated data.
|
||||
|
||||
Args:
|
||||
train_split: Fraction of data to use for training
|
||||
|
||||
Returns:
|
||||
Tuple of (train_dataset, eval_dataset)
|
||||
"""
|
||||
if not self.validated_data:
|
||||
raise ValueError("No validated data available. Run validate_audio_data first.")
|
||||
|
||||
try:
|
||||
logger.info("Creating training datasets...")
|
||||
|
||||
# Shuffle data
|
||||
import random
|
||||
random.shuffle(self.validated_data)
|
||||
|
||||
# Split data
|
||||
split_idx = int(len(self.validated_data) * train_split)
|
||||
train_data = self.validated_data[:split_idx]
|
||||
eval_data = self.validated_data[split_idx:]
|
||||
|
||||
logger.info(f"Train samples: {len(train_data)}, Eval samples: {len(eval_data)}")
|
||||
|
||||
# Create datasets
|
||||
train_dict = {
|
||||
'audio': [sample['audio'] for sample in train_data],
|
||||
'transcription': [sample['transcription'] for sample in train_data],
|
||||
'duration': [sample['duration'] for sample in train_data]
|
||||
}
|
||||
|
||||
eval_dict = {
|
||||
'audio': [sample['audio'] for sample in eval_data],
|
||||
'transcription': [sample['transcription'] for sample in eval_data],
|
||||
'duration': [sample['duration'] for sample in eval_data]
|
||||
}
|
||||
|
||||
train_dataset = Dataset.from_dict(train_dict)
|
||||
eval_dataset = Dataset.from_dict(eval_dict)
|
||||
|
||||
logger.info("Training datasets created successfully")
|
||||
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create training datasets: {e}")
|
||||
raise
|
||||
|
||||
def get_dataset_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get statistics about the validated dataset.
|
||||
|
||||
Returns:
|
||||
Dict containing dataset statistics
|
||||
"""
|
||||
if not self.validated_data:
|
||||
return {}
|
||||
|
||||
durations = [sample['duration'] for sample in self.validated_data]
|
||||
transcription_lengths = [len(sample['transcription']) for sample in self.validated_data]
|
||||
|
||||
stats = {
|
||||
'total_samples': len(self.validated_data),
|
||||
'total_duration_hours': sum(durations) / 3600,
|
||||
'avg_duration_seconds': np.mean(durations),
|
||||
'min_duration_seconds': min(durations),
|
||||
'max_duration_seconds': max(durations),
|
||||
'avg_transcription_length': np.mean(transcription_lengths),
|
||||
'min_transcription_length': min(transcription_lengths),
|
||||
'max_transcription_length': max(transcription_lengths)
|
||||
}
|
||||
|
||||
return stats
|
||||
1
whisper-medium-finetuned/src/faster_whisper_converter.py
Normal file
1
whisper-medium-finetuned/src/faster_whisper_converter.py
Normal file
@@ -0,0 +1 @@
|
||||
# Faster Whisper conversion functionality
|
||||
1
whisper-medium-finetuned/src/finetuning_preparator.py
Normal file
1
whisper-medium-finetuned/src/finetuning_preparator.py
Normal file
@@ -0,0 +1 @@
|
||||
# Fine-tuning preparation functionality
|
||||
152
whisper-medium-finetuned/src/model_loader.py
Normal file
152
whisper-medium-finetuned/src/model_loader.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
Whisper Model Loader
|
||||
|
||||
This module handles loading and initializing the Whisper Medium model from Hugging Face.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
WhisperForConditionalGeneration,
|
||||
WhisperProcessor,
|
||||
WhisperTokenizer,
|
||||
WhisperFeatureExtractor
|
||||
)
|
||||
from typing import Tuple, Optional
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WhisperModelLoader:
|
||||
"""Handles loading and verification of Whisper models from Hugging Face."""
|
||||
|
||||
def __init__(self, model_name: str = "openai/whisper-medium"):
|
||||
self.model_name = model_name
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logger.info(f"Using device: {self.device}")
|
||||
|
||||
def load_model(self) -> WhisperForConditionalGeneration:
|
||||
"""
|
||||
Load the Whisper model from Hugging Face.
|
||||
|
||||
Returns:
|
||||
WhisperForConditionalGeneration: The loaded model
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Loading Whisper model: {self.model_name}")
|
||||
model = WhisperForConditionalGeneration.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32
|
||||
)
|
||||
model.to(self.device)
|
||||
logger.info("Model loaded successfully")
|
||||
return model
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
raise
|
||||
|
||||
def load_processor(self) -> WhisperProcessor:
|
||||
"""
|
||||
Load the Whisper processor from Hugging Face.
|
||||
|
||||
Returns:
|
||||
WhisperProcessor: The loaded processor
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Loading Whisper processor: {self.model_name}")
|
||||
processor = WhisperProcessor.from_pretrained(self.model_name)
|
||||
logger.info("Processor loaded successfully")
|
||||
return processor
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load processor: {e}")
|
||||
raise
|
||||
|
||||
def load_tokenizer(self) -> WhisperTokenizer:
|
||||
"""
|
||||
Load the Whisper tokenizer separately.
|
||||
|
||||
Returns:
|
||||
WhisperTokenizer: The loaded tokenizer
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Loading Whisper tokenizer: {self.model_name}")
|
||||
tokenizer = WhisperTokenizer.from_pretrained(self.model_name)
|
||||
logger.info("Tokenizer loaded successfully")
|
||||
return tokenizer
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load tokenizer: {e}")
|
||||
raise
|
||||
|
||||
def load_feature_extractor(self) -> WhisperFeatureExtractor:
|
||||
"""
|
||||
Load the Whisper feature extractor separately.
|
||||
|
||||
Returns:
|
||||
WhisperFeatureExtractor: The loaded feature extractor
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Loading Whisper feature extractor: {self.model_name}")
|
||||
feature_extractor = WhisperFeatureExtractor.from_pretrained(self.model_name)
|
||||
logger.info("Feature extractor loaded successfully")
|
||||
return feature_extractor
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load feature extractor: {e}")
|
||||
raise
|
||||
|
||||
def verify_model_compatibility(self, model: WhisperForConditionalGeneration) -> bool:
|
||||
"""
|
||||
Verify that the loaded model is compatible for fine-tuning.
|
||||
|
||||
Args:
|
||||
model: The loaded Whisper model
|
||||
|
||||
Returns:
|
||||
bool: True if compatible, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Check if model has the expected architecture
|
||||
if not hasattr(model, 'model'):
|
||||
logger.error("Model does not have expected 'model' attribute")
|
||||
return False
|
||||
|
||||
if not hasattr(model.model, 'encoder') or not hasattr(model.model, 'decoder'):
|
||||
logger.error("Model does not have encoder/decoder architecture")
|
||||
return False
|
||||
|
||||
# Check if model supports gradient computation
|
||||
if not any(p.requires_grad for p in model.parameters()):
|
||||
logger.warning("No parameters require gradients - enabling gradients")
|
||||
for param in model.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
# Test forward pass with dummy input
|
||||
dummy_input = torch.randn(1, 80, 3000).to(self.device)
|
||||
dummy_labels = torch.randint(0, 1000, (1, 10)).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(input_features=dummy_input, labels=dummy_labels)
|
||||
if not hasattr(outputs, 'loss'):
|
||||
logger.error("Model output does not contain loss")
|
||||
return False
|
||||
|
||||
logger.info("Model compatibility verification passed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Model compatibility verification failed: {e}")
|
||||
return False
|
||||
|
||||
def load_all_components(self) -> Tuple[WhisperForConditionalGeneration, WhisperProcessor]:
|
||||
"""
|
||||
Load all necessary components for fine-tuning.
|
||||
|
||||
Returns:
|
||||
Tuple containing (model, processor)
|
||||
"""
|
||||
model = self.load_model()
|
||||
processor = self.load_processor()
|
||||
|
||||
if not self.verify_model_compatibility(model):
|
||||
raise ValueError("Model compatibility verification failed")
|
||||
|
||||
return model, processor
|
||||
1
whisper-medium-finetuned/src/persian_tokenizer.py
Normal file
1
whisper-medium-finetuned/src/persian_tokenizer.py
Normal file
@@ -0,0 +1 @@
|
||||
# Persian tokenization functionality placeholder
|
||||
1
whisper-medium-finetuned/tests/__init__.py
Normal file
1
whisper-medium-finetuned/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Test package
|
||||
7
whisper-medium-finetuned/tests/test_dataset_processor.py
Normal file
7
whisper-medium-finetuned/tests/test_dataset_processor.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Unit tests for PersianDatasetProcessor
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import numpy as n
|
||||
76
whisper-medium-finetuned/tests/test_model_loader.py
Normal file
76
whisper-medium-finetuned/tests/test_model_loader.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
Unit tests for WhisperModelLoader
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import torch
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add src to path for imports
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||
|
||||
from model_loader import WhisperModelLoader
|
||||
|
||||
|
||||
class TestWhisperModelLoader(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.loader = WhisperModelLoader("openai/whisper-medium")
|
||||
|
||||
def test_init(self):
|
||||
"""Test WhisperModelLoader initialization"""
|
||||
self.assertEqual(self.loader.model_name, "openai/whisper-medium")
|
||||
self.assertIsInstance(self.loader.device, torch.device)
|
||||
|
||||
@patch('model_loader.WhisperForConditionalGeneration.from_pretrained')
|
||||
def test_load_model_success(self, mock_from_pretrained):
|
||||
"""Test successful model loading"""
|
||||
mock_model = MagicMock()
|
||||
mock_from_pretrained.return_value = mock_model
|
||||
|
||||
result = self.loader.load_model()
|
||||
|
||||
mock_from_pretrained.assert_called_once()
|
||||
mock_model.to.assert_called_once_with(self.loader.device)
|
||||
self.assertEqual(result, mock_model)
|
||||
|
||||
@patch('model_loader.WhisperProcessor.from_pretrained')
|
||||
def test_load_processor_success(self, mock_from_pretrained):
|
||||
"""Test successful processor loading"""
|
||||
mock_processor = MagicMock()
|
||||
mock_from_pretrained.return_value = mock_processor
|
||||
|
||||
result = self.loader.load_processor()
|
||||
|
||||
mock_from_pretrained.assert_called_once_with("openai/whisper-medium")
|
||||
self.assertEqual(result, mock_processor)
|
||||
|
||||
def test_verify_model_compatibility_valid_model(self):
|
||||
"""Test model compatibility verification with valid model"""
|
||||
mock_model = MagicMock()
|
||||
mock_model.model.encoder = MagicMock()
|
||||
mock_model.model.decoder = MagicMock()
|
||||
mock_model.parameters.return_value = [MagicMock(requires_grad=True)]
|
||||
|
||||
# Mock forward pass
|
||||
mock_output = MagicMock()
|
||||
mock_output.loss = torch.tensor(1.0)
|
||||
mock_model.return_value = mock_output
|
||||
|
||||
result = self.loader.verify_model_compatibility(mock_model)
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_verify_model_compatibility_invalid_model(self):
|
||||
"""Test model compatibility verification with invalid model"""
|
||||
mock_model = MagicMock()
|
||||
# Remove required attributes
|
||||
del mock_model.model
|
||||
|
||||
result = self.loader.verify_model_compatibility(mock_model)
|
||||
self.assertFalse(result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user