This commit is contained in:
Alireza
2025-07-31 17:35:08 +03:30
commit 640363fef2
27 changed files with 4201 additions and 0 deletions

View File

@@ -0,0 +1 @@
# Test package

View File

@@ -0,0 +1,7 @@
"""
Unit tests for PersianDatasetProcessor
"""
import unittest
from unittest.mock import patch, MagicMock
import numpy as n

View File

@@ -0,0 +1,76 @@
"""
Unit tests for WhisperModelLoader
"""
import unittest
from unittest.mock import patch, MagicMock
import torch
import sys
import os
# Add src to path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
from model_loader import WhisperModelLoader
class TestWhisperModelLoader(unittest.TestCase):
def setUp(self):
self.loader = WhisperModelLoader("openai/whisper-medium")
def test_init(self):
"""Test WhisperModelLoader initialization"""
self.assertEqual(self.loader.model_name, "openai/whisper-medium")
self.assertIsInstance(self.loader.device, torch.device)
@patch('model_loader.WhisperForConditionalGeneration.from_pretrained')
def test_load_model_success(self, mock_from_pretrained):
"""Test successful model loading"""
mock_model = MagicMock()
mock_from_pretrained.return_value = mock_model
result = self.loader.load_model()
mock_from_pretrained.assert_called_once()
mock_model.to.assert_called_once_with(self.loader.device)
self.assertEqual(result, mock_model)
@patch('model_loader.WhisperProcessor.from_pretrained')
def test_load_processor_success(self, mock_from_pretrained):
"""Test successful processor loading"""
mock_processor = MagicMock()
mock_from_pretrained.return_value = mock_processor
result = self.loader.load_processor()
mock_from_pretrained.assert_called_once_with("openai/whisper-medium")
self.assertEqual(result, mock_processor)
def test_verify_model_compatibility_valid_model(self):
"""Test model compatibility verification with valid model"""
mock_model = MagicMock()
mock_model.model.encoder = MagicMock()
mock_model.model.decoder = MagicMock()
mock_model.parameters.return_value = [MagicMock(requires_grad=True)]
# Mock forward pass
mock_output = MagicMock()
mock_output.loss = torch.tensor(1.0)
mock_model.return_value = mock_output
result = self.loader.verify_model_compatibility(mock_model)
self.assertTrue(result)
def test_verify_model_compatibility_invalid_model(self):
"""Test model compatibility verification with invalid model"""
mock_model = MagicMock()
# Remove required attributes
del mock_model.model
result = self.loader.verify_model_compatibility(mock_model)
self.assertFalse(result)
if __name__ == '__main__':
unittest.main()