FUSE / run.py
pihaisun's picture
code
6a89de0
import argparse
import cv2
import matplotlib
import numpy as np
import os
import torch
import torch.nn.functional as F
# from model.FUSE_dis import FUSE
from model.FUSE import FUSE
from model.fuse.utils import clean_pretrained_weight
from dataset.mvsec import MVSEC
from dataset.eventscape import EventScape
from dataset.dense import Dense
from torch.utils.data import DataLoader
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="EPDE")
parser.add_argument("--input-size", type=int)
parser.add_argument("--encoder", type=str, choices=["vits", "vitb", "vitl", "vitg"])
parser.add_argument("--max-depth", type=float, default=20)
parser.add_argument("--event-voxel-chans", type=int)
parser.add_argument("--outdir", type=str)
parser.add_argument("--dataset", choices=["mvsec", "eventscape", "dense"])
parser.add_argument("--save-numpy", action="store_true")
parser.add_argument("--pred-only", action="store_true")
parser.add_argument("--grayscale", action="store_true")
parser.add_argument("--normalized-depth", action="store_true")
parser.add_argument("--return-feature", action="store_true")
parser.add_argument("--load-from", type=str)
parser.add_argument(
"--scene",
choices=["day1", "night1", "train", "test"],
)
args = parser.parse_args()
size = (args.input_size, args.input_size)
if args.dataset == "mvsec" and args.scene == "day1":
valset = MVSEC(
"dataset/splits/mvsec/outdoor_day1.txt",
"val",
normalized_d=args.normalized_depth,
size=size,
)
elif args.dataset == "mvsec" and args.scene == "night1":
valset = MVSEC(
"dataset/splits/mvsec/outdoor_night1.txt",
"val",
normalized_d=args.normalized_depth,
size=size,
)
elif args.dataset == "mvsec" and args.scene == "train":
valset = MVSEC(
"dataset/splits/mvsec/train.txt",
"val",
normalized_d=args.normalized_depth,
size=size,
)
elif args.dataset == "eventscape" and args.scene == "test":
valset = EventScape(
"dataset/splits/eventscape/test.txt",
"val",
normalized_d=args.normalized_depth,
size=size,
)
elif args.dataset == "dense" and args.scene == "test":
valset = Dense(
"dataset/splits/dense/test.txt",
"val",
normalized_d=args.normalized_depth,
size=size,
)
else:
raise NotImplementedError
valloader = DataLoader(
valset,
batch_size=1,
pin_memory=True,
num_workers=4,
drop_last=True,
)
DEVICE = (
"cuda"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)
# Instantiate Model and Load Pretrained Weight
model = FUSE(
model_name=args.encoder,
max_depth=args.max_depth,
event_voxel_chans=args.event_voxel_chans,
return_feature=args.return_feature,
)
model.eval()
checkpoint = torch.load(args.load_from, map_location="cpu")
checkpoint = clean_pretrained_weight(checkpoint)
model.load_state_dict(checkpoint)
print(f"Model weights load from {args.load_from} successfully!")
model = model.to(DEVICE).eval()
os.makedirs(args.outdir, exist_ok=True)
npy_dir = os.path.join(args.outdir, "npy")
os.makedirs(npy_dir, exist_ok=True)
vis_dir = os.path.join(args.outdir, "vis")
os.makedirs(vis_dir, exist_ok=True)
cmap = matplotlib.colormaps.get_cmap("Spectral")
for k, sample in enumerate(valloader):
rgb_path = sample["image_path"][0]
print(f"Progress {k+1}/{len(valloader)}: {rgb_path}")
raw_image = cv2.imread(rgb_path)
h, w = raw_image.shape[0], raw_image.shape[1]
with torch.no_grad():
inputs = sample["input"].to(DEVICE)
depth = model(inputs)
depth = F.interpolate(
depth[:, None], (h, w), mode="bilinear", align_corners=True
)[0, 0]
depth = depth.cpu().numpy()
# scene = rgb_path.split("/")[-3]
# name = rgb_path.split("/")[-1]
if args.save_numpy:
output_path = os.path.join(
npy_dir,
os.path.splitext(os.path.basename(rgb_path))[0] + ".npy",
# os.path.splitext(f"{scene}_{name}")[0] + ".npy",
)
np.save(output_path, depth)
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
depth = depth.astype(np.uint8)
if args.grayscale:
depth = np.repeat(depth[..., np.newaxis], 3, axis=-1)
else:
depth = (cmap(depth)[:, :, :3] * 255)[:, :, ::-1].astype(np.uint8)
output_path = os.path.join(
# vis_dir, f"{scene}_{name}"
vis_dir,
os.path.splitext(os.path.basename(rgb_path))[0] + ".png",
# vis_dir, os.path.basename(rgb_path) + ".png"
)
if args.pred_only:
cv2.imwrite(output_path, depth)
else:
split_region = np.ones((raw_image.shape[0], 50, 3), dtype=np.uint8) * 255
combined_result = cv2.hconcat([raw_image, split_region, depth])
cv2.imwrite(output_path, combined_result)