miqa
xiaoqi-wang commited on
Commit
4e62d30
·
verified ·
1 Parent(s): e29b006

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +613 -0
train.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import random
4
+ import shutil
5
+ import time
6
+ import warnings
7
+ from enum import Enum
8
+ import pandas as pd
9
+ import torch
10
+ import logging
11
+
12
+ import torch.nn.parallel
13
+ import torch.optim
14
+ import torch.utils.data
15
+ import torch.utils.data.distributed
16
+ import torch.distributed as dist
17
+ import torch.multiprocessing as mp
18
+ import torchmetrics
19
+ from torch.utils.tensorboard import SummaryWriter
20
+
21
+ from configs.args_base import get_args
22
+
23
+ from utils.losses import build_loss_function
24
+ from utils.optimizer import build_optimizer
25
+ from utils.lr_scheduler import build_scheduler
26
+ from utils.logger import build_logger
27
+ from utils.misc import setup_seed, reduce_tensor, save_checkpoint
28
+ from data import build_dataloader
29
+ from models.MIQA_base import get_torch_model, get_timm_model
30
+ from models.RA_MIQA import RegionVisionTransformer
31
+
32
+ best_srcc = best_plcc = best_klcc = 0.
33
+
34
+ def main(args):
35
+ if args.seed is not None:
36
+
37
+ setup_seed(args.seed)
38
+
39
+ warnings.warn('You have chosen to seed training. '
40
+ 'This will turn on the CUDNN deterministic setting, '
41
+ 'which can slow down your training considerably! '
42
+ 'You may see unexpected behavior when restarting '
43
+ 'from checkpoints.')
44
+
45
+ if args.gpu is not None:
46
+ warnings.warn('You have chosen a specific GPU. This will completely '
47
+ 'disable data parallelism.')
48
+
49
+ if args.dist_url == "env://" and args.world_size == -1:
50
+ args.world_size = int(os.environ["WORLD_SIZE"])
51
+
52
+ args.distributed = args.world_size > 1 or args.multiprocessing_distributed
53
+
54
+ if torch.cuda.is_available():
55
+ ngpus_per_node = torch.cuda.device_count()
56
+ if ngpus_per_node == 1 and args.dist_backend == "nccl":
57
+ warnings.warn(
58
+ "nccl backend >=2.5 requires GPU count>1, see https://github.com/NVIDIA/nccl/issues/103 perhaps use 'gloo'")
59
+ else:
60
+ ngpus_per_node = 1
61
+
62
+ if args.multiprocessing_distributed:
63
+ # Since we have ngpus_per_node processes per node, the total world_size
64
+ # needs to be adjusted accordingly
65
+ args.world_size = ngpus_per_node * args.world_size
66
+ # Use torch.multiprocessing.spawn to launch distributed processes: the
67
+ # main_worker process function
68
+ mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
69
+ else:
70
+ # Simply call main_worker function
71
+ main_worker(args.gpu, ngpus_per_node, args)
72
+
73
+
74
+ def main_worker(gpu, ngpus_per_node, args):
75
+ global best_srcc, best_plcc, best_klcc
76
+ args.gpu = gpu
77
+ args.ngpus_per_node = ngpus_per_node
78
+ if args.distributed:
79
+ if args.dist_url == "env://" and args.rank == -1:
80
+ args.rank = int(os.environ["RANK"])
81
+ if args.multiprocessing_distributed:
82
+ # For multiprocessing distributed training, rank needs to be the
83
+ # global rank among all the processes
84
+ args.rank = args.rank * ngpus_per_node + gpu
85
+ dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
86
+ world_size=args.world_size, rank=args.rank)
87
+
88
+ logger = build_logger(
89
+ output_dir=args.output_dir,
90
+ log_file='{}_train.log'.format(args.run_name),
91
+ rank=args.rank if args.distributed else None,
92
+ level=logging.INFO,
93
+ console_level=logging.INFO if args.rank in [0, None] else logging.WARNING,
94
+ file_level=logging.INFO
95
+ )
96
+
97
+ if args.gpu is not None:
98
+ logger.info("Use GPU: {} for training".format(args.gpu))
99
+
100
+ # create model
101
+ if args.arch.startswith('RA_'):
102
+ model = RegionVisionTransformer(
103
+ base_model_name = 'vit_small_patch16_224',
104
+ pretrained = True,
105
+ mmseg_config_path = 'models/model_configs/fcn_sere-small_finetuned_fp16_8x32_224x224_3600_imagenets919.py',
106
+ checkpoint_path = 'models/checkpoints/sere_finetuned_vit_small_ep100.pth',
107
+ auto_download = True,
108
+ force_download = False
109
+ )
110
+ else:
111
+ try:
112
+ logger.info(f"Loading model form torchvision {args.arch}")
113
+ model = get_torch_model(model_name=args.arch, pretrained=args.pretrained, num_classes=1)
114
+ except:
115
+ logger.info(f"Loading model form timm {args.arch}")
116
+ model = get_timm_model(model_name=args.arch, pretrained=args.pretrained, num_classes=1)
117
+
118
+ if not torch.cuda.is_available() and not torch.backends.mps.is_available():
119
+ logger.info('using CPU, this will be slow')
120
+
121
+ elif args.distributed:
122
+ # For multiprocessing distributed, DistributedDataParallel constructor
123
+ # should always set the single device scope, otherwise,
124
+ # DistributedDataParallel will use all available devices.
125
+ if torch.cuda.is_available():
126
+ if args.gpu is not None:
127
+ torch.cuda.set_device(args.gpu)
128
+ model.cuda(args.gpu)
129
+ # When using a single GPU per process and per
130
+ # DistributedDataParallel, we need to divide the batch size
131
+ # ourselves based on the total number of GPUs of the current node.
132
+ args.batch_size = int(args.batch_size / ngpus_per_node)
133
+ args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
134
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
135
+ else:
136
+ model.cuda()
137
+ # DistributedDataParallel will divide and allocate batch_size to all
138
+ # available GPUs if device_ids are not set
139
+ model = torch.nn.parallel.DistributedDataParallel(model)
140
+
141
+ elif args.gpu is not None and torch.cuda.is_available():
142
+ torch.cuda.set_device(args.gpu)
143
+ model = model.cuda(args.gpu)
144
+ elif torch.backends.mps.is_available():
145
+ device = torch.device("mps")
146
+ model = model.to(device)
147
+ else:
148
+ # DataParallel will divide and allocate batch_size to all available GPUs
149
+ if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
150
+ model.features = torch.nn.DataParallel(model.features)
151
+ model.cuda()
152
+ else:
153
+ model = torch.nn.DataParallel(model).cuda()
154
+
155
+ if torch.cuda.is_available():
156
+ if args.gpu:
157
+ device = torch.device('cuda:{}'.format(args.gpu))
158
+ else:
159
+ device = torch.device("cuda")
160
+ elif torch.backends.mps.is_available():
161
+ device = torch.device("mps")
162
+ else:
163
+ device = torch.device("cpu")
164
+
165
+ # Data loading
166
+ train_dataset, val_dataset = build_dataloader.build_dataset(args)
167
+
168
+ if args.distributed:
169
+ train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
170
+ val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=False)
171
+ else:
172
+ train_sampler = None
173
+ val_sampler = None
174
+
175
+ train_loader = torch.utils.data.DataLoader(
176
+ train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
177
+ num_workers=args.workers, pin_memory=True, sampler=train_sampler)
178
+
179
+ val_loader = torch.utils.data.DataLoader(
180
+ val_dataset, batch_size=args.batch_size, shuffle=False,
181
+ num_workers=args.workers, pin_memory=True, sampler=val_sampler)
182
+
183
+ # define loss function (criterion), optimizer, and learning rate scheduler
184
+ criterion = build_loss_function(loss_name=args.loss_name)
185
+ optimizer = build_optimizer(args, model)
186
+ scheduler = build_scheduler(args, optimizer, len(train_loader))
187
+
188
+ # optionally resume from a checkpoint
189
+ if args.resume:
190
+ if os.path.isfile(args.resume):
191
+ logger.info("=> loading checkpoint '{}'".format(args.resume))
192
+ if args.gpu is None:
193
+ checkpoint = torch.load(args.resume)
194
+ elif torch.cuda.is_available():
195
+ # Map model to be loaded to specified single gpu.
196
+ loc = 'cuda:{}'.format(args.gpu)
197
+ checkpoint = torch.load(args.resume, map_location=loc)
198
+ args.start_epoch = checkpoint['epoch']
199
+ best_srcc = checkpoint['best_srcc']
200
+ if args.gpu is not None:
201
+ # best_acc1 may be from a checkpoint from a different GPU
202
+ best_srcc = best_srcc.to(args.gpu)
203
+ model.load_state_dict(checkpoint['state_dict'])
204
+ optimizer.load_state_dict(checkpoint['optimizer'])
205
+ scheduler.load_state_dict(checkpoint['scheduler'])
206
+ logger.info("=> loaded checkpoint '{}' (epoch {})"
207
+ .format(args.resume, checkpoint['epoch']))
208
+ else:
209
+ logger.info("=> no checkpoint found at '{}'".format(args.resume))
210
+
211
+ # evaluate on validation set
212
+ if args.eval_only:
213
+ validate(val_loader, model, criterion, args)
214
+ return
215
+
216
+ writer = SummaryWriter(log_dir=os.path.join(args.output_dir, 'tensorboard_logs', args.run_name))
217
+
218
+ for epoch in range(args.start_epoch, args.epochs+1):
219
+ if args.distributed:
220
+ train_sampler.set_epoch(epoch)
221
+
222
+ # train for one epoch
223
+ best_srcc, best_plcc, best_klcc = train_one_epoch(train_loader, model, criterion, optimizer, scheduler, epoch, device, args, val_loader, writer, logger)
224
+
225
+ writer.close()
226
+
227
+ logger.info('################# Training Finished ##################')
228
+ logger.info(f"Best SRCC: {best_srcc}, Best PLCC: {best_plcc}, Best KLCC: {best_klcc}")
229
+ logger.info('######################################################')
230
+
231
+
232
+ def train_one_epoch(train_loader, model, criterion, optimizer, scheduler, epoch, device, args, val_loader, writer, logger):
233
+ global best_srcc, best_plcc, best_klcc
234
+ model.train()
235
+
236
+ batch_time = AverageMeter('Time', ':6.3f')
237
+ data_time = AverageMeter('Data', ':6.3f')
238
+ losses = AverageMeter('Loss', ':.4e')
239
+ srcc = AverageMeter('SRCC', ':6.4f')
240
+ plcc = AverageMeter('PLCC', ':6.4f')
241
+ klcc = AverageMeter('KLCC', ':6.4f')
242
+ # mse = AverageMeter('MSE', ':6.4f')
243
+
244
+ progress = ProgressMeter(
245
+ len(train_loader),
246
+ [batch_time, data_time, losses, srcc, plcc, klcc],
247
+ prefix="Epoch: [{}]".format(epoch))
248
+
249
+ validate_freq = len(train_loader) // args.validate_num
250
+
251
+ end = time.time()
252
+ global_step = epoch * len(train_loader)
253
+
254
+ for i, batch in enumerate(train_loader):
255
+ data_time.update(time.time() - end)
256
+
257
+ # images = batch['image'].cuda(args.gpu, non_blocking=True)
258
+ # target = batch['label'].cuda(args.gpu, non_blocking=True).view(-1)
259
+
260
+ image_cropped = batch['image_cropped'].cuda(args.gpu, non_blocking=True)
261
+ target = batch['label'].cuda(args.gpu, non_blocking=True).view(-1)
262
+
263
+ if 'image_resized' in batch:
264
+ image_resized = batch['image_resized'].cuda(args.gpu, non_blocking=True)
265
+ output = model(image_cropped, image_resized).view(-1)
266
+ else:
267
+ output = model(image_cropped).view(-1)
268
+
269
+ target_len = target.size(0)
270
+ train_loss = criterion(output, target)
271
+
272
+ # Calculate metrics during training sessions
273
+ srcc_train = torchmetrics.functional.spearman_corrcoef(output, target).item()
274
+ plcc_train = torchmetrics.functional.pearson_corrcoef(output, target).item()
275
+ klcc_train = torchmetrics.functional.kendall_rank_corrcoef(output, target).item()
276
+
277
+ # Update Metrics
278
+ losses.update(train_loss.item(), target_len)
279
+ srcc.update(srcc_train, target_len)
280
+ plcc.update(plcc_train, target_len)
281
+ klcc.update(klcc_train, target_len)
282
+
283
+ # Add training loss to the writer
284
+ writer.add_scalars('Loss', {
285
+ 'train': train_loss.item()
286
+ }, global_step + i)
287
+
288
+ optimizer.zero_grad()
289
+ train_loss.backward()
290
+ optimizer.step()
291
+
292
+ # if scheduler is not None:
293
+ scheduler.step_update(global_step + i)
294
+
295
+ # Record the current learning rate
296
+ if args.rank == 0:
297
+ current_lr = optimizer.param_groups[0]['lr']
298
+ writer.add_scalar('Learning_Rate', current_lr, global_step + i)
299
+
300
+ batch_time.update(time.time() - end)
301
+ end = time.time()
302
+
303
+ if i % args.print_freq == 0:
304
+ progress.display(i + 1)
305
+
306
+ # Perform multiple verifications within an epoch
307
+ if (i + 1) % validate_freq == 0:
308
+ model.eval()
309
+
310
+ results = validate(val_loader=val_loader, model=model, criterion=criterion, args=args, logger=logger)
311
+ val_srcc = results['metrics']['srcc']
312
+ val_plcc = results['metrics']['plcc']
313
+ val_klcc = results['metrics']['klcc']
314
+ val_loss = results['metrics']['loss']
315
+ logger.info(f'Validation results: SRCC: {val_srcc:.4f}, PLCC: {val_plcc:.4f}, KLCC: {val_klcc:.4f}, Loss: {val_loss:.4f}')
316
+ if args.rank == 0:
317
+ # Add the validation loss to the same loss chart
318
+ writer.add_scalars('Loss', {
319
+ 'val': val_loss
320
+ }, global_step + i)
321
+
322
+ # Add all performance metrics to the same Metrics chart.
323
+ writer.add_scalars('Metrics', {
324
+ 'SRCC': val_srcc,
325
+ 'PLCC': val_plcc,
326
+ 'KLCC': val_klcc
327
+ }, global_step + i)
328
+
329
+ is_best = val_srcc > best_srcc
330
+ best_srcc = max(val_srcc, best_srcc)
331
+ best_plcc = max(val_plcc, best_plcc)
332
+ best_klcc = max(val_klcc, best_klcc)
333
+
334
+ # Save the best model and results
335
+ if not args.multiprocessing_distributed or (args.multiprocessing_distributed
336
+ and args.rank % args.ngpus_per_node == 0):
337
+ save_checkpoint(args, {
338
+ 'epoch': epoch + 1,
339
+ 'arch': args.arch,
340
+ 'state_dict': model.state_dict(),
341
+ 'best_srcc': best_srcc,
342
+ 'optimizer': optimizer.state_dict(),
343
+ 'scheduler': scheduler.state_dict()
344
+ }, is_best)
345
+
346
+ if is_best:
347
+ logger.info(
348
+ f'**BEST** Validation results: SRCC: {best_srcc:.4f}, PLCC: {best_plcc:.4f}, KLCC: {best_klcc:.4f}')
349
+
350
+ df = pd.DataFrame({
351
+ 'image_name': results['image_names'],
352
+ 'prediction': results['predictions'],
353
+ 'ground_truth': results['ground_truth']
354
+ })
355
+ csv_filename = os.path.join(args.output_dir,
356
+ f'{args.run_name}_best_val_results.csv')
357
+ df.to_csv(csv_filename, index=False)
358
+
359
+ model.train()
360
+
361
+ logger.info(
362
+ f'**BEST** Validation results: SRCC: {best_srcc:.4f}, PLCC: {best_plcc:.4f}, KLCC: {best_klcc:.4f}')
363
+ return best_srcc, best_plcc, best_klcc
364
+
365
+ @torch.no_grad()
366
+ def validate(val_loader, model, args, criterion, logger):
367
+ model.eval()
368
+ val_dataset_len = len(val_loader.dataset)
369
+ val_loader_len = len(val_loader)
370
+ batch_time = AverageMeter('Time', ':6.3f')
371
+ losses = AverageMeter('Loss', ':.4e')
372
+
373
+ with torch.no_grad():
374
+ temp_pred_scores = []
375
+ temp_gt_scores = []
376
+ temp_img_names = []
377
+ time_tmp = time.time()
378
+
379
+ for i, batch in enumerate(val_loader):
380
+
381
+ if args.gpu is not None and torch.cuda.is_available():
382
+ device = torch.device(f'cuda:{args.gpu}')
383
+ elif torch.backends.mps.is_available():
384
+ device = torch.device('mps')
385
+ else:
386
+ device = torch.device('cpu')
387
+
388
+ images = batch['image_cropped'].to(device, non_blocking=True if device.type == 'cuda' else False)
389
+ target = batch['label'].to(device, non_blocking=True if device.type == 'cuda' else False)
390
+
391
+ if 'image_resized' in batch:
392
+ image_resized = batch['image_resized'].to(device, non_blocking=True if device.type == 'cuda' else False)
393
+ output = model(images, image_resized).view(-1)
394
+ else:
395
+ output = model(images).view(-1)
396
+
397
+ # if args.gpu is not None and torch.cuda.is_available():
398
+ # images = batch['image'].cuda(args.gpu, non_blocking=True)
399
+ # target = batch['label'].cuda(args.gpu, non_blocking=True)
400
+ # if torch.backends.mps.is_available():
401
+ # images = images.to('mps')
402
+ # target = target.to('mps')
403
+
404
+ # output = model(images).view(-1)
405
+ loss = criterion(output, target.view(-1))
406
+ loss = reduce_tensor(loss)
407
+ losses.update(loss.item(), target.size(0))
408
+
409
+ batch_time.update(time.time() - time_tmp)
410
+ time_tmp = time.time()
411
+
412
+ # Save predicted values, gt values, and image names
413
+ temp_pred_scores.append(output.view(-1))
414
+ temp_gt_scores.append(target.view(-1))
415
+ temp_img_names.extend(batch['image_name'])
416
+
417
+ if i % args.print_freq == 0:
418
+ logger.info(
419
+ f"Test: [{i}/{val_loader_len}]\t"
420
+ f"Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
421
+ f"Loss {losses.val:.4f} ({losses.avg:.4f})\t"
422
+ )
423
+
424
+ # Combine the results of all batches
425
+ pred_scores = torch.cat(temp_pred_scores)
426
+ gt_scores = torch.cat(temp_gt_scores)
427
+
428
+ # Distributed processing
429
+ if torch.distributed.is_initialized():
430
+ # Collect the results of all processes
431
+ img_names_gather = [None for _ in range(dist.get_world_size())]
432
+ torch.distributed.all_gather_object(img_names_gather, temp_img_names)
433
+ all_img_names = []
434
+ for names in img_names_gather:
435
+ all_img_names.extend(names)
436
+ all_img_names = all_img_names[:val_dataset_len] # 截取到实际数据集大小
437
+
438
+ preds_gather_list = [
439
+ torch.zeros_like(pred_scores) for _ in range(dist.get_world_size())
440
+ ]
441
+ torch.distributed.all_gather(preds_gather_list, pred_scores)
442
+ gather_preds = torch.cat(preds_gather_list, dim=0)[:val_dataset_len]
443
+
444
+ grotruth_gather_list = [
445
+ torch.zeros_like(gt_scores) for _ in range(dist.get_world_size())
446
+ ]
447
+ torch.distributed.all_gather(grotruth_gather_list, gt_scores)
448
+ gather_grotruth = torch.cat(grotruth_gather_list, dim=0)[:val_dataset_len]
449
+
450
+ if args.patch_num > 1:
451
+ gather_preds_matrix = gather_preds.view(-1, args.patch_num)
452
+
453
+ gather_preds = (gather_preds_matrix.mean(dim=-1)).squeeze()
454
+ gather_grotruth = (
455
+ (gather_grotruth.view(-1, args.patch_num)).mean(dim=-1)
456
+ ).squeeze()
457
+
458
+ final_preds = gather_preds.float().detach()
459
+ final_grotruth = gather_grotruth.float().detach()
460
+ else:
461
+ final_preds = pred_scores.float().detach()
462
+ final_grotruth = gt_scores.float().detach()
463
+ all_img_names = temp_img_names
464
+
465
+ # Calculate the correlation coefficient
466
+ try:
467
+ logger.info(f"len of dataset: {val_dataset_len}, final_preds shape: {final_preds.shape}, final_grotruth shape: {final_grotruth.shape}")
468
+ # Check for the presence of NaN or inf
469
+ if torch.isnan(final_preds).any() or torch.isinf(final_preds).any() or \
470
+ torch.isnan(final_grotruth).any() or torch.isinf(final_grotruth).any():
471
+ raise ValueError("Found NaN or inf values in predictions or ground truth")
472
+
473
+ test_srcc = torchmetrics.functional.spearman_corrcoef(final_preds, final_grotruth).item()
474
+ test_plcc = torchmetrics.functional.pearson_corrcoef(final_preds, final_grotruth).item()
475
+ test_klcc = torchmetrics.functional.kendall_rank_corrcoef(final_preds, final_grotruth).item()
476
+
477
+ except Exception as e:
478
+ logger.warning(f"Error in calculating correlations: {str(e)}. Resetting cc relation to zero...")
479
+ test_plcc = 0.0
480
+ test_srcc = 0.0
481
+ test_klcc = 0.0
482
+
483
+ # Create a result dictionary containing the correspondence between image names, predicted values, and actual values.
484
+ results = {
485
+ 'image_names': all_img_names,
486
+ 'predictions': final_preds.cpu().numpy().tolist(),
487
+ 'ground_truth': final_grotruth.cpu().numpy().tolist(),
488
+ 'metrics': {
489
+ 'srcc': test_srcc,
490
+ 'plcc': test_plcc,
491
+ 'klcc': test_klcc,
492
+ 'loss': losses.avg
493
+ }
494
+ }
495
+
496
+ return results
497
+
498
+ class Summary(Enum):
499
+ NONE = 0
500
+ AVERAGE = 1
501
+ SUM = 2
502
+ COUNT = 3
503
+
504
+ class AverageMeter(object):
505
+ """Computes and stores the average and current value"""
506
+
507
+ def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
508
+ self.name = name
509
+ self.fmt = fmt
510
+ self.summary_type = summary_type
511
+ self.reset()
512
+
513
+ def reset(self):
514
+ self.val = 0
515
+ self.avg = 0
516
+ self.sum = 0
517
+ self.count = 0
518
+
519
+ def update(self, val, n=1):
520
+ self.val = val
521
+ self.sum += val * n
522
+ self.count += n
523
+ self.avg = self.sum / self.count
524
+
525
+ def all_reduce(self):
526
+ if torch.cuda.is_available():
527
+ device = torch.device("cuda")
528
+ elif torch.backends.mps.is_available():
529
+ device = torch.device("mps")
530
+ else:
531
+ device = torch.device("cpu")
532
+ total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
533
+ dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
534
+ self.sum, self.count = total.tolist()
535
+ self.avg = self.sum / self.count
536
+
537
+ def __str__(self):
538
+ fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
539
+ return fmtstr.format(**self.__dict__)
540
+
541
+ def summary(self):
542
+ fmtstr = ''
543
+ if self.summary_type is Summary.NONE:
544
+ fmtstr = ''
545
+ elif self.summary_type is Summary.AVERAGE:
546
+ fmtstr = '{name} {avg:.3f}'
547
+ elif self.summary_type is Summary.SUM:
548
+ fmtstr = '{name} {sum:.3f}'
549
+ elif self.summary_type is Summary.COUNT:
550
+ fmtstr = '{name} {count:.3f}'
551
+ else:
552
+ raise ValueError('invalid summary type %r' % self.summary_type)
553
+
554
+ return fmtstr.format(**self.__dict__)
555
+
556
+
557
+ class ProgressMeter(object):
558
+ def __init__(self, num_batches, meters, prefix=""):
559
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
560
+ self.meters = meters
561
+ self.prefix = prefix
562
+
563
+ def display(self, batch):
564
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
565
+ entries += [str(meter) for meter in self.meters]
566
+ print('\t'.join(entries))
567
+
568
+ def display_summary(self):
569
+ entries = [" *"]
570
+ entries += [meter.summary() for meter in self.meters]
571
+ print(' '.join(entries))
572
+
573
+ def _get_batch_fmtstr(self, num_batches):
574
+ num_digits = len(str(num_batches // 1))
575
+ fmt = '{:' + str(num_digits) + 'd}'
576
+ return '[' + fmt + '/' + fmt.format(num_batches) + ']'
577
+
578
+
579
+ def accuracy(output, target, topk=(1,)):
580
+ """Computes the accuracy over the k top predictions for the specified values of k"""
581
+ with torch.no_grad():
582
+ maxk = max(topk)
583
+ batch_size = target.size(0)
584
+
585
+ _, pred = output.topk(maxk, 1, True, True)
586
+ pred = pred.t()
587
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
588
+
589
+ res = []
590
+ for k in topk:
591
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
592
+ res.append(correct_k.mul_(100.0 / batch_size))
593
+ return res
594
+
595
+
596
+ if __name__ == '__main__':
597
+
598
+ args = get_args().parse_args()
599
+
600
+ args.run_name = args.arch + '_' + args.dataset + '_' + args.metric_type
601
+
602
+ os.makedirs(os.path.join(args.output_dir), exist_ok=True)
603
+ os.makedirs(os.path.join(args.output_dir, 'tensorboard_logs', args.run_name), exist_ok=True)
604
+
605
+ # save config file
606
+ with open(os.path.join(args.output_dir, 'config.yaml'), 'w') as f:
607
+ f.write(args.__dict__.__str__())
608
+
609
+ main(args)
610
+
611
+
612
+
613
+