Daankular commited on
Commit
cf582e0
Β·
verified Β·
1 Parent(s): 54c1208

Upload scripts/momask_server.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/momask_server.py +204 -0
scripts/momask_server.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ momask_server.py
3
+ ────────────────────────────────────────────────────────────────────────────
4
+ Lightweight Flask inference server wrapping MoMask text-to-motion generation.
5
+ Runs on the Vast.ai instance. Exposes POST /generate β†’ [T, 263] JSON.
6
+
7
+ Does NOT require SMPL body models β€” only the MoMask VQ-VAE checkpoints.
8
+
9
+ Deploy
10
+ ──────
11
+ 1. Upload this file to /root/momask_server.py on the instance
12
+ 2. Install deps (see deploy_momask.sh)
13
+ 3. Run: python /root/momask_server.py --port 8765
14
+
15
+ Endpoint
16
+ ────────
17
+ POST /generate
18
+ Body: {"prompt": str, "num_frames": int, "seed": int}
19
+ Reply: {"motion": [[T, 263] as nested list], "num_frames": T, "fps": 20}
20
+ """
21
+ from __future__ import annotations
22
+ import argparse
23
+ import json
24
+ import os
25
+ import sys
26
+
27
+ import numpy as np
28
+
29
+ # ── Flask ──────────────────────────────────────────────────────────────────
30
+ try:
31
+ from flask import Flask, request, jsonify
32
+ except ImportError:
33
+ sys.exit("pip install flask")
34
+
35
+ app = Flask(__name__)
36
+
37
+ # ── Global model state ──────────────────────────────────────────────────────
38
+ _model = None
39
+ _mean = None
40
+ _std = None
41
+ _max_len = 196 # max HumanML3D frames (~9.8 s at 20 fps)
42
+
43
+
44
+ def _load_model(momask_root: str, device: str = "cuda"):
45
+ """Load MoMask model + normalisation stats into global state."""
46
+ global _model, _mean, _std
47
+
48
+ sys.path.insert(0, momask_root)
49
+
50
+ import torch
51
+ from models.mask_transformer.transformer import MaskTransformer
52
+ from options.get_eval_option import get_opt
53
+
54
+ # Load options from checkpoint directory
55
+ opt_path = os.path.join(momask_root, "checkpoints", "t2m", "t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns",
56
+ "opt.txt")
57
+ opt = get_opt(opt_path, device=device)
58
+
59
+ # Load normalisation stats (from the HumanML3D dataset)
60
+ stat_dir = os.path.join(momask_root, "checkpoints", "t2m",
61
+ "t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns")
62
+ _mean = np.load(os.path.join(stat_dir, "meta", "mean.npy"))
63
+ _std = np.load(os.path.join(stat_dir, "meta", "std.npy"))
64
+
65
+ # Load the transformer + VQ-VAE
66
+ from models.mask_transformer.transformer import MaskTransformer
67
+ from models.vq.model import RVQVAE
68
+ import options.option_transformer as option_trans
69
+
70
+ args = option_trans.get_args_parser()
71
+ args = args.parse_args([])
72
+ args.dataname = "t2m"
73
+ args.res_name = "ter1"
74
+ args.nb_code = 512
75
+ args.code_dim = 512
76
+ args.output_emb_width = 512
77
+ args.nb_joints = 22
78
+ args.window_size = 64
79
+ args.down_t = 2
80
+ args.stride_t = 2
81
+ args.width = 512
82
+ args.depth = 3
83
+ args.dilation_growth_rate = 3
84
+ args.vq_act = "relu"
85
+ args.vq_norm = None
86
+ args.num_quantizers = 6
87
+
88
+ net = RVQVAE(args,
89
+ 263,
90
+ args.nb_code,
91
+ args.code_dim,
92
+ args.output_emb_width,
93
+ args.down_t,
94
+ args.stride_t,
95
+ args.width,
96
+ args.depth,
97
+ args.dilation_growth_rate,
98
+ args.vq_act,
99
+ args.vq_norm)
100
+
101
+ # Load residual VQ-VAE weights
102
+ vqvae_ckpt = os.path.join(momask_root, "checkpoints", "t2m", "Comp_v6_KLD005",
103
+ "net_last.pth")
104
+ ckpt = torch.load(vqvae_ckpt, map_location="cpu")
105
+ net.load_state_dict(ckpt["net"], strict=True)
106
+ net.eval().to(device)
107
+
108
+ # Load mask transformer weights
109
+ trans_ckpt_dir = os.path.join(momask_root, "checkpoints", "t2m",
110
+ "t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns")
111
+ trans = MaskTransformer(code_dim=opt.code_dim,
112
+ cond_mode="text",
113
+ latent_dim=opt.latent_dim,
114
+ ff_size=opt.ff_size,
115
+ num_layers=opt.num_layers,
116
+ num_heads=opt.num_heads,
117
+ dropout=opt.dropout,
118
+ clip_dim=512,
119
+ cond_drop_prob=opt.cond_drop_prob,
120
+ clip_version=opt.clip_version,
121
+ opt=opt)
122
+ trans_ckpt = torch.load(os.path.join(trans_ckpt_dir, "net_last.pth"), map_location="cpu")
123
+ trans.load_state_dict(trans_ckpt["trans"], strict=True)
124
+ trans.eval().to(device)
125
+
126
+ _model = (net, trans, opt, device)
127
+ print(f"[momask_server] Model loaded on {device}")
128
+
129
+
130
+ def _generate(prompt: str, num_frames: int, seed: int) -> np.ndarray:
131
+ """Run MoMask inference; return denormalised [T, 263] array."""
132
+ import torch
133
+ from utils.motion_process import recover_from_ric
134
+
135
+ net, trans, opt, device = _model
136
+
137
+ if seed >= 0:
138
+ torch.manual_seed(seed)
139
+ np.random.seed(seed)
140
+
141
+ T = min(int(num_frames), _max_len)
142
+
143
+ with torch.no_grad():
144
+ # CLIP text encoding
145
+ from models.mask_transformer.transformer import MaskTransformer
146
+ cond_vector = trans.encode_text([prompt]) # [1, 77, 512]
147
+
148
+ # MoMask iterative decoding
149
+ mids = trans.generate(cond_vector, T // 4, temperature=1.0, topk_filter_thres=0.9,
150
+ gsample=True, force_mask=False) # [1, T//4, nb_code]
151
+
152
+ # Decode token sequence β†’ motion features via RVQVAE decoder
153
+ motion = net.forward_decoder(mids) # [1, T, 263]
154
+ motion = motion[0].cpu().numpy() # [T, 263]
155
+
156
+ # Denormalise
157
+ motion = motion * _std + _mean
158
+ return motion.astype(np.float32)
159
+
160
+
161
+ # ── Routes ────────────────────────────────────────────────────────────────────
162
+
163
+ @app.route("/health", methods=["GET"])
164
+ def health():
165
+ return jsonify({"status": "ok", "model_loaded": _model is not None})
166
+
167
+
168
+ @app.route("/generate", methods=["POST"])
169
+ def generate():
170
+ body = request.get_json(force=True)
171
+ prompt = body.get("prompt", "a person walks forward")
172
+ num_frames = int(body.get("num_frames", 120))
173
+ seed = int(body.get("seed", -1))
174
+
175
+ if _model is None:
176
+ return jsonify({"error": "model not loaded"}), 503
177
+
178
+ try:
179
+ motion = _generate(prompt, num_frames, seed)
180
+ return jsonify({
181
+ "motion": motion.tolist(),
182
+ "num_frames": int(motion.shape[0]),
183
+ "fps": 20,
184
+ "prompt": prompt,
185
+ })
186
+ except Exception as e:
187
+ return jsonify({"error": str(e)}), 500
188
+
189
+
190
+ # ── Entry point ───────────────────────────────────────────────────────────────
191
+
192
+ if __name__ == "__main__":
193
+ parser = argparse.ArgumentParser()
194
+ parser.add_argument("--momask-root", default="/root/momask-codes")
195
+ parser.add_argument("--port", type=int, default=8765)
196
+ parser.add_argument("--device", default="cuda")
197
+ parser.add_argument("--host", default="0.0.0.0")
198
+ args = parser.parse_args()
199
+
200
+ print(f"[momask_server] Loading model from {args.momask_root} ...")
201
+ _load_model(args.momask_root, args.device)
202
+
203
+ print(f"[momask_server] Listening on {args.host}:{args.port}")
204
+ app.run(host=args.host, port=args.port, threaded=False)