| import typer |
| import torch |
| import subprocess |
| from pathlib import Path |
|
|
| from expert import UpstreamExpert |
|
|
| SUBMISSION_FILES = ["expert.py", "model.pt"] |
| SAMPLE_RATE = 16000 |
| SECONDS = [2, 1.8, 3.7] |
|
|
| app = typer.Typer() |
|
|
| @app.command() |
| def validate(): |
| |
| for file in SUBMISSION_FILES: |
| if not Path(file).is_file(): |
| raise ValueError(f"File {file} not found! Please include {file} in your submission") |
|
|
| try: |
| upstream = UpstreamExpert(ckpt="model.pt") |
| samples = [round(SAMPLE_RATE * sec) for sec in SECONDS] |
| wavs = [torch.rand(sample) for sample in samples] |
| results = upstream(wavs) |
|
|
| assert isinstance(results, dict) |
| tasks = ["PR", "SID", "ER", "ASR", "ASV", "SD", "QbE", "ST", "SS", "SE", "secret"] |
| for task in tasks: |
| hidden_states = results.get(task, results["hidden_states"]) |
| assert isinstance(hidden_states, list) |
|
|
| for state in hidden_states: |
| assert isinstance(state, torch.Tensor) |
| assert state.dim() == 3, "(batch_size, max_sequence_length_of_batch, hidden_size)" |
| assert state.shape == hidden_states[0].shape |
|
|
| downsample_rate = upstream.get_downsample_rates(task) |
| assert isinstance(downsample_rate, int) |
| assert abs(round(max(samples) / downsample_rate) - hidden_states[0].size(1)) < 5, "wrong downsample rate" |
|
|
| except: |
| print("Please check the Upstream Specification on https://superbbenchmark.org/challenge-slt2022/upstream") |
| raise |
|
|
| typer.echo("All submission files validated!") |
| typer.echo("Now you can upload these files to huggingface's Hub.") |
|
|
|
|
| @app.command() |
| def upload(commit_message: str): |
| subprocess.call("git pull origin main".split()) |
| subprocess.call(["git", "add", "."]) |
| subprocess.call(["git", "commit", "-m", f"Upload Upstream: {commit_message} "]) |
| subprocess.call(["git", "push"]) |
| typer.echo("Upload successful!") |
| typer.echo("Please go to https://superbbenchmark.org/submit to make a submission with the following information:") |
| typer.echo("1. Organization Name") |
| typer.echo("2. Repository Name") |
| typer.echo("3. Commit Hash (full 40 characters)") |
| typer.echo("These information can be shown by: python cli.py info") |
|
|
| @app.command() |
| def info(): |
| result = subprocess.run(["git", "config", "--get", "remote.origin.url"], capture_output=True) |
| url = result.stdout.decode("utf-8").strip() |
| organization = url.split("/")[-2] |
| repo = url.split("/")[-1] |
|
|
| result = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True) |
| commit_hash = result.stdout.decode("utf-8").strip() |
|
|
| typer.echo(f"Organization Name: {organization}") |
| typer.echo(f"Repository Name: {repo}") |
| typer.echo(f"Commit Hash: {commit_hash}") |
|
|
| if __name__ == "__main__": |
| app() |
|
|