76 lines
2.5 KiB
Python
76 lines
2.5 KiB
Python
"""
|
|
Unit tests for WhisperModelLoader
|
|
"""
|
|
|
|
import unittest
|
|
from unittest.mock import patch, MagicMock
|
|
import torch
|
|
import sys
|
|
import os
|
|
|
|
# Add src to path for imports
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
|
|
|
from model_loader import WhisperModelLoader
|
|
|
|
|
|
class TestWhisperModelLoader(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
self.loader = WhisperModelLoader("openai/whisper-medium")
|
|
|
|
def test_init(self):
|
|
"""Test WhisperModelLoader initialization"""
|
|
self.assertEqual(self.loader.model_name, "openai/whisper-medium")
|
|
self.assertIsInstance(self.loader.device, torch.device)
|
|
|
|
@patch('model_loader.WhisperForConditionalGeneration.from_pretrained')
|
|
def test_load_model_success(self, mock_from_pretrained):
|
|
"""Test successful model loading"""
|
|
mock_model = MagicMock()
|
|
mock_from_pretrained.return_value = mock_model
|
|
|
|
result = self.loader.load_model()
|
|
|
|
mock_from_pretrained.assert_called_once()
|
|
mock_model.to.assert_called_once_with(self.loader.device)
|
|
self.assertEqual(result, mock_model)
|
|
|
|
@patch('model_loader.WhisperProcessor.from_pretrained')
|
|
def test_load_processor_success(self, mock_from_pretrained):
|
|
"""Test successful processor loading"""
|
|
mock_processor = MagicMock()
|
|
mock_from_pretrained.return_value = mock_processor
|
|
|
|
result = self.loader.load_processor()
|
|
|
|
mock_from_pretrained.assert_called_once_with("openai/whisper-medium")
|
|
self.assertEqual(result, mock_processor)
|
|
|
|
def test_verify_model_compatibility_valid_model(self):
|
|
"""Test model compatibility verification with valid model"""
|
|
mock_model = MagicMock()
|
|
mock_model.model.encoder = MagicMock()
|
|
mock_model.model.decoder = MagicMock()
|
|
mock_model.parameters.return_value = [MagicMock(requires_grad=True)]
|
|
|
|
# Mock forward pass
|
|
mock_output = MagicMock()
|
|
mock_output.loss = torch.tensor(1.0)
|
|
mock_model.return_value = mock_output
|
|
|
|
result = self.loader.verify_model_compatibility(mock_model)
|
|
self.assertTrue(result)
|
|
|
|
def test_verify_model_compatibility_invalid_model(self):
|
|
"""Test model compatibility verification with invalid model"""
|
|
mock_model = MagicMock()
|
|
# Remove required attributes
|
|
del mock_model.model
|
|
|
|
result = self.loader.verify_model_compatibility(mock_model)
|
|
self.assertFalse(result)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main() |