| | import unittest |
| | import torch |
| | import sys |
| | import os |
| |
|
| | |
| | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
| |
|
| | from src.config import ModelConfig |
| | from src.models.autoencoder import LatentAutoencoder |
| | from src.models.dit import FlowDiT |
| |
|
| | class TestModels(unittest.TestCase): |
| | def setUp(self): |
| | |
| | self.cfg = ModelConfig( |
| | encoder_name="roberta-base", |
| | latent_dim=128, |
| | max_seq_len=32, |
| | decoder_layers=2, |
| | dit_layers=2 |
| | ) |
| | |
| | |
| |
|
| | def test_ae_shape(self): |
| | print("\nTesting Autoencoder Shape...") |
| | model = LatentAutoencoder(self.cfg) |
| | input_ids = torch.randint(0, 100, (2, 32)) |
| | mask = torch.ones((2, 32)) |
| | logits, z = model(input_ids, mask) |
| | |
| | self.assertEqual(z.shape, (2, 32, 128)) |
| | |
| | self.assertEqual(logits.shape, (2, 32, 50265)) |
| | print("AE Shape Check Passed.") |
| |
|
| | def test_dit_shape(self): |
| | print("\nTesting DiT Shape...") |
| | model = FlowDiT(self.cfg) |
| | x = torch.randn(2, 32, 128) |
| | t = torch.rand(2) |
| | cond = torch.randn(2, 32, 128) |
| | |
| | out = model(x, t, condition=cond) |
| | self.assertEqual(out.shape, (2, 32, 128)) |
| | print("DiT Shape Check Passed.") |
| |
|
| | def test_cfg_forward(self): |
| | print("\nTesting CFG Forward...") |
| | model = FlowDiT(self.cfg) |
| | x = torch.randn(2, 32, 128) |
| | t = torch.rand(2) |
| | cond = torch.randn(2, 32, 128) |
| | |
| | out = model.forward_with_cfg(x, t, cond, cfg_scale=3.0) |
| | self.assertEqual(out.shape, (2, 32, 128)) |
| | print("CFG Check Passed.") |
| |
|
| | if __name__ == "__main__": |
| | unittest.main() |