Daankular commited on
Commit
74c8926
Β·
verified Β·
1 Parent(s): cf582e0

Upload scripts/patch_pshuman_vram.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/patch_pshuman_vram.py +126 -0
scripts/patch_pshuman_vram.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ patch_pshuman_vram.py
3
+ =====================
4
+ Apply VRAM-reduction optimisations to /root/PSHuman/inference.py.
5
+
6
+ Patches applied:
7
+ 1. load_pshuman_pipeline β€” adds VAE slicing + CPU offload on top of
8
+ the existing fp16 + xformers that are already in the file.
9
+ 2. run_inference β€” adds torch.cuda.empty_cache() after the pipeline
10
+ call so fragmented VRAM is reclaimed between multi-view denoising.
11
+
12
+ fp32 @ 768 res β‰ˆ 40 GB. fp16 β‰ˆ 20 GB. fp16 + xformers β‰ˆ 16-18 GB.
13
+ fp16 + xformers + VAE slicing + CPU offload β‰ˆ 14-16 GB peak β†’ fits 24 GB.
14
+
15
+ Run:
16
+ /root/miniconda/envs/pshuman/bin/python /root/MeshForge/scripts/patch_pshuman_vram.py
17
+ """
18
+ import pathlib, sys
19
+
20
+ TARGET = pathlib.Path("/root/PSHuman/inference.py")
21
+ if not TARGET.exists():
22
+ sys.exit(f"ERROR: {TARGET} not found β€” run after PSHuman is cloned")
23
+
24
+ src = TARGET.read_text()
25
+ original = src # keep a backup reference
26
+
27
+ # ─────────────────────────────────────────────────────────────────
28
+ # Patch 1: load_pshuman_pipeline
29
+ # ─────────────────────────────────────────────────────────────────
30
+ OLD_LOAD = """\
31
+ def load_pshuman_pipeline(cfg):
32
+ pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(cfg.pretrained_model_name_or_path, torch_dtype=weight_dtype)
33
+ pipeline.unet.enable_xformers_memory_efficient_attention()
34
+ if torch.cuda.is_available():
35
+ pipeline.to('cuda')
36
+ return pipeline"""
37
+
38
+ NEW_LOAD = """\
39
+ def load_pshuman_pipeline(cfg):
40
+ pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
41
+ cfg.pretrained_model_name_or_path,
42
+ torch_dtype=weight_dtype, # float16 β€” halves VRAM vs fp32
43
+ )
44
+
45
+ # xformers: reduces peak VRAM during multi-head denoising attention
46
+ try:
47
+ pipeline.unet.enable_xformers_memory_efficient_attention()
48
+ print("[PSHuman] xformers memory-efficient attention enabled")
49
+ except Exception as _xe:
50
+ print(f"[PSHuman] xformers unavailable ({_xe}) β€” falling back to attention slicing")
51
+ pipeline.unet.enable_attention_slicing(1)
52
+
53
+ # VAE slicing: prevents OOM when decoding a 7-view 768-res batch at once
54
+ if hasattr(pipeline, "enable_vae_slicing"):
55
+ pipeline.enable_vae_slicing()
56
+ print("[PSHuman] VAE slicing enabled")
57
+
58
+ # CPU offload: idle pipeline components (text encoder, VAE, safety checker)
59
+ # move to RAM when not actively used, freeing ~3-4 GB of static VRAM.
60
+ # pipeline() is called via standard diffusers __call__, so hooks work.
61
+ if torch.cuda.is_available():
62
+ try:
63
+ pipeline.enable_model_cpu_offload()
64
+ print("[PSHuman] model CPU offload enabled")
65
+ except Exception as _oe:
66
+ print(f"[PSHuman] CPU offload unavailable ({_oe}) β€” loading to CUDA directly")
67
+ pipeline.to("cuda")
68
+
69
+ return pipeline"""
70
+
71
+ if OLD_LOAD in src:
72
+ src = src.replace(OLD_LOAD, NEW_LOAD)
73
+ print("[patch 1] load_pshuman_pipeline β€” VRAM optimisations applied")
74
+ elif "enable_vae_slicing" in src:
75
+ print("[patch 1] load_pshuman_pipeline β€” already patched, skipping")
76
+ else:
77
+ # Looser match for minor whitespace/version differences
78
+ import re
79
+ m = re.search(
80
+ r'def load_pshuman_pipeline\(cfg\):.*?return pipeline',
81
+ src, re.DOTALL
82
+ )
83
+ if m:
84
+ src = src[:m.start()] + NEW_LOAD + src[m.end():]
85
+ print("[patch 1] load_pshuman_pipeline β€” applied via regex fallback")
86
+ else:
87
+ print("[patch 1] WARNING: could not locate load_pshuman_pipeline β€” skipping")
88
+
89
+ # ─────────────────────────────────────────────────────────────────
90
+ # Patch 2: empty CUDA cache after pipeline call in run_inference
91
+ # ─────────────────────────────────────────────────────────────────
92
+ # Insert torch.cuda.empty_cache() right after the pipeline __call__ block.
93
+ # The existing code already has `torch.cuda.empty_cache()` at the bottom of
94
+ # the batch loop β€” so only add if it's missing near the unet_out line.
95
+
96
+ OLD_CACHE_ANCHOR = """\
97
+ with torch.autocast("cuda"):
98
+ # B*Nv images
99
+ guidance_scale = cfg.validation_guidance_scales
100
+ unet_out = pipeline("""
101
+
102
+ NEW_CACHE_ANCHOR = """\
103
+ torch.cuda.empty_cache() # free fragmented VRAM before denoising
104
+ with torch.autocast("cuda"):
105
+ # B*Nv images
106
+ guidance_scale = cfg.validation_guidance_scales
107
+ unet_out = pipeline("""
108
+
109
+ if OLD_CACHE_ANCHOR in src and "empty_cache() # free fragmented" not in src:
110
+ src = src.replace(OLD_CACHE_ANCHOR, NEW_CACHE_ANCHOR)
111
+ print("[patch 2] run_inference β€” added pre-denoising cache flush")
112
+ else:
113
+ print("[patch 2] run_inference β€” cache flush already present or anchor not found, skipping")
114
+
115
+ # ─────────────────────────────────────────────────────────────────
116
+ # Write back only if changed
117
+ # ─────────────────────────────────────────────────────────────────
118
+ if src != original:
119
+ backup = TARGET.with_suffix(".py.orig")
120
+ if not backup.exists():
121
+ backup.write_text(original)
122
+ print(f"[patch] Backup saved β†’ {backup}")
123
+ TARGET.write_text(src)
124
+ print(f"[patch] Written β†’ {TARGET}")
125
+ else:
126
+ print("[patch] No changes made.")