zeekay commited on
Commit
2272e87
·
verified ·
1 Parent(s): cc4882e

Upload folder using huggingface_hub

Browse files
Files changed (6) hide show
  1. __init__.py +1 -0
  2. conftest.py +60 -0
  3. test_config.py +78 -0
  4. test_pipeline.py +66 -0
  5. test_training.py +136 -0
  6. test_wav2lip_model.py +134 -0
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Zen Translator test suite."""
conftest.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pytest configuration and fixtures."""
2
+
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import pytest
7
+
8
+
9
+ @pytest.fixture
10
+ def sample_audio():
11
+ """Generate sample audio data for testing."""
12
+ # 3 seconds of audio at 16kHz
13
+ duration_seconds = 3.0
14
+ sample_rate = 16000
15
+ samples = int(duration_seconds * sample_rate)
16
+
17
+ # Generate a simple sine wave
18
+ t = np.linspace(0, duration_seconds, samples)
19
+ audio = np.sin(2 * np.pi * 440 * t).astype(np.float32)
20
+
21
+ return audio, sample_rate
22
+
23
+
24
+ @pytest.fixture
25
+ def sample_video_frame():
26
+ """Generate sample video frame for testing."""
27
+ # RGB frame 256x256
28
+ frame = np.random.randint(0, 255, (256, 256, 3), dtype=np.uint8)
29
+ return frame
30
+
31
+
32
+ @pytest.fixture
33
+ def temp_audio_file(tmp_path, sample_audio):
34
+ """Create a temporary audio file."""
35
+ import soundfile as sf
36
+
37
+ audio, sr = sample_audio
38
+ audio_path = tmp_path / "test_audio.wav"
39
+ sf.write(str(audio_path), audio, sr)
40
+
41
+ return audio_path
42
+
43
+
44
+ @pytest.fixture
45
+ def translator_config():
46
+ """Create test translator configuration."""
47
+ from zen_translator.config import TranslatorConfig
48
+
49
+ return TranslatorConfig(
50
+ device="cpu",
51
+ dtype="float32",
52
+ enable_lip_sync=False, # Disable for faster tests
53
+ use_flash_attention=False,
54
+ )
55
+
56
+
57
+ @pytest.fixture
58
+ def test_data_dir():
59
+ """Get test data directory."""
60
+ return Path(__file__).parent / "data"
test_config.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for configuration module."""
2
+
3
+
4
+ class TestTranslatorConfig:
5
+ """Tests for TranslatorConfig."""
6
+
7
+ def test_default_config(self):
8
+ """Test default configuration values."""
9
+ from zen_translator.config import TranslatorConfig
10
+
11
+ config = TranslatorConfig()
12
+
13
+ assert config.target_language == "en"
14
+ assert config.device == "cuda"
15
+ assert config.dtype == "bfloat16"
16
+ assert config.enable_lip_sync is True
17
+ assert config.voice_reference_seconds == 3.0
18
+
19
+ def test_config_from_env(self, monkeypatch):
20
+ """Test configuration from environment variables."""
21
+ from zen_translator.config import TranslatorConfig
22
+
23
+ monkeypatch.setenv("ZEN_TRANSLATOR_TARGET_LANGUAGE", "es")
24
+ monkeypatch.setenv("ZEN_TRANSLATOR_DEVICE", "cpu")
25
+
26
+ config = TranslatorConfig()
27
+
28
+ assert config.target_language == "es"
29
+ assert config.device == "cpu"
30
+
31
+ def test_supported_languages(self):
32
+ """Test supported language lists."""
33
+ from zen_translator.config import TranslatorConfig
34
+
35
+ config = TranslatorConfig()
36
+
37
+ # Check input languages
38
+ assert "en" in config.supported_input_languages
39
+ assert "zh" in config.supported_input_languages
40
+ assert "ja" in config.supported_input_languages
41
+ assert "yue" in config.supported_input_languages # Cantonese
42
+
43
+ # Check output languages
44
+ assert "en" in config.supported_output_languages
45
+ assert "zh" in config.supported_output_languages
46
+ assert len(config.supported_output_languages) == 10
47
+
48
+ def test_lip_sync_quality_options(self):
49
+ """Test lip sync quality options."""
50
+ from zen_translator.config import TranslatorConfig
51
+
52
+ for quality in ["fast", "balanced", "quality"]:
53
+ config = TranslatorConfig(lip_sync_quality=quality)
54
+ assert config.lip_sync_quality == quality
55
+
56
+
57
+ class TestNewsAnchorConfig:
58
+ """Tests for NewsAnchorConfig."""
59
+
60
+ def test_default_config(self):
61
+ """Test default news anchor config."""
62
+ from zen_translator.config import NewsAnchorConfig
63
+
64
+ config = NewsAnchorConfig()
65
+
66
+ assert config.min_clip_duration == 5.0
67
+ assert config.max_clip_duration == 30.0
68
+ assert len(config.target_anchors) > 0
69
+
70
+ def test_training_settings(self):
71
+ """Test training hyperparameters."""
72
+ from zen_translator.config import NewsAnchorConfig
73
+
74
+ config = NewsAnchorConfig()
75
+
76
+ assert config.batch_size == 4
77
+ assert config.learning_rate == 2e-5
78
+ assert config.num_epochs == 3
test_pipeline.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for translation pipeline."""
2
+
3
+
4
+ class TestTranslationPipeline:
5
+ """Tests for TranslationPipeline."""
6
+
7
+ def test_pipeline_initialization(self, translator_config):
8
+ """Test pipeline can be initialized."""
9
+ from zen_translator.pipeline import TranslationPipeline
10
+
11
+ pipeline = TranslationPipeline(translator_config)
12
+
13
+ assert pipeline.config == translator_config
14
+ assert pipeline.translator is not None
15
+ assert pipeline.voice_cloner is not None
16
+ assert pipeline._loaded is False
17
+
18
+ def test_get_supported_languages(self, translator_config):
19
+ """Test getting supported languages."""
20
+ from zen_translator.pipeline import TranslationPipeline
21
+
22
+ pipeline = TranslationPipeline(translator_config)
23
+ languages = pipeline.get_supported_languages()
24
+
25
+ assert "input" in languages
26
+ assert "output" in languages
27
+ assert len(languages["input"]) >= 18
28
+ assert len(languages["output"]) == 10
29
+
30
+
31
+ class TestBatchTranslationPipeline:
32
+ """Tests for BatchTranslationPipeline."""
33
+
34
+ def test_batch_pipeline_initialization(self, translator_config):
35
+ """Test batch pipeline can be initialized."""
36
+ from zen_translator.pipeline import BatchTranslationPipeline
37
+
38
+ pipeline = BatchTranslationPipeline(translator_config)
39
+
40
+ assert pipeline.config == translator_config
41
+
42
+
43
+ class TestPipelineConfig:
44
+ """Tests for pipeline configuration options."""
45
+
46
+ def test_default_config(self):
47
+ """Test default pipeline configuration."""
48
+ from zen_translator import TranslatorConfig
49
+
50
+ config = TranslatorConfig()
51
+
52
+ assert config.qwen3_omni_model == "Qwen/Qwen3-Omni-30B-A3B-Instruct"
53
+ assert config.cosyvoice_model == "FunAudioLLM/CosyVoice2-0.5B"
54
+ assert config.wav2lip_model == "numz/wav2lip_studio"
55
+
56
+ def test_custom_model_paths(self):
57
+ """Test custom model path configuration."""
58
+ from zen_translator import TranslatorConfig
59
+
60
+ config = TranslatorConfig(
61
+ qwen3_omni_model="./local/qwen3-omni",
62
+ cosyvoice_model="./local/cosyvoice",
63
+ )
64
+
65
+ assert config.qwen3_omni_model == "./local/qwen3-omni"
66
+ assert config.cosyvoice_model == "./local/cosyvoice"
test_training.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for training infrastructure."""
2
+
3
+
4
+ class TestSwiftConfig:
5
+ """Tests for ms-swift training configuration."""
6
+
7
+ def test_default_config(self):
8
+ """Test default training config."""
9
+ from zen_translator.training import SwiftTrainingConfig
10
+
11
+ config = SwiftTrainingConfig()
12
+
13
+ assert config.model_type == "qwen3-omni"
14
+ assert config.train_type == "lora"
15
+ assert config.lora_rank == 64
16
+ assert config.lora_alpha == 128
17
+
18
+ def test_to_swift_args(self):
19
+ """Test conversion to swift CLI arguments."""
20
+ from zen_translator.training import SwiftTrainingConfig
21
+
22
+ config = SwiftTrainingConfig()
23
+ args = config.to_swift_args()
24
+
25
+ assert "--model_type=qwen3-omni" in args
26
+ assert "--train_type=lora" in args
27
+ assert "--lora_rank=64" in args
28
+
29
+ def test_to_yaml(self, tmp_path):
30
+ """Test YAML export."""
31
+ from zen_translator.training import SwiftTrainingConfig
32
+
33
+ config = SwiftTrainingConfig()
34
+ yaml_path = tmp_path / "config.yaml"
35
+
36
+ config.to_yaml(yaml_path)
37
+
38
+ assert yaml_path.exists()
39
+
40
+ # Verify content
41
+ import yaml
42
+
43
+ with open(yaml_path) as f:
44
+ saved = yaml.safe_load(f)
45
+
46
+ assert saved["model"]["type"] == "qwen3-omni"
47
+ assert saved["lora"]["rank"] == 64
48
+
49
+
50
+ class TestZenIdentityConfig:
51
+ """Tests for Zen identity finetuning config."""
52
+
53
+ def test_identity_system_prompt(self):
54
+ """Test identity system prompt is set."""
55
+ from zen_translator.training import ZenIdentityConfig
56
+
57
+ config = ZenIdentityConfig()
58
+
59
+ assert "Zen Translator" in config.system_prompt
60
+ assert "Hanzo AI" in config.system_prompt
61
+
62
+
63
+ class TestNewsAnchorConfig:
64
+ """Tests for news anchor training config."""
65
+
66
+ def test_anchor_names(self):
67
+ """Test anchor names are configured."""
68
+ from zen_translator.training import NewsAnchorConfig
69
+
70
+ config = NewsAnchorConfig()
71
+
72
+ assert len(config.anchor_names) > 0
73
+ assert "cnn" in config.anchor_names
74
+ assert "bbc" in config.anchor_names
75
+
76
+ def test_news_domains(self):
77
+ """Test news domains are configured."""
78
+ from zen_translator.training import NewsAnchorConfig
79
+
80
+ config = NewsAnchorConfig()
81
+
82
+ assert "politics" in config.news_domains
83
+ assert "technology" in config.news_domains
84
+
85
+
86
+ class TestNewsChannels:
87
+ """Tests for predefined news channels."""
88
+
89
+ def test_channels_defined(self):
90
+ """Test news channels are defined."""
91
+ from zen_translator.training import NEWS_CHANNELS
92
+
93
+ assert len(NEWS_CHANNELS) > 0
94
+ assert "cnn" in NEWS_CHANNELS
95
+ assert "bbc" in NEWS_CHANNELS
96
+ assert "nhk" in NEWS_CHANNELS
97
+
98
+ def test_channel_urls(self):
99
+ """Test channel URLs are valid."""
100
+ from zen_translator.training import NEWS_CHANNELS
101
+
102
+ for name, url in NEWS_CHANNELS.items():
103
+ assert url.startswith("https://")
104
+ assert "youtube.com" in url
105
+
106
+
107
+ class TestCreateTrainingDataset:
108
+ """Tests for dataset creation."""
109
+
110
+ def test_create_jsonl_dataset(self, tmp_path):
111
+ """Test JSONL dataset creation."""
112
+ from zen_translator.training import create_training_dataset
113
+
114
+ conversations = [
115
+ {
116
+ "conversations": [
117
+ {"role": "user", "content": "Hello"},
118
+ {"role": "assistant", "content": "Hi there!"},
119
+ ]
120
+ }
121
+ ]
122
+
123
+ output_path = tmp_path / "train.jsonl"
124
+ create_training_dataset(conversations, output_path, format="jsonl")
125
+
126
+ assert output_path.exists()
127
+
128
+ # Verify content
129
+ import json
130
+
131
+ with open(output_path) as f:
132
+ lines = f.readlines()
133
+
134
+ assert len(lines) == 1
135
+ data = json.loads(lines[0])
136
+ assert "conversations" in data
test_wav2lip_model.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for Wav2Lip model architecture."""
2
+
3
+ import torch
4
+
5
+
6
+ class TestWav2LipModel:
7
+ """Tests for Wav2Lip neural network."""
8
+
9
+ def test_model_initialization(self):
10
+ """Test model can be initialized."""
11
+ from zen_translator.lip_sync.wav2lip_model import Wav2Lip
12
+
13
+ model = Wav2Lip()
14
+
15
+ assert model.audio_encoder is not None
16
+ assert model.face_encoder is not None
17
+ assert model.face_decoder is not None
18
+
19
+ def test_model_forward_shape(self):
20
+ """Test model forward pass produces correct output shape."""
21
+ from zen_translator.lip_sync.wav2lip_model import Wav2Lip
22
+
23
+ model = Wav2Lip()
24
+ model.eval()
25
+
26
+ # Create dummy inputs
27
+ batch_size = 2
28
+ mel_length = 16
29
+ mel_channels = 80
30
+
31
+ # Audio: (B, T, 1, 80, 16) -> mel spectrogram windows
32
+ audio = torch.randn(batch_size, 1, 1, mel_channels, mel_length)
33
+
34
+ # Face: (B, 6, 96, 96) -> half face + reference
35
+ face = torch.randn(batch_size, 6, 96, 96)
36
+
37
+ with torch.no_grad():
38
+ output = model(audio, face)
39
+
40
+ # Output should be (B, 3, 96, 96)
41
+ assert output.shape == (batch_size, 3, 96, 96)
42
+
43
+ def test_audio_encoder(self):
44
+ """Test audio encoder produces correct embedding."""
45
+ from zen_translator.lip_sync.wav2lip_model import AudioEncoder
46
+
47
+ encoder = AudioEncoder()
48
+ encoder.eval()
49
+
50
+ batch_size = 2
51
+ audio = torch.randn(batch_size, 1, 1, 80, 16)
52
+
53
+ with torch.no_grad():
54
+ embedding = encoder(audio)
55
+
56
+ # Should produce 512-dim embedding
57
+ assert embedding.shape[-3] == 512
58
+
59
+ def test_face_encoder(self):
60
+ """Test face encoder produces feature hierarchy."""
61
+ from zen_translator.lip_sync.wav2lip_model import FaceEncoder
62
+
63
+ encoder = FaceEncoder()
64
+ encoder.eval()
65
+
66
+ batch_size = 2
67
+ face = torch.randn(batch_size, 6, 96, 96)
68
+
69
+ with torch.no_grad():
70
+ features = encoder(face)
71
+
72
+ # Should produce 7 feature maps (one per block)
73
+ assert len(features) == 7
74
+
75
+
76
+ class TestConvBlocks:
77
+ """Tests for convolution building blocks."""
78
+
79
+ def test_conv2d_block(self):
80
+ """Test Conv2d block."""
81
+ from zen_translator.lip_sync.wav2lip_model import Conv2d
82
+
83
+ block = Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
84
+ x = torch.randn(1, 3, 64, 64)
85
+
86
+ out = block(x)
87
+
88
+ assert out.shape == (1, 32, 64, 64)
89
+
90
+ def test_conv2d_residual(self):
91
+ """Test Conv2d with residual connection."""
92
+ from zen_translator.lip_sync.wav2lip_model import Conv2d
93
+
94
+ block = Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)
95
+ x = torch.randn(1, 32, 64, 64)
96
+
97
+ out = block(x)
98
+
99
+ # With residual, output should be different from non-residual
100
+ assert out.shape == (1, 32, 64, 64)
101
+
102
+ def test_transpose_conv2d(self):
103
+ """Test ConvTranspose2d block."""
104
+ from zen_translator.lip_sync.wav2lip_model import ConvTranspose2d
105
+
106
+ block = ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1)
107
+ x = torch.randn(1, 32, 32, 32)
108
+
109
+ out = block(x)
110
+
111
+ # Should upsample by factor of 2
112
+ assert out.shape == (1, 16, 64, 64)
113
+
114
+
115
+ class TestSyncDiscriminator:
116
+ """Tests for sync discriminator."""
117
+
118
+ def test_discriminator_output(self):
119
+ """Test sync discriminator produces probability."""
120
+ from zen_translator.lip_sync.wav2lip_model import SyncDiscriminator
121
+
122
+ discriminator = SyncDiscriminator()
123
+ discriminator.eval()
124
+
125
+ batch_size = 2
126
+ mel = torch.randn(batch_size, 80, 16)
127
+ face = torch.randn(batch_size, 3, 96, 96)
128
+
129
+ with torch.no_grad():
130
+ output = discriminator(mel, face)
131
+
132
+ # Should produce sync probability
133
+ assert output.shape == (batch_size, 1)
134
+ assert torch.all(output >= 0) and torch.all(output <= 1)