| import argparse |
| import random |
| import os |
| from datetime import datetime |
|
|
| import numpy as np |
| import torch |
| import json |
| from torch.optim import AdamW |
| from torchvision.transforms import v2 |
| from torch.utils.tensorboard import SummaryWriter |
| from tqdm import tqdm |
| from transformers import get_constant_schedule_with_warmup |
|
|
| from Datasets import CCDataset, Batcher |
| from model import ICCModel |
| from utils import get_vocabulary |
| from Loss import InfoNCELoss |
| from eval import captioning, retrieve, plot |
| from huggingface_hub import hf_hub_download |
| import open_clip |
|
|
|
|
| def train(args, model, train_loader, valid_loader, device, infonce, optim, scheduler, writer): |
| step = 0 |
| best_score = float("inf") |
| best_model = None |
|
|
| for epoch in range(args.epochs): |
| model.train() |
|
|
| for batch in tqdm(train_loader, desc='Epoch ' + str(epoch)): |
| imgs1 = batch['images_before'].to(device) |
| imgs2 = batch['images_after'].to(device) |
| toks = batch['input_ids'].to(device) |
| labs = batch['labels'].to(device) |
| flags = batch['flags'].to(device) |
| attention_mask = batch['pad_mask'].to(device) |
| embs = batch['embs'].to(device) |
|
|
| cap_loss, vis_emb, text_emb, _, _, _ = model(imgs1, imgs2, toks, labs, attention_mask) |
|
|
| con_loss, num_pos = infonce(vis_emb, text_emb, flags, embs) |
| loss = cap_loss + args.lamb * con_loss |
| loss.backward() |
|
|
| if args.max_grad_norm: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) |
| grad = torch.norm(torch.stack( |
| [torch.norm(p.grad.detach()).to(device) for p in model.parameters() if p.grad is not None])) |
|
|
| optim.step() |
| scheduler.step() |
| optim.zero_grad() |
|
|
| writer.add_scalar('train_loss', loss.item(), step) |
| writer.add_scalar('grad', grad, step) |
| writer.add_scalar('lr', scheduler.get_last_lr()[0], step) |
|
|
| step += 1 |
|
|
| torch.save(model.state_dict(), args.output_path + 'model_{}.pt'.format(step)) |
|
|
| model.eval() |
| with torch.no_grad(): |
| eval_losses = torch.empty(0) |
| for batch in tqdm(valid_loader, desc='Validation ' + str(epoch)): |
| imgs1 = batch['images_before'].to(device) |
| imgs2 = batch['images_after'].to(device) |
| toks = batch['input_ids'].to(device) |
| labs = batch['labels'].to(device) |
| flags = batch['flags'].to(device) |
| attention_mask = batch['pad_mask'].to(device) |
| embs = batch['embs'].to(device) |
|
|
| cap_loss, vis_emb, text_emb, _, _, _ = model(imgs1, imgs2, toks, labs, attention_mask) |
|
|
| con_loss, _ = infonce(vis_emb, text_emb, flags, embs) |
| loss = cap_loss + args.lamb * con_loss |
| eval_losses = torch.cat([eval_losses, loss.cpu().unsqueeze(0)]) |
|
|
| eval_score = torch.mean(eval_losses) |
| writer.add_scalar('eval_score', eval_score, step) |
|
|
| is_best = eval_score < best_score |
| best_score = min(eval_score, best_score) |
| if is_best: |
| best_model = step |
|
|
| if best_model is not None: |
| state_dict = torch.load(os.path.join(args.output_path + 'model_{}.pt'.format(best_model)), map_location=device) |
| torch.save(state_dict, args.output_path + 'model_best.pt') |
|
|
|
|
| def run(args, config): |
| print('Initializing...') |
| torch.manual_seed(args.seed) |
| np.random.seed(args.seed) |
| random.seed(args.seed) |
| torch.backends.cudnn.deterministic = True |
|
|
| device = torch.device('cpu') |
| if torch.cuda.is_available(): |
| device = torch.device('cuda') |
|
|
| dt_str = datetime.now().strftime("%d-%m-%Y-%H-%M-%S") |
| writer_path = args.output_path + dt_str |
| writer = SummaryWriter(writer_path) |
|
|
| if os.path.exists(args.vocab): |
| with open(args.vocab, 'r') as infile: |
| vocab = json.load(infile) |
| else: |
| vocab = get_vocabulary(args.annotation_json, args.vocab) |
|
|
| clip = None |
| preprocess = v2.Compose([ |
| v2.ToImage(), |
| v2.ToDtype(torch.float32, scale=True), |
| v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
|
|
| if 'resnet' not in config['backbone']: |
| checkpoint_path = hf_hub_download("chendelong/RemoteCLIP", |
| f"RemoteCLIP-{config['backbone']}.pt", |
| cache_dir=args.pretrained) |
|
|
| clip, _, preprocess = open_clip.create_model_and_transforms(config['backbone']) |
| ckpt = torch.load(checkpoint_path, map_location="cpu") |
| clip.load_state_dict(ckpt) |
|
|
| model = ICCModel(device, clip, config['backbone'], config['d_model'], |
| len(vocab), config['max_len'], config['num_heads'], config['h_dim'], config['a_dim'], |
| config['encoder_layers'], config['decoder_layers'], config['dropout'], |
| learnable=config['learnable'], fine_tune=config['fine_tune'], |
| tie_embeddings=config['tie_embeddings'], prenorm=config['prenorm']) |
| model = model.to(device) |
| del clip |
|
|
| print('Loading...') |
| training_set = CCDataset(args.annotation_json, args.image_dir, vocab, preprocess, 'train', config['max_len'], |
| config['s-transformers'], device) |
| valid_set = CCDataset(args.annotation_json, args.image_dir, vocab, preprocess, 'val', config['max_len'], |
| config['s-transformers'], device) |
| test_set = CCDataset(args.annotation_json, args.image_dir, vocab, preprocess, 'test', config['max_len'], |
| config['s-transformers'], device) |
|
|
| train_loader = Batcher(training_set, args.batch_size, config['max_len'], device, args.hd, model=model, shuffle=True) |
| valid_loader = Batcher(valid_set, args.batch_size, config['max_len'], device) |
| test_loader = Batcher(test_set, 1, config['max_len'], device) |
|
|
| print('Training...') |
| infonce = InfoNCELoss(device, k=args.k, temperature=args.temperature, threshold=config['s-threshold'], |
| fna=config['fna']) |
| optim = AdamW([x for x in model.parameters() if x.requires_grad], lr=args.learning_rate, eps=args.adam_epsilon) |
| scheduler = get_constant_schedule_with_warmup(optim, |
| num_warmup_steps=args.warmup_steps * len(train_loader) * args.epochs) |
| train(args, model, train_loader, valid_loader, device, infonce, optim, scheduler, writer) |
|
|
| print('Final evaluation...') |
| model.load_state_dict(torch.load(os.path.join(args.output_path, 'model_best.pt'), map_location=device)) |
| results = captioning(args, config, model, test_loader, vocab, device) |
| retrieve(args, config, model, test_loader, device) |
| plot(args, model.encoder.encoder.feat_size, results) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--annotation_json', type=str, default='../input/Levir_CC/LevirCCcaptions.json') |
| parser.add_argument('--image_dir', type=str, default='../input/Levir_CC/images/') |
| parser.add_argument('--vocab', type=str, default='../input/levir_vocab.json') |
| parser.add_argument('--pretrained', type=str, default='../../input/checkpoints') |
| parser.add_argument('--config', type=str, default='../config.json') |
| parser.add_argument('--output_path', type=str, default='../output/') |
|
|
| parser.add_argument('--epochs', type=int, default=50) |
| parser.add_argument('--batch_size', type=int, default=4) |
| parser.add_argument('--k', type=int, default=-1) |
| parser.add_argument('--hd', type=int, default=-1) |
| parser.add_argument('--learning_rate', type=float, default=1e-4) |
| parser.add_argument('--warmup_steps', type=float, default=0.025) |
| parser.add_argument('--lr_decay', type=float, default=0.7) |
| parser.add_argument('--adam_epsilon', type=float, default=1e-8) |
| parser.add_argument('--max_grad_norm', type=float, default=None) |
| parser.add_argument('--temperature', type=float, default=0.01) |
| parser.add_argument('--lamb', type=float, default=0.5) |
| parser.add_argument('--seed', type=int, default=42) |
|
|
| args = parser.parse_args() |
|
|
| with open(args.config, 'r') as config_file: |
| config = json.load(config_file) |
|
|
| run(args, config) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|