miqa
miqa / train.py
xiaoqi-wang's picture
Upload train.py with huggingface_hub
4e62d30 verified
import numpy as np
import os
import random
import shutil
import time
import warnings
from enum import Enum
import pandas as pd
import torch
import logging
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torch.distributed as dist
import torch.multiprocessing as mp
import torchmetrics
from torch.utils.tensorboard import SummaryWriter
from configs.args_base import get_args
from utils.losses import build_loss_function
from utils.optimizer import build_optimizer
from utils.lr_scheduler import build_scheduler
from utils.logger import build_logger
from utils.misc import setup_seed, reduce_tensor, save_checkpoint
from data import build_dataloader
from models.MIQA_base import get_torch_model, get_timm_model
from models.RA_MIQA import RegionVisionTransformer
best_srcc = best_plcc = best_klcc = 0.
def main(args):
if args.seed is not None:
setup_seed(args.seed)
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
if args.gpu is not None:
warnings.warn('You have chosen a specific GPU. This will completely '
'disable data parallelism.')
if args.dist_url == "env://" and args.world_size == -1:
args.world_size = int(os.environ["WORLD_SIZE"])
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
if torch.cuda.is_available():
ngpus_per_node = torch.cuda.device_count()
if ngpus_per_node == 1 and args.dist_backend == "nccl":
warnings.warn(
"nccl backend >=2.5 requires GPU count>1, see https://github.com/NVIDIA/nccl/issues/103 perhaps use 'gloo'")
else:
ngpus_per_node = 1
if args.multiprocessing_distributed:
# Since we have ngpus_per_node processes per node, the total world_size
# needs to be adjusted accordingly
args.world_size = ngpus_per_node * args.world_size
# Use torch.multiprocessing.spawn to launch distributed processes: the
# main_worker process function
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
else:
# Simply call main_worker function
main_worker(args.gpu, ngpus_per_node, args)
def main_worker(gpu, ngpus_per_node, args):
global best_srcc, best_plcc, best_klcc
args.gpu = gpu
args.ngpus_per_node = ngpus_per_node
if args.distributed:
if args.dist_url == "env://" and args.rank == -1:
args.rank = int(os.environ["RANK"])
if args.multiprocessing_distributed:
# For multiprocessing distributed training, rank needs to be the
# global rank among all the processes
args.rank = args.rank * ngpus_per_node + gpu
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
logger = build_logger(
output_dir=args.output_dir,
log_file='{}_train.log'.format(args.run_name),
rank=args.rank if args.distributed else None,
level=logging.INFO,
console_level=logging.INFO if args.rank in [0, None] else logging.WARNING,
file_level=logging.INFO
)
if args.gpu is not None:
logger.info("Use GPU: {} for training".format(args.gpu))
# create model
if args.arch.startswith('RA_'):
model = RegionVisionTransformer(
base_model_name = 'vit_small_patch16_224',
pretrained = True,
mmseg_config_path = 'models/model_configs/fcn_sere-small_finetuned_fp16_8x32_224x224_3600_imagenets919.py',
checkpoint_path = 'models/checkpoints/sere_finetuned_vit_small_ep100.pth',
auto_download = True,
force_download = False
)
else:
try:
logger.info(f"Loading model form torchvision {args.arch}")
model = get_torch_model(model_name=args.arch, pretrained=args.pretrained, num_classes=1)
except:
logger.info(f"Loading model form timm {args.arch}")
model = get_timm_model(model_name=args.arch, pretrained=args.pretrained, num_classes=1)
if not torch.cuda.is_available() and not torch.backends.mps.is_available():
logger.info('using CPU, this will be slow')
elif args.distributed:
# For multiprocessing distributed, DistributedDataParallel constructor
# should always set the single device scope, otherwise,
# DistributedDataParallel will use all available devices.
if torch.cuda.is_available():
if args.gpu is not None:
torch.cuda.set_device(args.gpu)
model.cuda(args.gpu)
# When using a single GPU per process and per
# DistributedDataParallel, we need to divide the batch size
# ourselves based on the total number of GPUs of the current node.
args.batch_size = int(args.batch_size / ngpus_per_node)
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
else:
model.cuda()
# DistributedDataParallel will divide and allocate batch_size to all
# available GPUs if device_ids are not set
model = torch.nn.parallel.DistributedDataParallel(model)
elif args.gpu is not None and torch.cuda.is_available():
torch.cuda.set_device(args.gpu)
model = model.cuda(args.gpu)
elif torch.backends.mps.is_available():
device = torch.device("mps")
model = model.to(device)
else:
# DataParallel will divide and allocate batch_size to all available GPUs
if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
model.features = torch.nn.DataParallel(model.features)
model.cuda()
else:
model = torch.nn.DataParallel(model).cuda()
if torch.cuda.is_available():
if args.gpu:
device = torch.device('cuda:{}'.format(args.gpu))
else:
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
# Data loading
train_dataset, val_dataset = build_dataloader.build_dataset(args)
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=False)
else:
train_sampler = None
val_sampler = None
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True, sampler=val_sampler)
# define loss function (criterion), optimizer, and learning rate scheduler
criterion = build_loss_function(loss_name=args.loss_name)
optimizer = build_optimizer(args, model)
scheduler = build_scheduler(args, optimizer, len(train_loader))
# optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
logger.info("=> loading checkpoint '{}'".format(args.resume))
if args.gpu is None:
checkpoint = torch.load(args.resume)
elif torch.cuda.is_available():
# Map model to be loaded to specified single gpu.
loc = 'cuda:{}'.format(args.gpu)
checkpoint = torch.load(args.resume, map_location=loc)
args.start_epoch = checkpoint['epoch']
best_srcc = checkpoint['best_srcc']
if args.gpu is not None:
# best_acc1 may be from a checkpoint from a different GPU
best_srcc = best_srcc.to(args.gpu)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])
logger.info("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
logger.info("=> no checkpoint found at '{}'".format(args.resume))
# evaluate on validation set
if args.eval_only:
validate(val_loader, model, criterion, args)
return
writer = SummaryWriter(log_dir=os.path.join(args.output_dir, 'tensorboard_logs', args.run_name))
for epoch in range(args.start_epoch, args.epochs+1):
if args.distributed:
train_sampler.set_epoch(epoch)
# train for one epoch
best_srcc, best_plcc, best_klcc = train_one_epoch(train_loader, model, criterion, optimizer, scheduler, epoch, device, args, val_loader, writer, logger)
writer.close()
logger.info('################# Training Finished ##################')
logger.info(f"Best SRCC: {best_srcc}, Best PLCC: {best_plcc}, Best KLCC: {best_klcc}")
logger.info('######################################################')
def train_one_epoch(train_loader, model, criterion, optimizer, scheduler, epoch, device, args, val_loader, writer, logger):
global best_srcc, best_plcc, best_klcc
model.train()
batch_time = AverageMeter('Time', ':6.3f')
data_time = AverageMeter('Data', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
srcc = AverageMeter('SRCC', ':6.4f')
plcc = AverageMeter('PLCC', ':6.4f')
klcc = AverageMeter('KLCC', ':6.4f')
# mse = AverageMeter('MSE', ':6.4f')
progress = ProgressMeter(
len(train_loader),
[batch_time, data_time, losses, srcc, plcc, klcc],
prefix="Epoch: [{}]".format(epoch))
validate_freq = len(train_loader) // args.validate_num
end = time.time()
global_step = epoch * len(train_loader)
for i, batch in enumerate(train_loader):
data_time.update(time.time() - end)
# images = batch['image'].cuda(args.gpu, non_blocking=True)
# target = batch['label'].cuda(args.gpu, non_blocking=True).view(-1)
image_cropped = batch['image_cropped'].cuda(args.gpu, non_blocking=True)
target = batch['label'].cuda(args.gpu, non_blocking=True).view(-1)
if 'image_resized' in batch:
image_resized = batch['image_resized'].cuda(args.gpu, non_blocking=True)
output = model(image_cropped, image_resized).view(-1)
else:
output = model(image_cropped).view(-1)
target_len = target.size(0)
train_loss = criterion(output, target)
# Calculate metrics during training sessions
srcc_train = torchmetrics.functional.spearman_corrcoef(output, target).item()
plcc_train = torchmetrics.functional.pearson_corrcoef(output, target).item()
klcc_train = torchmetrics.functional.kendall_rank_corrcoef(output, target).item()
# Update Metrics
losses.update(train_loss.item(), target_len)
srcc.update(srcc_train, target_len)
plcc.update(plcc_train, target_len)
klcc.update(klcc_train, target_len)
# Add training loss to the writer
writer.add_scalars('Loss', {
'train': train_loss.item()
}, global_step + i)
optimizer.zero_grad()
train_loss.backward()
optimizer.step()
# if scheduler is not None:
scheduler.step_update(global_step + i)
# Record the current learning rate
if args.rank == 0:
current_lr = optimizer.param_groups[0]['lr']
writer.add_scalar('Learning_Rate', current_lr, global_step + i)
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i + 1)
# Perform multiple verifications within an epoch
if (i + 1) % validate_freq == 0:
model.eval()
results = validate(val_loader=val_loader, model=model, criterion=criterion, args=args, logger=logger)
val_srcc = results['metrics']['srcc']
val_plcc = results['metrics']['plcc']
val_klcc = results['metrics']['klcc']
val_loss = results['metrics']['loss']
logger.info(f'Validation results: SRCC: {val_srcc:.4f}, PLCC: {val_plcc:.4f}, KLCC: {val_klcc:.4f}, Loss: {val_loss:.4f}')
if args.rank == 0:
# Add the validation loss to the same loss chart
writer.add_scalars('Loss', {
'val': val_loss
}, global_step + i)
# Add all performance metrics to the same Metrics chart.
writer.add_scalars('Metrics', {
'SRCC': val_srcc,
'PLCC': val_plcc,
'KLCC': val_klcc
}, global_step + i)
is_best = val_srcc > best_srcc
best_srcc = max(val_srcc, best_srcc)
best_plcc = max(val_plcc, best_plcc)
best_klcc = max(val_klcc, best_klcc)
# Save the best model and results
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
and args.rank % args.ngpus_per_node == 0):
save_checkpoint(args, {
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_srcc': best_srcc,
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict()
}, is_best)
if is_best:
logger.info(
f'**BEST** Validation results: SRCC: {best_srcc:.4f}, PLCC: {best_plcc:.4f}, KLCC: {best_klcc:.4f}')
df = pd.DataFrame({
'image_name': results['image_names'],
'prediction': results['predictions'],
'ground_truth': results['ground_truth']
})
csv_filename = os.path.join(args.output_dir,
f'{args.run_name}_best_val_results.csv')
df.to_csv(csv_filename, index=False)
model.train()
logger.info(
f'**BEST** Validation results: SRCC: {best_srcc:.4f}, PLCC: {best_plcc:.4f}, KLCC: {best_klcc:.4f}')
return best_srcc, best_plcc, best_klcc
@torch.no_grad()
def validate(val_loader, model, args, criterion, logger):
model.eval()
val_dataset_len = len(val_loader.dataset)
val_loader_len = len(val_loader)
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
with torch.no_grad():
temp_pred_scores = []
temp_gt_scores = []
temp_img_names = []
time_tmp = time.time()
for i, batch in enumerate(val_loader):
if args.gpu is not None and torch.cuda.is_available():
device = torch.device(f'cuda:{args.gpu}')
elif torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')
images = batch['image_cropped'].to(device, non_blocking=True if device.type == 'cuda' else False)
target = batch['label'].to(device, non_blocking=True if device.type == 'cuda' else False)
if 'image_resized' in batch:
image_resized = batch['image_resized'].to(device, non_blocking=True if device.type == 'cuda' else False)
output = model(images, image_resized).view(-1)
else:
output = model(images).view(-1)
# if args.gpu is not None and torch.cuda.is_available():
# images = batch['image'].cuda(args.gpu, non_blocking=True)
# target = batch['label'].cuda(args.gpu, non_blocking=True)
# if torch.backends.mps.is_available():
# images = images.to('mps')
# target = target.to('mps')
# output = model(images).view(-1)
loss = criterion(output, target.view(-1))
loss = reduce_tensor(loss)
losses.update(loss.item(), target.size(0))
batch_time.update(time.time() - time_tmp)
time_tmp = time.time()
# Save predicted values, gt values, and image names
temp_pred_scores.append(output.view(-1))
temp_gt_scores.append(target.view(-1))
temp_img_names.extend(batch['image_name'])
if i % args.print_freq == 0:
logger.info(
f"Test: [{i}/{val_loader_len}]\t"
f"Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
f"Loss {losses.val:.4f} ({losses.avg:.4f})\t"
)
# Combine the results of all batches
pred_scores = torch.cat(temp_pred_scores)
gt_scores = torch.cat(temp_gt_scores)
# Distributed processing
if torch.distributed.is_initialized():
# Collect the results of all processes
img_names_gather = [None for _ in range(dist.get_world_size())]
torch.distributed.all_gather_object(img_names_gather, temp_img_names)
all_img_names = []
for names in img_names_gather:
all_img_names.extend(names)
all_img_names = all_img_names[:val_dataset_len] # 截取到实际数据集大小
preds_gather_list = [
torch.zeros_like(pred_scores) for _ in range(dist.get_world_size())
]
torch.distributed.all_gather(preds_gather_list, pred_scores)
gather_preds = torch.cat(preds_gather_list, dim=0)[:val_dataset_len]
grotruth_gather_list = [
torch.zeros_like(gt_scores) for _ in range(dist.get_world_size())
]
torch.distributed.all_gather(grotruth_gather_list, gt_scores)
gather_grotruth = torch.cat(grotruth_gather_list, dim=0)[:val_dataset_len]
if args.patch_num > 1:
gather_preds_matrix = gather_preds.view(-1, args.patch_num)
gather_preds = (gather_preds_matrix.mean(dim=-1)).squeeze()
gather_grotruth = (
(gather_grotruth.view(-1, args.patch_num)).mean(dim=-1)
).squeeze()
final_preds = gather_preds.float().detach()
final_grotruth = gather_grotruth.float().detach()
else:
final_preds = pred_scores.float().detach()
final_grotruth = gt_scores.float().detach()
all_img_names = temp_img_names
# Calculate the correlation coefficient
try:
logger.info(f"len of dataset: {val_dataset_len}, final_preds shape: {final_preds.shape}, final_grotruth shape: {final_grotruth.shape}")
# Check for the presence of NaN or inf
if torch.isnan(final_preds).any() or torch.isinf(final_preds).any() or \
torch.isnan(final_grotruth).any() or torch.isinf(final_grotruth).any():
raise ValueError("Found NaN or inf values in predictions or ground truth")
test_srcc = torchmetrics.functional.spearman_corrcoef(final_preds, final_grotruth).item()
test_plcc = torchmetrics.functional.pearson_corrcoef(final_preds, final_grotruth).item()
test_klcc = torchmetrics.functional.kendall_rank_corrcoef(final_preds, final_grotruth).item()
except Exception as e:
logger.warning(f"Error in calculating correlations: {str(e)}. Resetting cc relation to zero...")
test_plcc = 0.0
test_srcc = 0.0
test_klcc = 0.0
# Create a result dictionary containing the correspondence between image names, predicted values, and actual values.
results = {
'image_names': all_img_names,
'predictions': final_preds.cpu().numpy().tolist(),
'ground_truth': final_grotruth.cpu().numpy().tolist(),
'metrics': {
'srcc': test_srcc,
'plcc': test_plcc,
'klcc': test_klcc,
'loss': losses.avg
}
}
return results
class Summary(Enum):
NONE = 0
AVERAGE = 1
SUM = 2
COUNT = 3
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
self.name = name
self.fmt = fmt
self.summary_type = summary_type
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def all_reduce(self):
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
self.sum, self.count = total.tolist()
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
def summary(self):
fmtstr = ''
if self.summary_type is Summary.NONE:
fmtstr = ''
elif self.summary_type is Summary.AVERAGE:
fmtstr = '{name} {avg:.3f}'
elif self.summary_type is Summary.SUM:
fmtstr = '{name} {sum:.3f}'
elif self.summary_type is Summary.COUNT:
fmtstr = '{name} {count:.3f}'
else:
raise ValueError('invalid summary type %r' % self.summary_type)
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print('\t'.join(entries))
def display_summary(self):
entries = [" *"]
entries += [meter.summary() for meter in self.meters]
print(' '.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
if __name__ == '__main__':
args = get_args().parse_args()
args.run_name = args.arch + '_' + args.dataset + '_' + args.metric_type
os.makedirs(os.path.join(args.output_dir), exist_ok=True)
os.makedirs(os.path.join(args.output_dir, 'tensorboard_logs', args.run_name), exist_ok=True)
# save config file
with open(os.path.join(args.output_dir, 'config.yaml'), 'w') as f:
f.write(args.__dict__.__str__())
main(args)