""" Unit tests for WhisperModelLoader """ import unittest from unittest.mock import patch, MagicMock import torch import sys import os # Add src to path for imports sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) from model_loader import WhisperModelLoader class TestWhisperModelLoader(unittest.TestCase): def setUp(self): self.loader = WhisperModelLoader("openai/whisper-medium") def test_init(self): """Test WhisperModelLoader initialization""" self.assertEqual(self.loader.model_name, "openai/whisper-medium") self.assertIsInstance(self.loader.device, torch.device) @patch('model_loader.WhisperForConditionalGeneration.from_pretrained') def test_load_model_success(self, mock_from_pretrained): """Test successful model loading""" mock_model = MagicMock() mock_from_pretrained.return_value = mock_model result = self.loader.load_model() mock_from_pretrained.assert_called_once() mock_model.to.assert_called_once_with(self.loader.device) self.assertEqual(result, mock_model) @patch('model_loader.WhisperProcessor.from_pretrained') def test_load_processor_success(self, mock_from_pretrained): """Test successful processor loading""" mock_processor = MagicMock() mock_from_pretrained.return_value = mock_processor result = self.loader.load_processor() mock_from_pretrained.assert_called_once_with("openai/whisper-medium") self.assertEqual(result, mock_processor) def test_verify_model_compatibility_valid_model(self): """Test model compatibility verification with valid model""" mock_model = MagicMock() mock_model.model.encoder = MagicMock() mock_model.model.decoder = MagicMock() mock_model.parameters.return_value = [MagicMock(requires_grad=True)] # Mock forward pass mock_output = MagicMock() mock_output.loss = torch.tensor(1.0) mock_model.return_value = mock_output result = self.loader.verify_model_compatibility(mock_model) self.assertTrue(result) def test_verify_model_compatibility_invalid_model(self): """Test model compatibility verification with invalid model""" mock_model = MagicMock() # Remove required attributes del mock_model.model result = self.loader.verify_model_compatibility(mock_model) self.assertFalse(result) if __name__ == '__main__': unittest.main()