| |
| """ |
| Train and eval functions used in main.py |
| """ |
|
|
| import math |
| from PIL import Image |
| import os |
| import sys |
| from typing import Iterable |
| import numpy as np |
| from util.utils import slprint, to_device |
| from sklearn.metrics import accuracy_score |
| import numpy as np |
| from itertools import zip_longest |
| import torch |
| |
| from compute_rouge import compute_rouge |
| import util.misc as utils |
| from dino_datasets.coco_eval import CocoEvaluator |
| from dino_datasets.panoptic_eval import PanopticEvaluator |
| from sklearn.metrics import accuracy_score, f1_score |
| import torch.nn.functional as F |
| import matplotlib.pyplot as plt |
| |
| |
| |
| |
| from itertools import zip_longest |
| |
| |
| |
| |
|
|
|
|
|
|
| def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, |
| data_loader: Iterable,data_loader_2: Iterable,data_loader_3: Iterable,data_loader_4: Iterable, optimizer: torch.optim.Optimizer, |
| device: torch.device, epoch: int, max_norm: float = 0, |
| wo_class_error=False, lr_scheduler=None, args=None, logger=None, ema_m=None): |
| scaler = torch.cuda.amp.GradScaler(enabled=args.amp) |
|
|
| try: |
| need_tgt_for_training = args.use_dn |
| except: |
| need_tgt_for_training = False |
|
|
| model.train() |
| criterion.train() |
| metric_logger = utils.MetricLogger(delimiter=" ") |
| metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) |
| if not wo_class_error: |
| metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) |
| header = 'Epoch: [{}]'.format(epoch) |
| print_freq = 200 |
|
|
| _cnt = 0 |
| |
| loader1 = iter(metric_logger.log_every(data_loader, print_freq, header, logger=logger)) |
| loader2 = iter(metric_logger.log_every(data_loader_2, print_freq, header, logger=logger)) |
| loader3 = iter(metric_logger.log_every(data_loader_3, print_freq, header, logger=logger)) |
| loader4 = iter(metric_logger.log_every(data_loader_4, print_freq, header, logger=logger)) |
| |
|
|
| |
| if args.traning_step == 1: |
| data_iter = zip_longest(loader3, fillvalue=None) |
|
|
| elif args.traning_step == 3: |
| data_iter = zip(loader1, loader2, loader3) |
|
|
| elif args.traning_step == 4: |
| data_iter = zip_longest(loader1, fillvalue=None) |
|
|
| elif args.traning_step == 5: |
| data_iter = zip_longest(loader2, fillvalue=None) |
|
|
| elif args.traning_step == 6: |
| data_iter = zip_longest(loader4, fillvalue=None) |
|
|
| else: |
| raise ValueError("Invalid training step") |
|
|
|
|
| for batches in data_iter: |
|
|
| losses = torch.as_tensor(0.).to(device) |
| loss_value = torch.as_tensor(0.).to(device) |
| combined_loss_dict_scaled = {} |
| combined_loss_dict_unscaled = {} |
|
|
| |
| if args.traning_step == 3: |
| batch_list = list(batches) |
| else: |
| batch_list = [batches] |
|
|
| for i, batch in enumerate(batch_list): |
|
|
| if batch is None: |
| continue |
|
|
| |
| if args.traning_step == 3: |
| samples, targets, prompt = batch |
| else: |
| samples, targets, prompt = batch[0] |
|
|
| samples = samples.to(device) |
|
|
| targets = [ |
| {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} |
| for t in targets |
| ] |
|
|
| with torch.cuda.amp.autocast(enabled=args.amp): |
| outputs = model(samples, prompt, targets) if need_tgt_for_training else model(samples, prompt) |
|
|
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| loss_dict = criterion(outputs, targets,args) |
| weight_dict = criterion.weight_dict |
|
|
| |
| loss_dict_reduced = utils.reduce_dict(loss_dict) |
|
|
| |
| for k, v in loss_dict_reduced.items(): |
| combined_loss_dict_unscaled[f'{k}_unscaled_{i+1}'] = v |
|
|
| |
| loss_dict_reduced_scaled = { |
| k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict |
| } |
|
|
| |
| if loss_dict_reduced_scaled: |
| loss = sum(loss_dict_reduced_scaled.values()) |
| else: |
| |
| |
| loss = sum(loss_dict_reduced.values()) |
|
|
| |
| losses= losses+loss |
| |
| loss_value = loss_value + loss.item() |
|
|
| for k, v in loss_dict_reduced_scaled.items(): |
| combined_loss_dict_scaled[f'{k}_scaled_{i+1}'] = v |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| if not math.isfinite(loss_value): |
| print("Loss is {}, stopping training".format(loss_value)) |
| print(loss_dict_reduced) |
| sys.exit(1) |
|
|
|
|
| |
| if args.amp: |
| optimizer.zero_grad() |
| |
| scaler.scale(losses).backward() |
| if max_norm > 0: |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) |
| scaler.step(optimizer) |
| |
| scaler.update() |
| else: |
| |
| optimizer.zero_grad() |
| |
| if losses != 0: |
| losses.backward(retain_graph=True) |
| |
| |
| optimizer.step() |
| |
|
|
| if args.onecyclelr: |
| lr_scheduler.step() |
| |
| if args.use_ema: |
| if epoch >= args.ema_epoch: |
| ema_m.update(model) |
|
|
| |
| metric_logger.update(loss=loss_value, **{ |
| **{k: v.item() for k, v in combined_loss_dict_unscaled.items()}, |
| **{k: v.item() for k, v in combined_loss_dict_scaled.items()} |
| }) |
| if 'class_error' in loss_dict_reduced: |
| metric_logger.update(class_error=loss_dict_reduced['class_error']) |
| metric_logger.update(lr=optimizer.param_groups[0]["lr"]) |
|
|
| _cnt += 1 |
| if args.debug: |
| if _cnt % 15 == 0: |
| print("BREAK!"*5) |
| break |
|
|
| if getattr(criterion, 'loss_weight_decay', False): |
| criterion.loss_weight_decay(epoch=epoch) |
| if getattr(criterion, 'tuning_matching', False): |
| criterion.tuning_matching(epoch) |
|
|
|
|
| |
| metric_logger.synchronize_between_processes() |
| print("Averaged stats:", metric_logger) |
| resstat = {k: meter.global_avg for k, meter in metric_logger.meters.items() if meter.count > 0} |
| if getattr(criterion, 'loss_weight_decay', False): |
| resstat.update({f'weight_{k}': v for k,v in criterion.weight_dict.items()}) |
| return resstat |
| def plot_pred_mask(pred_mask, threshold=0.5, title="Predicted Mask"): |
| |
| if pred_mask.dim() == 4: |
| pred_mask = pred_mask[0, 0] |
| elif pred_mask.dim() == 3: |
| pred_mask = pred_mask[0] |
|
|
| |
| pred_mask = pred_mask.sigmoid() if pred_mask.max() > 1 else pred_mask |
| binary_mask = (pred_mask > threshold).float() |
|
|
| |
| binary_mask_np = binary_mask.detach().cpu().numpy() |
|
|
| |
| plt.figure(figsize=(6, 6)) |
| plt.imshow(binary_mask_np, cmap='gray') |
| plt.title(title) |
| plt.axis('off') |
| plt.show() |
| def dice_score2(pred, target, epsilon=1e-6): |
| |
| |
| if isinstance(pred, list): |
| pred = torch.stack(pred) if isinstance(pred[0], torch.Tensor) else torch.tensor(pred) |
| if isinstance(target, list): |
| target = torch.stack(target) if isinstance(target[0], torch.Tensor) else torch.tensor(target) |
|
|
| pred = pred.float() |
| target = target.float() |
|
|
| |
| pred = (pred > 0.5).float() |
| target = (target > 0.5).float() |
|
|
| intersection = (pred.cpu() * target.cpu()).sum() |
| union = pred.cpu().sum() + target.cpu().sum() |
|
|
| dice = (2. * intersection + epsilon) / (union + epsilon) |
| return dice.item()/len(pred) |
|
|
|
|
| def multiclass_dice_score(pred, target, num_classes, epsilon=1e-6): |
| """ |
| pred: (B, C, H, W) raw logits or probabilities |
| target: (B, H, W) class indices 0..C-1 |
| """ |
|
|
| |
| pred_soft = F.softmax(pred, dim=1) |
|
|
| |
| target_onehot = F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() |
|
|
| |
| dice_scores = [] |
| for c in range(num_classes): |
| pred_c = pred_soft[:, c, :, :] |
| target_c = target_onehot[:, c, :, :] |
|
|
| intersection = (pred_c * target_c).sum(dim=(1,2)) |
| union = pred_c.sum(dim=(1,2)) + target_c.sum(dim=(1,2)) |
|
|
| dice = ((2 * intersection + epsilon) / (union + epsilon)).mean() |
| dice_scores.append(dice) |
|
|
| |
| mean_dice = torch.mean(torch.stack(dice_scores)) |
|
|
| return mean_dice, dice_scores |
|
|
| def focal_loss(pred, target, alpha=0.25, gamma=2.): |
| pred = pred.view(-1) |
| target = target.view(-1) |
| bce_loss = F.binary_cross_entropy(pred, target, reduction='none') |
| pt = torch.exp(-bce_loss) |
| focal_loss = alpha * (1 - pt) ** gamma * bce_loss |
| return focal_loss.mean() |
| @torch.no_grad() |
| def map_label(lbl): |
| if 0 <= lbl <= 14: |
| return 24 |
| elif lbl in [24, 29]: |
| return 24 |
| elif 15 <= lbl <= 20: |
| return 23 |
| elif lbl == 22: |
| return 23 |
| elif 25 <= lbl <= 28: |
| return 23 |
| else: |
| return lbl |
| def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir, wo_class_error=False, args=None, logger=None): |
| try: |
| need_tgt_for_training = args.use_dn |
| except: |
| need_tgt_for_training = False |
|
|
| model.eval() |
| criterion.eval() |
| |
| all_gt_morphology = [] |
| all_pred_morphology = [] |
| seg_mask_t=[] |
| classification_t=[] |
| classification_p=[] |
| |
| |
| Dice_score_all=[] |
| text_t=[] |
| text_p=[] |
| prompt_text=[] |
|
|
| metric_logger = utils.MetricLogger(delimiter=" ") |
| if not wo_class_error: |
| metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) |
| header = 'Test:' |
|
|
| iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys()) |
| useCats = True |
| try: |
| useCats = args.useCats |
| except: |
| useCats = True |
| if not useCats: |
| print("useCats: {} !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!".format(useCats)) |
| if args.eval_type == "det": |
| coco_evaluator = CocoEvaluator(base_ds, iou_types, useCats=useCats) |
| |
|
|
| panoptic_evaluator = None |
| if 'panoptic' in postprocessors.keys(): |
| panoptic_evaluator = PanopticEvaluator( |
| data_loader.dataset.ann_file, |
| data_loader.dataset.ann_folder, |
| output_dir=os.path.join(output_dir, "panoptic_eval"), |
| ) |
|
|
| _cnt = 0 |
| output_state_dict = {} |
| for samples, targets,prompt in metric_logger.log_every(data_loader, 100, header, logger=logger): |
| samples = samples.to(device) |
|
|
| |
| targets = [ |
| {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} |
| for t in targets] |
| |
| |
| |
| |
| |
| |
| |
| |
| with torch.cuda.amp.autocast(enabled=args.amp): |
| if need_tgt_for_training: |
| outputs = model(samples,prompt, targets) |
| else: |
| |
| outputs = model(samples,prompt) |
| |
|
|
| loss_dict = criterion(outputs, targets,args) |
| weight_dict = criterion.weight_dict |
| |
| |
| |
| |
| |
| |
| pred_morphology = outputs['pred_morphology'] |
|
|
| num_attributes = 6 |
| num_classes_per_attribute = 2 |
|
|
| pred_morphology = pred_morphology.view( |
| pred_morphology.size(0), |
| pred_morphology.size(1), |
| num_attributes, |
| num_classes_per_attribute |
| ) |
|
|
| morphology_probs = F.softmax(pred_morphology, dim=-1) |
| pred_morphology_labels = morphology_probs.argmax(-1) |
|
|
| |
| |
| for i, target in enumerate(targets): |
| if 'segmentation' in target: |
| seg_mask_t.append(target['segmentation']) |
| pred_masks = outputs['pred_mask'] |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| pred_probs = torch.sigmoid(pred_masks) |
| Dice_score_all.append(dice_score2(pred_probs, target['binary_mask'])) |
| for j in range(pred_probs.shape[0]): |
| mask = (pred_probs[j,0] * 255).cpu().detach().numpy().astype('uint8') |
| img = Image.fromarray(mask) |
| |
| image_id = target['mask'].item() |
| img.save(os.path.join("Malaria-Detection-2019_test_data", f"pred_mask_{image_id}.png")) |
| |
| if 'classification' in target: |
| logits_classification = outputs['pred_image_class'] |
| logits_classification_feat = outputs['pred_image_feat'] |
| |
| probs = F.softmax(logits_classification, dim=1) |
| pred_class = probs.argmax(dim=1).item() |
|
|
| target_class = target['category_id'].item() |
| |
| |
| classification_t.append(target_class) |
| classification_p.append(pred_class) |
| |
| |
| if 'masked_traning' in target: |
| pred_text = outputs['pred_text'] |
| |
| |
| |
| target_text = outputs['completed_text'] |
| |
| text_t.extend(target_text) |
| text_p.extend(pred_text) |
| prompt_text.append(target['prompt']) |
| |
| |
| if 'morphology' in target: |
| gt_morphology = target['morphology'] |
|
|
| |
| indices = criterion.matcher(outputs, [target])[0] |
| src_idx = criterion._get_src_permutation_idx([indices])[1] |
| tgt_idx = criterion._get_tgt_permutation_idx([indices])[1] |
|
|
| |
| pred_labels = pred_morphology_labels[i, src_idx] |
| gt_labels = gt_morphology[tgt_idx] |
|
|
| all_pred_morphology.append(pred_labels) |
| all_gt_morphology.append(gt_labels) |
|
|
| |
| |
| |
|
|
| |
| loss_dict_reduced = utils.reduce_dict(loss_dict) |
| loss_dict_reduced_scaled = {k: v * weight_dict[k] |
| for k, v in loss_dict_reduced.items() if k in weight_dict} |
| loss_dict_reduced_unscaled = {f'{k}_unscaled': v |
| for k, v in loss_dict_reduced.items()} |
| metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()), |
| **loss_dict_reduced_scaled, |
| **loss_dict_reduced_unscaled) |
| if 'class_error' in loss_dict_reduced: |
| metric_logger.update(class_error=loss_dict_reduced['class_error']) |
|
|
| if seg_mask_t: |
| E_Score= 1 |
| elif classification_t: |
| E_Score= 1 |
| elif text_t: |
| E_Score= 1 |
| else: |
| orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) |
| results = postprocessors['bbox'](outputs, orig_target_sizes) |
| score_threshold = 0.1 |
|
|
| |
| for result in results: |
| keep = result["scores"] > score_threshold |
| for key in result.keys(): |
| result[key] = result[key][keep] |
| |
| if 'segm' in postprocessors.keys(): |
| target_sizes = torch.stack([t["size"] for t in targets], dim=0) |
| results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes) |
| res = {target['image_id'].item(): output for target, output in zip(targets, results)} |
| |
| |
| |
| |
| |
| |
| |
| if coco_evaluator is not None: |
| coco_evaluator.update(res) |
|
|
| if panoptic_evaluator is not None: |
| res_pano = postprocessors["panoptic"](outputs, target_sizes, orig_target_sizes) |
| for i, target in enumerate(targets): |
| image_id = target["image_id"].item() |
| file_name = f"{image_id:012d}.png" |
| res_pano[i]["image_id"] = image_id |
| res_pano[i]["file_name"] = file_name |
|
|
| panoptic_evaluator.update(res_pano) |
| |
| if args.save_results: |
| |
| |
| |
| |
|
|
|
|
| for i, (tgt, res, outbbox) in enumerate(zip(targets, results, outputs['pred_boxes'])): |
| """ |
| pred vars: |
| K: number of bbox pred |
| score: Tensor(K), |
| label: list(len: K), |
| bbox: Tensor(K, 4) |
| idx: list(len: K) |
| tgt: dict. |
| |
| """ |
| |
| gt_bbox = tgt['boxes'] |
| gt_label = tgt['labels'] |
| gt_info = torch.cat((gt_bbox, gt_label.unsqueeze(-1)), 1) |
| |
| |
| |
| |
| _res_bbox = outbbox |
| _res_prob = res['scores'] |
| _res_label = res['labels'] |
| res_info = torch.cat((_res_bbox, _res_prob.unsqueeze(-1), _res_label.unsqueeze(-1)), 1) |
| |
|
|
| if 'gt_info' not in output_state_dict: |
| output_state_dict['gt_info'] = [] |
| output_state_dict['gt_info'].append(gt_info.cpu()) |
|
|
| if 'res_info' not in output_state_dict: |
| output_state_dict['res_info'] = [] |
| output_state_dict['res_info'].append(res_info.cpu()) |
|
|
| |
| |
| |
| |
| |
|
|
| _cnt += 1 |
| if args.debug: |
| if _cnt % 15 == 0: |
| print("BREAK!"*5) |
| break |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| if seg_mask_t: |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| D_Score= sum(Dice_score_all) / len(Dice_score_all) |
| metric_logger.synchronize_between_processes() |
| print("Averaged stats:", metric_logger) |
| stats = {k: meter.global_avg for k, meter in metric_logger.meters.items() if meter.count > 0} |
| stats['Dice_Score'] = D_Score |
| return stats , 1 |
| elif text_t: |
| |
| |
| |
| |
| tokenized_preds = [word_tokenize(pred.lower()) for pred in text_p] |
| tokenized_refs = [[word_tokenize(ref.lower())] for ref in text_t] |
|
|
| |
| smoothing = SmoothingFunction().method4 |
| bleu_scores_4= corpus_bleu(tokenized_refs, tokenized_preds, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smoothing) |
|
|
| |
| bleu_scores_2 = corpus_bleu( |
| tokenized_refs, tokenized_preds, |
| weights=(0.5, 0.5, 0, 0), |
| smoothing_function=smoothing |
| ) |
|
|
| |
| bleu_scores_3 = corpus_bleu( |
| tokenized_refs, tokenized_preds, |
| weights=(0.33, 0.33, 0.33, 0), |
| smoothing_function=smoothing |
| ) |
| bleu_scores_1 = corpus_bleu(tokenized_refs, tokenized_preds, weights=(1, 0, 0, 0), smoothing_function=smoothing) |
| rouge_scores = compute_rouge(text_p, text_t) |
| with open("predictions_Backbone_QA.txt", "a", encoding="utf-8") as f: |
| for i in range(len(prompt_text)): |
| f.write(f"Prompt: {prompt_text[i]}\n") |
| f.write(f"Prediction: {text_p[i]}\n") |
| f.write(f"Target: {text_t[i]}\n") |
| f.write("-" * 50 + "\n") |
| print("Txt file save") |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| metric_logger.synchronize_between_processes() |
| |
| print("Averaged stats:", metric_logger) |
| print("\nROUGE Scores:") |
| for k, v in rouge_scores.items(): |
| print(f"{k}: {v:.4f}") |
| stats = {k: meter.global_avg for k, meter in metric_logger.meters.items() if meter.count > 0} |
| |
| stats['bleu_scores_1_result'] = bleu_scores_1 |
| stats['bleu_scores_2_result'] = bleu_scores_2 |
| stats['bleu_scores_3_result'] = bleu_scores_3 |
| stats['bleu_scores_4_result'] = bleu_scores_4 |
| |
| stats['rouge_scores'] = {k: round(v, 4) for k, v in rouge_scores.items()} |
| |
| |
| return stats , 1 |
| elif classification_t: |
| |
| f1_score_classification = f1_score(classification_t, classification_p, average='weighted') |
| metric_logger.synchronize_between_processes() |
| print("Averaged stats:", metric_logger) |
| stats = {k: meter.global_avg for k, meter in metric_logger.meters.items() if meter.count > 0} |
| stats['F1_Score'] = f1_score_classification |
| return stats , 1 |
| |
| |
| |
| elif len(all_gt_morphology) > 0: |
| all_pred_morphology = torch.cat(all_pred_morphology, dim=0) |
| all_gt_morphology = torch.cat(all_gt_morphology, dim=0) |
|
|
| pred_labels_flat = all_pred_morphology.reshape(-1) |
| gt_labels_flat = all_gt_morphology.reshape(-1) |
| valid_mask = gt_labels_flat != 4 |
|
|
| pred_labels_valid = pred_labels_flat[valid_mask] |
| gt_labels_valid = gt_labels_flat[valid_mask] |
|
|
| if pred_labels_valid.numel() > 0: |
| |
| |
|
|
| |
| all_pred_morphology_np = all_pred_morphology.cpu().numpy() |
| all_gt_morphology_np = all_gt_morphology.cpu().numpy() |
|
|
| |
| pred_labels_flat = all_pred_morphology_np.reshape(-1) |
| gt_labels_flat = all_gt_morphology_np.reshape(-1) |
|
|
| |
| valid_mask = gt_labels_flat != 4 |
| pred_labels_valid = pred_labels_flat[valid_mask] |
| gt_labels_valid = gt_labels_flat[valid_mask] |
|
|
| |
| labels = ["NC", "NS", "N", "C", "CB", "CV"] |
| f1_scores = {} |
|
|
| |
| pred_labels_valid_reshaped = pred_labels_valid.reshape(-1, len(labels)) |
| gt_labels_valid_reshaped = gt_labels_valid.reshape(-1, len(labels)) |
|
|
| |
| for i, label in enumerate(labels): |
| pred_label = pred_labels_valid_reshaped[:, i] |
| gt_label = gt_labels_valid_reshaped[:, i] |
| |
| |
| unique_values = set(pred_label).union(set(gt_label)) |
| if len(unique_values) <= 2: |
| f1_s = f1_score(gt_label, pred_label, average="binary") |
| else: |
| f1_s = f1_score(gt_label, pred_label, average="macro") |
| |
| f1_scores[label] = f1_s |
|
|
| |
| overall_f1 = f1_score( |
| gt_labels_valid_reshaped, pred_labels_valid_reshaped, average="macro" |
| ) |
| accuracy = accuracy_score(gt_labels_valid, pred_labels_valid) |
|
|
| |
| |
| |
| |
|
|
| else: |
| accuracy = float('nan') |
| f1_scores = [0.0] * 6 |
| overall_f1 = float('nan') |
| else: |
| accuracy = float('nan') |
| f1_scores = [0.0] * 6 |
| overall_f1 = float('nan') |
| |
| |
| morphology_accuracy = torch.tensor([accuracy], device=device) |
| if all(value == 0.0 for value in f1_scores): |
| f1_scores_list = f1_scores |
| else: |
| f1_scores_list = list(f1_scores.values()) |
|
|
| |
| morphology_f1_scores = torch.tensor(f1_scores_list, device=device) |
| overall_f1_tensor = torch.tensor([overall_f1], device=device) |
|
|
| if utils.is_dist_avail_and_initialized(): |
| torch.distributed.all_reduce(morphology_accuracy) |
| torch.distributed.all_reduce(morphology_f1_scores) |
| torch.distributed.all_reduce(overall_f1_tensor) |
|
|
| morphology_accuracy /= utils.get_world_size() |
| morphology_f1_scores /= utils.get_world_size() |
| overall_f1_tensor /= utils.get_world_size() |
|
|
|
|
| if args.save_results: |
| import os.path as osp |
| |
| |
| |
| savepath = osp.join(args.output_dir, 'results-{}.pkl'.format(utils.get_rank())) |
| print("Saving res to {}".format(savepath)) |
| torch.save(output_state_dict, savepath) |
|
|
| |
| metric_logger.synchronize_between_processes() |
| print("Averaged stats:", metric_logger) |
| if coco_evaluator is not None: |
| coco_evaluator.synchronize_between_processes() |
| if panoptic_evaluator is not None: |
| panoptic_evaluator.synchronize_between_processes() |
|
|
| |
| if coco_evaluator is not None: |
| coco_evaluator.accumulate() |
| coco_evaluator.summarize() |
| coco_evaluator.count_false_positives(iou_type="bbox") |
| coco_evaluator.count_false_positives_far(iou_type="bbox", iou_threshold=0.7, distance_threshold=50) |
| |
| cocoeval = coco_evaluator.coco_eval['bbox'] |
| iou_index = np.where(cocoeval.params.iouThrs == 0.5)[0] |
| area_index = cocoeval.params.areaRngLbl.index("all") |
| maxdet_index = cocoeval.params.maxDets.index(100) |
| print(cocoeval.eval["precision"][iou_index, :, :, area_index, maxdet_index].mean(axis = 1)) |
| |
| area_index = 0 |
| maxdet_index = 2 |
|
|
| |
| ar = cocoeval.eval['recall'][:, :, area_index, maxdet_index] |
| valid = ar[ar > -1] |
| average_recall = valid.mean() |
| |
|
|
| print("Average Recall @IoU=0.50:0.95 | area=all | maxDets=1 =", average_recall) |
| print(cocoeval.eval['recall'][5, :, area_index, maxdet_index]) |
|
|
| panoptic_res = None |
| if panoptic_evaluator is not None: |
| panoptic_res = panoptic_evaluator.summarize() |
| stats = {k: meter.global_avg for k, meter in metric_logger.meters.items() if meter.count > 0} |
| if coco_evaluator is not None: |
| if 'bbox' in postprocessors.keys(): |
| stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist() |
| if 'segm' in postprocessors.keys(): |
| stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist() |
| if panoptic_res is not None: |
| stats['PQ_all'] = panoptic_res["All"] |
| stats['PQ_th'] = panoptic_res["Things"] |
| stats['PQ_st'] = panoptic_res["Stuff"] |
| |
| panoptic_res = None |
| if panoptic_evaluator is not None: |
| panoptic_res = panoptic_evaluator.summarize() |
| stats = {k: meter.global_avg for k, meter in metric_logger.meters.items() if meter.count > 0} |
| |
| |
| |
| |
| stats['morphology_accuracy'] = morphology_accuracy.item() |
| |
| if any(pred_labels_valid): |
| for i, label_class in enumerate(labels): |
| stats[f"f1_{label_class}"] = morphology_f1_scores[i].item() |
|
|
| stats['overall_f1'] = overall_f1_tensor.item() |
| print("______ MORPHOLOGY RESULTS__________") |
| |
| |
| for label, f1 in zip(labels, morphology_f1_scores): |
| print(f"F1 Score for {label} | {f1:.4f}") |
|
|
| |
| |
| |
|
|
| |
| metric_logger.update(morphology_accuracy=morphology_accuracy.item()) |
| metric_logger.update( |
| morphology_accuracy=morphology_accuracy.item(), |
| overall_f1=overall_f1_tensor.item(), |
| ) |
| for i, label_class in enumerate(labels): |
| metric_logger.update(**{f"f1_{label_class}": morphology_f1_scores[i].item()}) |
|
|
| |
| |
| |
| if coco_evaluator is not None: |
| if 'bbox' in postprocessors.keys(): |
| stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist() |
| if 'segm' in postprocessors.keys(): |
| stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist() |
| if panoptic_res is not None: |
| stats['PQ_all'] = panoptic_res["All"] |
| stats['PQ_th'] = panoptic_res["Things"] |
| stats['PQ_st'] = panoptic_res["Stuff"] |
|
|
|
|
|
|
| return stats, coco_evaluator |
|
|
|
|
| @torch.no_grad() |
| def test(model, criterion, postprocessors, data_loader, base_ds, device, output_dir, wo_class_error=False, args=None, logger=None): |
| model.eval() |
| criterion.eval() |
|
|
| metric_logger = utils.MetricLogger(delimiter=" ") |
| |
| |
| header = 'Test:' |
|
|
| iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys()) |
| |
| |
|
|
| panoptic_evaluator = None |
| if 'panoptic' in postprocessors.keys(): |
| panoptic_evaluator = PanopticEvaluator( |
| data_loader.dataset.ann_file, |
| data_loader.dataset.ann_folder, |
| output_dir=os.path.join(output_dir, "panoptic_eval"), |
| ) |
|
|
| final_res = [] |
| for samples, targets in metric_logger.log_every(data_loader, 10, header, logger=logger): |
| samples = samples.to(device) |
|
|
| |
| targets = [{k: to_device(v, device) for k, v in t.items()} for t in targets] |
|
|
| outputs = model(samples) |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) |
| results = postprocessors['bbox'](outputs, orig_target_sizes, not_to_xyxy=True) |
| |
| if 'segm' in postprocessors.keys(): |
| target_sizes = torch.stack([t["size"] for t in targets], dim=0) |
| results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes) |
| res = {target['image_id'].item(): output for target, output in zip(targets, results)} |
| for image_id, outputs in res.items(): |
| _scores = outputs['scores'].tolist() |
| _labels = outputs['labels'].tolist() |
| _boxes = outputs['boxes'].tolist() |
| for s, l, b in zip(_scores, _labels, _boxes): |
| assert isinstance(l, int) |
| itemdict = { |
| "image_id": int(image_id), |
| "category_id": l, |
| "bbox": b, |
| "score": s, |
| } |
| final_res.append(itemdict) |
|
|
| if args.output_dir: |
| import json |
| with open(args.output_dir + f'/results{args.rank}.json', 'w') as f: |
| json.dump(final_res, f) |
|
|
| return final_res |
|
|