Daankular commited on
Commit
f06ae6c
·
verified ·
1 Parent(s): 7b514c3

Upload scripts/texture_i2tex.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/texture_i2tex.py +120 -0
scripts/texture_i2tex.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+
5
+ import torch
6
+ from torchvision import transforms
7
+ from transformers import AutoModelForImageSegmentation
8
+
9
+ from mvadapter.pipelines.pipeline_texture import ModProcessConfig, TexturePipeline
10
+ from mvadapter.utils import make_image_grid
11
+
12
+ if __name__ == "__main__":
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("--device", type=str, default="cuda")
15
+ parser.add_argument("--variant", type=str, default="sdxl", choices=["sdxl", "sd21"])
16
+ # I/O
17
+ parser.add_argument("--mesh", type=str, required=True)
18
+ parser.add_argument("--image", type=str, required=True)
19
+ parser.add_argument("--text", type=str, default="high quality")
20
+ parser.add_argument("--seed", type=int, default=-1)
21
+ parser.add_argument("--save_dir", type=str, default="./output")
22
+ parser.add_argument("--save_name", type=str, default="i2tex_sample")
23
+ # Extra
24
+ parser.add_argument("--reference_conditioning_scale", type=float, default=1.0)
25
+ parser.add_argument("--preprocess_mesh", action="store_true")
26
+ parser.add_argument("--remove_bg", action="store_true")
27
+ args = parser.parse_args()
28
+
29
+ if args.variant == "sdxl":
30
+ from .inference_ig2mv_sdxl import prepare_pipeline, remove_bg, run_pipeline
31
+
32
+ base_model = "stabilityai/stable-diffusion-xl-base-1.0"
33
+ vae_model = "madebyollin/sdxl-vae-fp16-fix"
34
+ height = width = 768
35
+ uv_size = 4096
36
+ elif args.variant == "sd21":
37
+ from .inference_ig2mv_sd import prepare_pipeline, remove_bg, run_pipeline
38
+
39
+ base_model = "stabilityai/stable-diffusion-2-1-base"
40
+ vae_model = None
41
+ height = width = 512
42
+ uv_size = 2048
43
+ else:
44
+ raise ValueError(f"Invalid variant: {args.variant}")
45
+
46
+ device = args.device
47
+ num_views = 6
48
+
49
+ # Prepare pipelines
50
+ pipe = prepare_pipeline(
51
+ base_model=base_model,
52
+ vae_model=vae_model,
53
+ unet_model=None,
54
+ lora_model=None,
55
+ adapter_path="huanngzh/mv-adapter",
56
+ scheduler=None,
57
+ num_views=num_views,
58
+ device=device,
59
+ dtype=torch.float16,
60
+ )
61
+ if args.remove_bg:
62
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
63
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
64
+ )
65
+ birefnet.to(args.device)
66
+ transform_image = transforms.Compose(
67
+ [
68
+ transforms.Resize((1024, 1024)),
69
+ transforms.ToTensor(),
70
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
71
+ ]
72
+ )
73
+ remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, args.device)
74
+ else:
75
+ remove_bg_fn = None
76
+
77
+ texture_pipe = TexturePipeline(
78
+ upscaler_ckpt_path="./checkpoints/RealESRGAN_x2plus.pth",
79
+ inpaint_ckpt_path=None,
80
+ device=device,
81
+ )
82
+ print("Pipeline ready.")
83
+
84
+ os.makedirs(args.save_dir, exist_ok=True)
85
+
86
+ # 1. run MV-Adapter to generate multi-view images
87
+ images, _, _, _ = run_pipeline(
88
+ pipe,
89
+ mesh_path=args.mesh,
90
+ num_views=num_views,
91
+ text=args.text,
92
+ image=args.image,
93
+ height=height,
94
+ width=width,
95
+ num_inference_steps=50,
96
+ guidance_scale=3.0,
97
+ seed=args.seed,
98
+ reference_conditioning_scale=args.reference_conditioning_scale,
99
+ negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
100
+ device=device,
101
+ remove_bg_fn=remove_bg_fn,
102
+ )
103
+ mv_path = os.path.join(args.save_dir, f"{args.save_name}.png")
104
+ make_image_grid(images, rows=1).save(mv_path)
105
+
106
+ torch.cuda.empty_cache()
107
+
108
+ # 2. un-project and complete texture
109
+ out = texture_pipe(
110
+ mesh_path=args.mesh,
111
+ save_dir=args.save_dir,
112
+ save_name=args.save_name,
113
+ uv_unwarp=True,
114
+ preprocess_mesh=args.preprocess_mesh,
115
+ uv_size=uv_size,
116
+ rgb_path=mv_path,
117
+ rgb_process_config=ModProcessConfig(view_upscale=True, inpaint_mode="uv"),
118
+ camera_azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
119
+ )
120
+ print(f"Output saved to {out.shaded_model_save_path}")