Instructions to use zenlm/zen-translator with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use zenlm/zen-translator with Transformers:
# Use a pipeline as a high-level helper # Warning: Pipeline type "translation" is no longer supported in transformers v5. # You must load the model directly (see below) or downgrade to v4.x with: # 'pip install "transformers<5.0.0' from transformers import pipeline pipe = pipeline("translation", model="zenlm/zen-translator")# Load model directly from transformers import ZenTranslatorForSpeechTranslation model = ZenTranslatorForSpeechTranslation.from_pretrained("zenlm/zen-translator", dtype="auto") - Notebooks
- Google Colab
- Kaggle
Upload folder using huggingface_hub
Browse files- __init__.py +1 -0
- conftest.py +60 -0
- test_config.py +78 -0
- test_pipeline.py +66 -0
- test_training.py +136 -0
- test_wav2lip_model.py +134 -0
__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Zen Translator test suite."""
|
conftest.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pytest configuration and fixtures."""
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pytest
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@pytest.fixture
|
| 10 |
+
def sample_audio():
|
| 11 |
+
"""Generate sample audio data for testing."""
|
| 12 |
+
# 3 seconds of audio at 16kHz
|
| 13 |
+
duration_seconds = 3.0
|
| 14 |
+
sample_rate = 16000
|
| 15 |
+
samples = int(duration_seconds * sample_rate)
|
| 16 |
+
|
| 17 |
+
# Generate a simple sine wave
|
| 18 |
+
t = np.linspace(0, duration_seconds, samples)
|
| 19 |
+
audio = np.sin(2 * np.pi * 440 * t).astype(np.float32)
|
| 20 |
+
|
| 21 |
+
return audio, sample_rate
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@pytest.fixture
|
| 25 |
+
def sample_video_frame():
|
| 26 |
+
"""Generate sample video frame for testing."""
|
| 27 |
+
# RGB frame 256x256
|
| 28 |
+
frame = np.random.randint(0, 255, (256, 256, 3), dtype=np.uint8)
|
| 29 |
+
return frame
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@pytest.fixture
|
| 33 |
+
def temp_audio_file(tmp_path, sample_audio):
|
| 34 |
+
"""Create a temporary audio file."""
|
| 35 |
+
import soundfile as sf
|
| 36 |
+
|
| 37 |
+
audio, sr = sample_audio
|
| 38 |
+
audio_path = tmp_path / "test_audio.wav"
|
| 39 |
+
sf.write(str(audio_path), audio, sr)
|
| 40 |
+
|
| 41 |
+
return audio_path
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@pytest.fixture
|
| 45 |
+
def translator_config():
|
| 46 |
+
"""Create test translator configuration."""
|
| 47 |
+
from zen_translator.config import TranslatorConfig
|
| 48 |
+
|
| 49 |
+
return TranslatorConfig(
|
| 50 |
+
device="cpu",
|
| 51 |
+
dtype="float32",
|
| 52 |
+
enable_lip_sync=False, # Disable for faster tests
|
| 53 |
+
use_flash_attention=False,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@pytest.fixture
|
| 58 |
+
def test_data_dir():
|
| 59 |
+
"""Get test data directory."""
|
| 60 |
+
return Path(__file__).parent / "data"
|
test_config.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for configuration module."""
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class TestTranslatorConfig:
|
| 5 |
+
"""Tests for TranslatorConfig."""
|
| 6 |
+
|
| 7 |
+
def test_default_config(self):
|
| 8 |
+
"""Test default configuration values."""
|
| 9 |
+
from zen_translator.config import TranslatorConfig
|
| 10 |
+
|
| 11 |
+
config = TranslatorConfig()
|
| 12 |
+
|
| 13 |
+
assert config.target_language == "en"
|
| 14 |
+
assert config.device == "cuda"
|
| 15 |
+
assert config.dtype == "bfloat16"
|
| 16 |
+
assert config.enable_lip_sync is True
|
| 17 |
+
assert config.voice_reference_seconds == 3.0
|
| 18 |
+
|
| 19 |
+
def test_config_from_env(self, monkeypatch):
|
| 20 |
+
"""Test configuration from environment variables."""
|
| 21 |
+
from zen_translator.config import TranslatorConfig
|
| 22 |
+
|
| 23 |
+
monkeypatch.setenv("ZEN_TRANSLATOR_TARGET_LANGUAGE", "es")
|
| 24 |
+
monkeypatch.setenv("ZEN_TRANSLATOR_DEVICE", "cpu")
|
| 25 |
+
|
| 26 |
+
config = TranslatorConfig()
|
| 27 |
+
|
| 28 |
+
assert config.target_language == "es"
|
| 29 |
+
assert config.device == "cpu"
|
| 30 |
+
|
| 31 |
+
def test_supported_languages(self):
|
| 32 |
+
"""Test supported language lists."""
|
| 33 |
+
from zen_translator.config import TranslatorConfig
|
| 34 |
+
|
| 35 |
+
config = TranslatorConfig()
|
| 36 |
+
|
| 37 |
+
# Check input languages
|
| 38 |
+
assert "en" in config.supported_input_languages
|
| 39 |
+
assert "zh" in config.supported_input_languages
|
| 40 |
+
assert "ja" in config.supported_input_languages
|
| 41 |
+
assert "yue" in config.supported_input_languages # Cantonese
|
| 42 |
+
|
| 43 |
+
# Check output languages
|
| 44 |
+
assert "en" in config.supported_output_languages
|
| 45 |
+
assert "zh" in config.supported_output_languages
|
| 46 |
+
assert len(config.supported_output_languages) == 10
|
| 47 |
+
|
| 48 |
+
def test_lip_sync_quality_options(self):
|
| 49 |
+
"""Test lip sync quality options."""
|
| 50 |
+
from zen_translator.config import TranslatorConfig
|
| 51 |
+
|
| 52 |
+
for quality in ["fast", "balanced", "quality"]:
|
| 53 |
+
config = TranslatorConfig(lip_sync_quality=quality)
|
| 54 |
+
assert config.lip_sync_quality == quality
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class TestNewsAnchorConfig:
|
| 58 |
+
"""Tests for NewsAnchorConfig."""
|
| 59 |
+
|
| 60 |
+
def test_default_config(self):
|
| 61 |
+
"""Test default news anchor config."""
|
| 62 |
+
from zen_translator.config import NewsAnchorConfig
|
| 63 |
+
|
| 64 |
+
config = NewsAnchorConfig()
|
| 65 |
+
|
| 66 |
+
assert config.min_clip_duration == 5.0
|
| 67 |
+
assert config.max_clip_duration == 30.0
|
| 68 |
+
assert len(config.target_anchors) > 0
|
| 69 |
+
|
| 70 |
+
def test_training_settings(self):
|
| 71 |
+
"""Test training hyperparameters."""
|
| 72 |
+
from zen_translator.config import NewsAnchorConfig
|
| 73 |
+
|
| 74 |
+
config = NewsAnchorConfig()
|
| 75 |
+
|
| 76 |
+
assert config.batch_size == 4
|
| 77 |
+
assert config.learning_rate == 2e-5
|
| 78 |
+
assert config.num_epochs == 3
|
test_pipeline.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for translation pipeline."""
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class TestTranslationPipeline:
|
| 5 |
+
"""Tests for TranslationPipeline."""
|
| 6 |
+
|
| 7 |
+
def test_pipeline_initialization(self, translator_config):
|
| 8 |
+
"""Test pipeline can be initialized."""
|
| 9 |
+
from zen_translator.pipeline import TranslationPipeline
|
| 10 |
+
|
| 11 |
+
pipeline = TranslationPipeline(translator_config)
|
| 12 |
+
|
| 13 |
+
assert pipeline.config == translator_config
|
| 14 |
+
assert pipeline.translator is not None
|
| 15 |
+
assert pipeline.voice_cloner is not None
|
| 16 |
+
assert pipeline._loaded is False
|
| 17 |
+
|
| 18 |
+
def test_get_supported_languages(self, translator_config):
|
| 19 |
+
"""Test getting supported languages."""
|
| 20 |
+
from zen_translator.pipeline import TranslationPipeline
|
| 21 |
+
|
| 22 |
+
pipeline = TranslationPipeline(translator_config)
|
| 23 |
+
languages = pipeline.get_supported_languages()
|
| 24 |
+
|
| 25 |
+
assert "input" in languages
|
| 26 |
+
assert "output" in languages
|
| 27 |
+
assert len(languages["input"]) >= 18
|
| 28 |
+
assert len(languages["output"]) == 10
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class TestBatchTranslationPipeline:
|
| 32 |
+
"""Tests for BatchTranslationPipeline."""
|
| 33 |
+
|
| 34 |
+
def test_batch_pipeline_initialization(self, translator_config):
|
| 35 |
+
"""Test batch pipeline can be initialized."""
|
| 36 |
+
from zen_translator.pipeline import BatchTranslationPipeline
|
| 37 |
+
|
| 38 |
+
pipeline = BatchTranslationPipeline(translator_config)
|
| 39 |
+
|
| 40 |
+
assert pipeline.config == translator_config
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class TestPipelineConfig:
|
| 44 |
+
"""Tests for pipeline configuration options."""
|
| 45 |
+
|
| 46 |
+
def test_default_config(self):
|
| 47 |
+
"""Test default pipeline configuration."""
|
| 48 |
+
from zen_translator import TranslatorConfig
|
| 49 |
+
|
| 50 |
+
config = TranslatorConfig()
|
| 51 |
+
|
| 52 |
+
assert config.qwen3_omni_model == "Qwen/Qwen3-Omni-30B-A3B-Instruct"
|
| 53 |
+
assert config.cosyvoice_model == "FunAudioLLM/CosyVoice2-0.5B"
|
| 54 |
+
assert config.wav2lip_model == "numz/wav2lip_studio"
|
| 55 |
+
|
| 56 |
+
def test_custom_model_paths(self):
|
| 57 |
+
"""Test custom model path configuration."""
|
| 58 |
+
from zen_translator import TranslatorConfig
|
| 59 |
+
|
| 60 |
+
config = TranslatorConfig(
|
| 61 |
+
qwen3_omni_model="./local/qwen3-omni",
|
| 62 |
+
cosyvoice_model="./local/cosyvoice",
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
assert config.qwen3_omni_model == "./local/qwen3-omni"
|
| 66 |
+
assert config.cosyvoice_model == "./local/cosyvoice"
|
test_training.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for training infrastructure."""
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class TestSwiftConfig:
|
| 5 |
+
"""Tests for ms-swift training configuration."""
|
| 6 |
+
|
| 7 |
+
def test_default_config(self):
|
| 8 |
+
"""Test default training config."""
|
| 9 |
+
from zen_translator.training import SwiftTrainingConfig
|
| 10 |
+
|
| 11 |
+
config = SwiftTrainingConfig()
|
| 12 |
+
|
| 13 |
+
assert config.model_type == "qwen3-omni"
|
| 14 |
+
assert config.train_type == "lora"
|
| 15 |
+
assert config.lora_rank == 64
|
| 16 |
+
assert config.lora_alpha == 128
|
| 17 |
+
|
| 18 |
+
def test_to_swift_args(self):
|
| 19 |
+
"""Test conversion to swift CLI arguments."""
|
| 20 |
+
from zen_translator.training import SwiftTrainingConfig
|
| 21 |
+
|
| 22 |
+
config = SwiftTrainingConfig()
|
| 23 |
+
args = config.to_swift_args()
|
| 24 |
+
|
| 25 |
+
assert "--model_type=qwen3-omni" in args
|
| 26 |
+
assert "--train_type=lora" in args
|
| 27 |
+
assert "--lora_rank=64" in args
|
| 28 |
+
|
| 29 |
+
def test_to_yaml(self, tmp_path):
|
| 30 |
+
"""Test YAML export."""
|
| 31 |
+
from zen_translator.training import SwiftTrainingConfig
|
| 32 |
+
|
| 33 |
+
config = SwiftTrainingConfig()
|
| 34 |
+
yaml_path = tmp_path / "config.yaml"
|
| 35 |
+
|
| 36 |
+
config.to_yaml(yaml_path)
|
| 37 |
+
|
| 38 |
+
assert yaml_path.exists()
|
| 39 |
+
|
| 40 |
+
# Verify content
|
| 41 |
+
import yaml
|
| 42 |
+
|
| 43 |
+
with open(yaml_path) as f:
|
| 44 |
+
saved = yaml.safe_load(f)
|
| 45 |
+
|
| 46 |
+
assert saved["model"]["type"] == "qwen3-omni"
|
| 47 |
+
assert saved["lora"]["rank"] == 64
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class TestZenIdentityConfig:
|
| 51 |
+
"""Tests for Zen identity finetuning config."""
|
| 52 |
+
|
| 53 |
+
def test_identity_system_prompt(self):
|
| 54 |
+
"""Test identity system prompt is set."""
|
| 55 |
+
from zen_translator.training import ZenIdentityConfig
|
| 56 |
+
|
| 57 |
+
config = ZenIdentityConfig()
|
| 58 |
+
|
| 59 |
+
assert "Zen Translator" in config.system_prompt
|
| 60 |
+
assert "Hanzo AI" in config.system_prompt
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class TestNewsAnchorConfig:
|
| 64 |
+
"""Tests for news anchor training config."""
|
| 65 |
+
|
| 66 |
+
def test_anchor_names(self):
|
| 67 |
+
"""Test anchor names are configured."""
|
| 68 |
+
from zen_translator.training import NewsAnchorConfig
|
| 69 |
+
|
| 70 |
+
config = NewsAnchorConfig()
|
| 71 |
+
|
| 72 |
+
assert len(config.anchor_names) > 0
|
| 73 |
+
assert "cnn" in config.anchor_names
|
| 74 |
+
assert "bbc" in config.anchor_names
|
| 75 |
+
|
| 76 |
+
def test_news_domains(self):
|
| 77 |
+
"""Test news domains are configured."""
|
| 78 |
+
from zen_translator.training import NewsAnchorConfig
|
| 79 |
+
|
| 80 |
+
config = NewsAnchorConfig()
|
| 81 |
+
|
| 82 |
+
assert "politics" in config.news_domains
|
| 83 |
+
assert "technology" in config.news_domains
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class TestNewsChannels:
|
| 87 |
+
"""Tests for predefined news channels."""
|
| 88 |
+
|
| 89 |
+
def test_channels_defined(self):
|
| 90 |
+
"""Test news channels are defined."""
|
| 91 |
+
from zen_translator.training import NEWS_CHANNELS
|
| 92 |
+
|
| 93 |
+
assert len(NEWS_CHANNELS) > 0
|
| 94 |
+
assert "cnn" in NEWS_CHANNELS
|
| 95 |
+
assert "bbc" in NEWS_CHANNELS
|
| 96 |
+
assert "nhk" in NEWS_CHANNELS
|
| 97 |
+
|
| 98 |
+
def test_channel_urls(self):
|
| 99 |
+
"""Test channel URLs are valid."""
|
| 100 |
+
from zen_translator.training import NEWS_CHANNELS
|
| 101 |
+
|
| 102 |
+
for name, url in NEWS_CHANNELS.items():
|
| 103 |
+
assert url.startswith("https://")
|
| 104 |
+
assert "youtube.com" in url
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class TestCreateTrainingDataset:
|
| 108 |
+
"""Tests for dataset creation."""
|
| 109 |
+
|
| 110 |
+
def test_create_jsonl_dataset(self, tmp_path):
|
| 111 |
+
"""Test JSONL dataset creation."""
|
| 112 |
+
from zen_translator.training import create_training_dataset
|
| 113 |
+
|
| 114 |
+
conversations = [
|
| 115 |
+
{
|
| 116 |
+
"conversations": [
|
| 117 |
+
{"role": "user", "content": "Hello"},
|
| 118 |
+
{"role": "assistant", "content": "Hi there!"},
|
| 119 |
+
]
|
| 120 |
+
}
|
| 121 |
+
]
|
| 122 |
+
|
| 123 |
+
output_path = tmp_path / "train.jsonl"
|
| 124 |
+
create_training_dataset(conversations, output_path, format="jsonl")
|
| 125 |
+
|
| 126 |
+
assert output_path.exists()
|
| 127 |
+
|
| 128 |
+
# Verify content
|
| 129 |
+
import json
|
| 130 |
+
|
| 131 |
+
with open(output_path) as f:
|
| 132 |
+
lines = f.readlines()
|
| 133 |
+
|
| 134 |
+
assert len(lines) == 1
|
| 135 |
+
data = json.loads(lines[0])
|
| 136 |
+
assert "conversations" in data
|
test_wav2lip_model.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for Wav2Lip model architecture."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class TestWav2LipModel:
|
| 7 |
+
"""Tests for Wav2Lip neural network."""
|
| 8 |
+
|
| 9 |
+
def test_model_initialization(self):
|
| 10 |
+
"""Test model can be initialized."""
|
| 11 |
+
from zen_translator.lip_sync.wav2lip_model import Wav2Lip
|
| 12 |
+
|
| 13 |
+
model = Wav2Lip()
|
| 14 |
+
|
| 15 |
+
assert model.audio_encoder is not None
|
| 16 |
+
assert model.face_encoder is not None
|
| 17 |
+
assert model.face_decoder is not None
|
| 18 |
+
|
| 19 |
+
def test_model_forward_shape(self):
|
| 20 |
+
"""Test model forward pass produces correct output shape."""
|
| 21 |
+
from zen_translator.lip_sync.wav2lip_model import Wav2Lip
|
| 22 |
+
|
| 23 |
+
model = Wav2Lip()
|
| 24 |
+
model.eval()
|
| 25 |
+
|
| 26 |
+
# Create dummy inputs
|
| 27 |
+
batch_size = 2
|
| 28 |
+
mel_length = 16
|
| 29 |
+
mel_channels = 80
|
| 30 |
+
|
| 31 |
+
# Audio: (B, T, 1, 80, 16) -> mel spectrogram windows
|
| 32 |
+
audio = torch.randn(batch_size, 1, 1, mel_channels, mel_length)
|
| 33 |
+
|
| 34 |
+
# Face: (B, 6, 96, 96) -> half face + reference
|
| 35 |
+
face = torch.randn(batch_size, 6, 96, 96)
|
| 36 |
+
|
| 37 |
+
with torch.no_grad():
|
| 38 |
+
output = model(audio, face)
|
| 39 |
+
|
| 40 |
+
# Output should be (B, 3, 96, 96)
|
| 41 |
+
assert output.shape == (batch_size, 3, 96, 96)
|
| 42 |
+
|
| 43 |
+
def test_audio_encoder(self):
|
| 44 |
+
"""Test audio encoder produces correct embedding."""
|
| 45 |
+
from zen_translator.lip_sync.wav2lip_model import AudioEncoder
|
| 46 |
+
|
| 47 |
+
encoder = AudioEncoder()
|
| 48 |
+
encoder.eval()
|
| 49 |
+
|
| 50 |
+
batch_size = 2
|
| 51 |
+
audio = torch.randn(batch_size, 1, 1, 80, 16)
|
| 52 |
+
|
| 53 |
+
with torch.no_grad():
|
| 54 |
+
embedding = encoder(audio)
|
| 55 |
+
|
| 56 |
+
# Should produce 512-dim embedding
|
| 57 |
+
assert embedding.shape[-3] == 512
|
| 58 |
+
|
| 59 |
+
def test_face_encoder(self):
|
| 60 |
+
"""Test face encoder produces feature hierarchy."""
|
| 61 |
+
from zen_translator.lip_sync.wav2lip_model import FaceEncoder
|
| 62 |
+
|
| 63 |
+
encoder = FaceEncoder()
|
| 64 |
+
encoder.eval()
|
| 65 |
+
|
| 66 |
+
batch_size = 2
|
| 67 |
+
face = torch.randn(batch_size, 6, 96, 96)
|
| 68 |
+
|
| 69 |
+
with torch.no_grad():
|
| 70 |
+
features = encoder(face)
|
| 71 |
+
|
| 72 |
+
# Should produce 7 feature maps (one per block)
|
| 73 |
+
assert len(features) == 7
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class TestConvBlocks:
|
| 77 |
+
"""Tests for convolution building blocks."""
|
| 78 |
+
|
| 79 |
+
def test_conv2d_block(self):
|
| 80 |
+
"""Test Conv2d block."""
|
| 81 |
+
from zen_translator.lip_sync.wav2lip_model import Conv2d
|
| 82 |
+
|
| 83 |
+
block = Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
|
| 84 |
+
x = torch.randn(1, 3, 64, 64)
|
| 85 |
+
|
| 86 |
+
out = block(x)
|
| 87 |
+
|
| 88 |
+
assert out.shape == (1, 32, 64, 64)
|
| 89 |
+
|
| 90 |
+
def test_conv2d_residual(self):
|
| 91 |
+
"""Test Conv2d with residual connection."""
|
| 92 |
+
from zen_translator.lip_sync.wav2lip_model import Conv2d
|
| 93 |
+
|
| 94 |
+
block = Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)
|
| 95 |
+
x = torch.randn(1, 32, 64, 64)
|
| 96 |
+
|
| 97 |
+
out = block(x)
|
| 98 |
+
|
| 99 |
+
# With residual, output should be different from non-residual
|
| 100 |
+
assert out.shape == (1, 32, 64, 64)
|
| 101 |
+
|
| 102 |
+
def test_transpose_conv2d(self):
|
| 103 |
+
"""Test ConvTranspose2d block."""
|
| 104 |
+
from zen_translator.lip_sync.wav2lip_model import ConvTranspose2d
|
| 105 |
+
|
| 106 |
+
block = ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1)
|
| 107 |
+
x = torch.randn(1, 32, 32, 32)
|
| 108 |
+
|
| 109 |
+
out = block(x)
|
| 110 |
+
|
| 111 |
+
# Should upsample by factor of 2
|
| 112 |
+
assert out.shape == (1, 16, 64, 64)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class TestSyncDiscriminator:
|
| 116 |
+
"""Tests for sync discriminator."""
|
| 117 |
+
|
| 118 |
+
def test_discriminator_output(self):
|
| 119 |
+
"""Test sync discriminator produces probability."""
|
| 120 |
+
from zen_translator.lip_sync.wav2lip_model import SyncDiscriminator
|
| 121 |
+
|
| 122 |
+
discriminator = SyncDiscriminator()
|
| 123 |
+
discriminator.eval()
|
| 124 |
+
|
| 125 |
+
batch_size = 2
|
| 126 |
+
mel = torch.randn(batch_size, 80, 16)
|
| 127 |
+
face = torch.randn(batch_size, 3, 96, 96)
|
| 128 |
+
|
| 129 |
+
with torch.no_grad():
|
| 130 |
+
output = discriminator(mel, face)
|
| 131 |
+
|
| 132 |
+
# Should produce sync probability
|
| 133 |
+
assert output.shape == (batch_size, 1)
|
| 134 |
+
assert torch.all(output >= 0) and torch.all(output <= 1)
|