| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| | import glob |
| | import os |
| | import re |
| | import subprocess |
| |
|
| |
|
| | |
| | |
| | DIFFUSERS_PATH = "src/diffusers" |
| | REPO_PATH = "." |
| |
|
| |
|
| | def _should_continue(line, indent): |
| | return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None |
| |
|
| |
|
| | def find_code_in_diffusers(object_name): |
| | """Find and return the code source code of `object_name`.""" |
| | parts = object_name.split(".") |
| | i = 0 |
| |
|
| | |
| | module = parts[i] |
| | while i < len(parts) and not os.path.isfile(os.path.join(DIFFUSERS_PATH, f"{module}.py")): |
| | i += 1 |
| | if i < len(parts): |
| | module = os.path.join(module, parts[i]) |
| | if i >= len(parts): |
| | raise ValueError(f"`object_name` should begin with the name of a module of diffusers but got {object_name}.") |
| |
|
| | with open( |
| | os.path.join(DIFFUSERS_PATH, f"{module}.py"), |
| | "r", |
| | encoding="utf-8", |
| | newline="\n", |
| | ) as f: |
| | lines = f.readlines() |
| |
|
| | |
| | indent = "" |
| | line_index = 0 |
| | for name in parts[i + 1 :]: |
| | while ( |
| | line_index < len(lines) and re.search(rf"^{indent}(class|def)\s+{name}(\(|\:)", lines[line_index]) is None |
| | ): |
| | line_index += 1 |
| | indent += " " |
| | line_index += 1 |
| |
|
| | if line_index >= len(lines): |
| | raise ValueError(f" {object_name} does not match any function or class in {module}.") |
| |
|
| | |
| | start_index = line_index |
| | while line_index < len(lines) and _should_continue(lines[line_index], indent): |
| | line_index += 1 |
| | |
| | while len(lines[line_index - 1]) <= 1: |
| | line_index -= 1 |
| |
|
| | code_lines = lines[start_index:line_index] |
| | return "".join(code_lines) |
| |
|
| |
|
| | _re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+diffusers\.(\S+\.\S+)\s*($|\S.*$)") |
| | _re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)") |
| | _re_fill_pattern = re.compile(r"<FILL\s+[^>]*>") |
| |
|
| |
|
| | def get_indent(code): |
| | lines = code.split("\n") |
| | idx = 0 |
| | while idx < len(lines) and len(lines[idx]) == 0: |
| | idx += 1 |
| | if idx < len(lines): |
| | return re.search(r"^(\s*)\S", lines[idx]).groups()[0] |
| | return "" |
| |
|
| |
|
| | def run_ruff(code): |
| | command = ["ruff", "format", "-", "--config", "pyproject.toml", "--silent"] |
| | process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE) |
| | stdout, _ = process.communicate(input=code.encode()) |
| | return stdout.decode() |
| |
|
| |
|
| | def stylify(code: str) -> str: |
| | """ |
| | Applies the ruff part of our `make style` command to some code. This formats the code using `ruff format`. |
| | As `ruff` does not provide a python api this cannot be done on the fly. |
| | |
| | Args: |
| | code (`str`): The code to format. |
| | |
| | Returns: |
| | `str`: The formatted code. |
| | """ |
| | has_indent = len(get_indent(code)) > 0 |
| | if has_indent: |
| | code = f"class Bla:\n{code}" |
| | formatted_code = run_ruff(code) |
| | return formatted_code[len("class Bla:\n") :] if has_indent else formatted_code |
| |
|
| |
|
| | def is_copy_consistent(filename, overwrite=False): |
| | """ |
| | Check if the code commented as a copy in `filename` matches the original. |
| | Return the differences or overwrites the content depending on `overwrite`. |
| | """ |
| | with open(filename, "r", encoding="utf-8", newline="\n") as f: |
| | lines = f.readlines() |
| | diffs = [] |
| | line_index = 0 |
| | |
| | while line_index < len(lines): |
| | search = _re_copy_warning.search(lines[line_index]) |
| | if search is None: |
| | line_index += 1 |
| | continue |
| |
|
| | |
| | indent, object_name, replace_pattern = search.groups() |
| | theoretical_code = find_code_in_diffusers(object_name) |
| | theoretical_indent = get_indent(theoretical_code) |
| |
|
| | start_index = line_index + 1 if indent == theoretical_indent else line_index + 2 |
| | indent = theoretical_indent |
| | line_index = start_index |
| |
|
| | |
| | should_continue = True |
| | while line_index < len(lines) and should_continue: |
| | line_index += 1 |
| | if line_index >= len(lines): |
| | break |
| | line = lines[line_index] |
| | should_continue = _should_continue(line, indent) and re.search(f"^{indent}# End copy", line) is None |
| | |
| | while len(lines[line_index - 1]) <= 1: |
| | line_index -= 1 |
| |
|
| | observed_code_lines = lines[start_index:line_index] |
| | observed_code = "".join(observed_code_lines) |
| |
|
| | |
| | theoretical_code = [line for line in theoretical_code.split("\n") if _re_copy_warning.search(line) is None] |
| | theoretical_code = "\n".join(theoretical_code) |
| |
|
| | |
| | if len(replace_pattern) > 0: |
| | patterns = replace_pattern.replace("with", "").split(",") |
| | patterns = [_re_replace_pattern.search(p) for p in patterns] |
| | for pattern in patterns: |
| | if pattern is None: |
| | continue |
| | obj1, obj2, option = pattern.groups() |
| | theoretical_code = re.sub(obj1, obj2, theoretical_code) |
| | if option.strip() == "all-casing": |
| | theoretical_code = re.sub(obj1.lower(), obj2.lower(), theoretical_code) |
| | theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code) |
| |
|
| | |
| | |
| | theoretical_code = stylify(lines[start_index - 1] + theoretical_code) |
| | theoretical_code = theoretical_code[len(lines[start_index - 1]) :] |
| |
|
| | |
| | if observed_code != theoretical_code: |
| | diffs.append([object_name, start_index]) |
| | if overwrite: |
| | lines = lines[:start_index] + [theoretical_code] + lines[line_index:] |
| | line_index = start_index + 1 |
| |
|
| | if overwrite and len(diffs) > 0: |
| | |
| | print(f"Detected changes, rewriting {filename}.") |
| | with open(filename, "w", encoding="utf-8", newline="\n") as f: |
| | f.writelines(lines) |
| | return diffs |
| |
|
| |
|
| | def check_copies(overwrite: bool = False): |
| | all_files = glob.glob(os.path.join(DIFFUSERS_PATH, "**/*.py"), recursive=True) |
| | diffs = [] |
| | for filename in all_files: |
| | new_diffs = is_copy_consistent(filename, overwrite) |
| | diffs += [f"- {filename}: copy does not match {d[0]} at line {d[1]}" for d in new_diffs] |
| | if not overwrite and len(diffs) > 0: |
| | diff = "\n".join(diffs) |
| | raise Exception( |
| | "Found the following copy inconsistencies:\n" |
| | + diff |
| | + "\nRun `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them." |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--fix_and_overwrite", |
| | action="store_true", |
| | help="Whether to fix inconsistencies.", |
| | ) |
| | args = parser.parse_args() |
| |
|
| | check_copies(args.fix_and_overwrite) |
| |
|