Sure! Pl
This commit is contained in:
1
whisper-medium-finetuned/tests/__init__.py
Normal file
1
whisper-medium-finetuned/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Test package
|
||||
7
whisper-medium-finetuned/tests/test_dataset_processor.py
Normal file
7
whisper-medium-finetuned/tests/test_dataset_processor.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Unit tests for PersianDatasetProcessor
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import numpy as n
|
||||
76
whisper-medium-finetuned/tests/test_model_loader.py
Normal file
76
whisper-medium-finetuned/tests/test_model_loader.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
Unit tests for WhisperModelLoader
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import torch
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add src to path for imports
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||
|
||||
from model_loader import WhisperModelLoader
|
||||
|
||||
|
||||
class TestWhisperModelLoader(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.loader = WhisperModelLoader("openai/whisper-medium")
|
||||
|
||||
def test_init(self):
|
||||
"""Test WhisperModelLoader initialization"""
|
||||
self.assertEqual(self.loader.model_name, "openai/whisper-medium")
|
||||
self.assertIsInstance(self.loader.device, torch.device)
|
||||
|
||||
@patch('model_loader.WhisperForConditionalGeneration.from_pretrained')
|
||||
def test_load_model_success(self, mock_from_pretrained):
|
||||
"""Test successful model loading"""
|
||||
mock_model = MagicMock()
|
||||
mock_from_pretrained.return_value = mock_model
|
||||
|
||||
result = self.loader.load_model()
|
||||
|
||||
mock_from_pretrained.assert_called_once()
|
||||
mock_model.to.assert_called_once_with(self.loader.device)
|
||||
self.assertEqual(result, mock_model)
|
||||
|
||||
@patch('model_loader.WhisperProcessor.from_pretrained')
|
||||
def test_load_processor_success(self, mock_from_pretrained):
|
||||
"""Test successful processor loading"""
|
||||
mock_processor = MagicMock()
|
||||
mock_from_pretrained.return_value = mock_processor
|
||||
|
||||
result = self.loader.load_processor()
|
||||
|
||||
mock_from_pretrained.assert_called_once_with("openai/whisper-medium")
|
||||
self.assertEqual(result, mock_processor)
|
||||
|
||||
def test_verify_model_compatibility_valid_model(self):
|
||||
"""Test model compatibility verification with valid model"""
|
||||
mock_model = MagicMock()
|
||||
mock_model.model.encoder = MagicMock()
|
||||
mock_model.model.decoder = MagicMock()
|
||||
mock_model.parameters.return_value = [MagicMock(requires_grad=True)]
|
||||
|
||||
# Mock forward pass
|
||||
mock_output = MagicMock()
|
||||
mock_output.loss = torch.tensor(1.0)
|
||||
mock_model.return_value = mock_output
|
||||
|
||||
result = self.loader.verify_model_compatibility(mock_model)
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_verify_model_compatibility_invalid_model(self):
|
||||
"""Test model compatibility verification with invalid model"""
|
||||
mock_model = MagicMock()
|
||||
# Remove required attributes
|
||||
del mock_model.model
|
||||
|
||||
result = self.loader.verify_model_compatibility(mock_model)
|
||||
self.assertFalse(result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user