| | import os |
| | from typing import Optional, Literal |
| | from types import ModuleType |
| | import enum |
| | from packaging import version |
| |
|
| | import torch |
| |
|
| | |
| | if version.parse(torch.__version__) >= version.parse("2.0.0"): |
| | SDP_IS_AVAILABLE = True |
| | else: |
| | SDP_IS_AVAILABLE = False |
| |
|
| | try: |
| | import xformers |
| | import xformers.ops |
| | XFORMERS_IS_AVAILBLE = True |
| | except: |
| | XFORMERS_IS_AVAILBLE = False |
| |
|
| |
|
| | class AttnMode(enum.Enum): |
| | SDP = 0 |
| | XFORMERS = 1 |
| | VANILLA = 2 |
| |
|
| |
|
| | class Config: |
| | xformers: Optional[ModuleType] = None |
| | attn_mode: AttnMode = AttnMode.VANILLA |
| |
|
| |
|
| | |
| | if SDP_IS_AVAILABLE: |
| | Config.attn_mode = AttnMode.SDP |
| | print(f"use sdp attention as default") |
| | elif XFORMERS_IS_AVAILBLE: |
| | Config.attn_mode = AttnMode.XFORMERS |
| | print(f"use xformers attention as default") |
| | else: |
| | print(f"both sdp attention and xformers are not available, use vanilla attention (very expensive) as default") |
| |
|
| | if XFORMERS_IS_AVAILBLE: |
| | Config.xformers = xformers |
| |
|
| |
|
| | |
| | ATTN_MODE = os.environ.get("ATTN_MODE", None) |
| | if ATTN_MODE is not None: |
| | assert ATTN_MODE in ["vanilla", "sdp", "xformers"] |
| | if ATTN_MODE == "sdp": |
| | assert SDP_IS_AVAILABLE |
| | Config.attn_mode = AttnMode.SDP |
| | elif ATTN_MODE == "xformers": |
| | assert XFORMERS_IS_AVAILBLE |
| | Config.attn_mode = AttnMode.XFORMERS |
| | else: |
| | Config.attn_mode = AttnMode.VANILLA |
| | print(f"set attention mode to {ATTN_MODE}") |
| | else: |
| | print("keep default attention mode") |
| |
|