Uni-hema / engine.py
ryhm's picture
Upload 206 files
51067bb verified
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Train and eval functions used in main.py
"""
import math
from PIL import Image
import os
import sys
from typing import Iterable
import numpy as np
from util.utils import slprint, to_device
from sklearn.metrics import accuracy_score
import numpy as np
from itertools import zip_longest
import torch
# from medpy import metric
from compute_rouge import compute_rouge
import util.misc as utils
from dino_datasets.coco_eval import CocoEvaluator
from dino_datasets.panoptic_eval import PanopticEvaluator
from sklearn.metrics import accuracy_score, f1_score
import torch.nn.functional as F
import matplotlib.pyplot as plt
# from nltk.translate.bleu_score import corpus_bleu
# import nltk
# from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
# from nltk.tokenize import word_tokenize
from itertools import zip_longest
# from evaluate import load
# Only needed once
# nltk.download('punkt', download_dir='nltk_data')
# nltk.download('punkt', download_dir='/home/iml_abdul/nltk_data')
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
data_loader: Iterable,data_loader_2: Iterable,data_loader_3: Iterable,data_loader_4: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, max_norm: float = 0,
wo_class_error=False, lr_scheduler=None, args=None, logger=None, ema_m=None):
scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
try:
need_tgt_for_training = args.use_dn
except:
need_tgt_for_training = False
model.train()
criterion.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
if not wo_class_error:
metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = 200
_cnt = 0
loader1 = iter(metric_logger.log_every(data_loader, print_freq, header, logger=logger))
loader2 = iter(metric_logger.log_every(data_loader_2, print_freq, header, logger=logger))
loader3 = iter(metric_logger.log_every(data_loader_3, print_freq, header, logger=logger))
loader4 = iter(metric_logger.log_every(data_loader_4, print_freq, header, logger=logger))
# Select loader based on training step
if args.traning_step == 1:
data_iter = zip_longest(loader3, fillvalue=None)
elif args.traning_step == 3:
data_iter = zip(loader1, loader2, loader3)
elif args.traning_step == 4:
data_iter = zip_longest(loader1, fillvalue=None)
elif args.traning_step == 5:
data_iter = zip_longest(loader2, fillvalue=None)
elif args.traning_step == 6:
data_iter = zip_longest(loader4, fillvalue=None)
else:
raise ValueError("Invalid training step")
for batches in data_iter:
losses = torch.as_tensor(0.).to(device)
loss_value = torch.as_tensor(0.).to(device)
combined_loss_dict_scaled = {}
combined_loss_dict_unscaled = {}
# Handle multi-loader case (step 3)
if args.traning_step == 3:
batch_list = list(batches)
else:
batch_list = [batches]
for i, batch in enumerate(batch_list):
if batch is None:
continue
# Unpack data
if args.traning_step == 3:
samples, targets, prompt = batch
else:
samples, targets, prompt = batch[0]
samples = samples.to(device)
targets = [
{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()}
for t in targets
]
with torch.cuda.amp.autocast(enabled=args.amp):
outputs = model(samples, prompt, targets) if need_tgt_for_training else model(samples, prompt)
# Compute loss dict
# loss_dict = criterion(outputs, targets)
# weight_dict = criterion.weight_dict
# # Reduce across distributed processes
# loss_dict_reduced = utils.reduce_dict(loss_dict)
# # Unscaled loss components (for logging)
# if len(loss_dict_reduced)>1:
# for k, v in loss_dict_reduced.items():
# combined_loss_dict_unscaled[f'{k}_unscaled_{i+1}'] = v
# # Scaled loss components
# loss_dict_reduced_scaled = {
# k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict
# }
# # Sum scaled losses
# loss = sum(loss_dict_reduced_scaled.values())
# losses += loss
# loss_value += loss.item()
# print("batch image 1", targets[0]['image_id'],"batch image 2", targets[1]['image_id'])
loss_dict = criterion(outputs, targets,args)
weight_dict = criterion.weight_dict
# Reduce across distributed processes
loss_dict_reduced = utils.reduce_dict(loss_dict)
# Unscaled loss components (for logging)
for k, v in loss_dict_reduced.items():
combined_loss_dict_unscaled[f'{k}_unscaled_{i+1}'] = v
# Scaled loss components if weights exist
loss_dict_reduced_scaled = {
k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict
}
# Final loss computation
if loss_dict_reduced_scaled:
loss = sum(loss_dict_reduced_scaled.values())
else:
# print(f"⚠️ No keys in loss_dict_reduced matched weight_dict at step {i+1}.")
# Fallback: use unscaled loss
loss = sum(loss_dict_reduced.values())
# Accumulate loss
losses= losses+loss
loss_value = loss_value + loss.item()
for k, v in loss_dict_reduced_scaled.items():
combined_loss_dict_scaled[f'{k}_scaled_{i+1}'] = v
# pred_morphology = outputs['pred_morphology']
# logit = outputs['pred_logits']
# pred_box = outputs['pred_boxes']
# loss_dict = criterion(outputs, targets)
# weight_dict = criterion.weight_dict
# losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
# loss_morphology = loss_dict.get('loss_morphology', None)
# if loss_morphology is not None:
# print(f"Loss Morphology: {loss_morphology.item()}")
# reduce losses over all GPUs for logging purposes
# loss_dict_reduced = utils.reduce_dict(loss_dict)
# loss_dict_reduced_unscaled = {f'{k}_unscaled': v
# for k, v in loss_dict_reduced.items()}
# loss_dict_reduced_scaled = {k: v * weight_dict[k]
# for k, v in loss_dict_reduced.items() if k in weight_dict}
# losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
# loss_value = losses_reduced_scaled.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
print(loss_dict_reduced)
sys.exit(1)
# amp backward function
if args.amp:
optimizer.zero_grad()
# optimizer_bart.zero_grad()
scaler.scale(losses).backward()
if max_norm > 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
# scaler.step(optimizer_bart)
scaler.update()
else:
# original backward function
optimizer.zero_grad()
# optimizer_bart.zero_grad()
if losses != 0:
losses.backward(retain_graph=True)
# if max_norm > 0:
# torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
optimizer.step()
# optimizer_bart.step()
if args.onecyclelr:
lr_scheduler.step()
# lr_scheduler_bart.step()
if args.use_ema:
if epoch >= args.ema_epoch:
ema_m.update(model)
# metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled)
metric_logger.update(loss=loss_value, **{
**{k: v.item() for k, v in combined_loss_dict_unscaled.items()},
**{k: v.item() for k, v in combined_loss_dict_scaled.items()}
})
if 'class_error' in loss_dict_reduced:
metric_logger.update(class_error=loss_dict_reduced['class_error'])
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
_cnt += 1
if args.debug:
if _cnt % 15 == 0:
print("BREAK!"*5)
break
if getattr(criterion, 'loss_weight_decay', False):
criterion.loss_weight_decay(epoch=epoch)
if getattr(criterion, 'tuning_matching', False):
criterion.tuning_matching(epoch)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
resstat = {k: meter.global_avg for k, meter in metric_logger.meters.items() if meter.count > 0}
if getattr(criterion, 'loss_weight_decay', False):
resstat.update({f'weight_{k}': v for k,v in criterion.weight_dict.items()})
return resstat
def plot_pred_mask(pred_mask, threshold=0.5, title="Predicted Mask"):
# If batch dimension exists, remove it
if pred_mask.dim() == 4:
pred_mask = pred_mask[0, 0]
elif pred_mask.dim() == 3:
pred_mask = pred_mask[0]
# Apply sigmoid and threshold if it's not binary yet
pred_mask = pred_mask.sigmoid() if pred_mask.max() > 1 else pred_mask
binary_mask = (pred_mask > threshold).float()
# Convert to numpy for plotting
binary_mask_np = binary_mask.detach().cpu().numpy()
# Plotting
plt.figure(figsize=(6, 6))
plt.imshow(binary_mask_np, cmap='gray')
plt.title(title)
plt.axis('off')
plt.show()
def dice_score2(pred, target, epsilon=1e-6):
# def dice_score2(pred, target, epsilon=1e-6):
# Ensure inputs are tensors (handle list of tensors)
if isinstance(pred, list):
pred = torch.stack(pred) if isinstance(pred[0], torch.Tensor) else torch.tensor(pred)
if isinstance(target, list):
target = torch.stack(target) if isinstance(target[0], torch.Tensor) else torch.tensor(target)
pred = pred.float()
target = target.float()
# Ensure binary masks
pred = (pred > 0.5).float()
target = (target > 0.5).float()
intersection = (pred.cpu() * target.cpu()).sum()
union = pred.cpu().sum() + target.cpu().sum()
dice = (2. * intersection + epsilon) / (union + epsilon)
return dice.item()/len(pred)
def multiclass_dice_score(pred, target, num_classes, epsilon=1e-6):
"""
pred: (B, C, H, W) raw logits or probabilities
target: (B, H, W) class indices 0..C-1
"""
# Softmax predictions → probabilities
pred_soft = F.softmax(pred, dim=1)
# Convert target to one-hot (B, C, H, W)
target_onehot = F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float()
# Compute Dice per class
dice_scores = []
for c in range(num_classes):
pred_c = pred_soft[:, c, :, :]
target_c = target_onehot[:, c, :, :]
intersection = (pred_c * target_c).sum(dim=(1,2))
union = pred_c.sum(dim=(1,2)) + target_c.sum(dim=(1,2))
dice = ((2 * intersection + epsilon) / (union + epsilon)).mean() # mean over batch
dice_scores.append(dice)
# Mean Dice over all classes
mean_dice = torch.mean(torch.stack(dice_scores))
return mean_dice, dice_scores
def focal_loss(pred, target, alpha=0.25, gamma=2.):
pred = pred.view(-1)
target = target.view(-1)
bce_loss = F.binary_cross_entropy(pred, target, reduction='none')
pt = torch.exp(-bce_loss)
focal_loss = alpha * (1 - pt) ** gamma * bce_loss
return focal_loss.mean()
@torch.no_grad()
def map_label(lbl):
if 0 <= lbl <= 14:
return 24
elif lbl in [24, 29]:
return 24
elif 15 <= lbl <= 20:
return 23
elif lbl == 22:
return 23
elif 25 <= lbl <= 28:
return 23
else:
return lbl
def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir, wo_class_error=False, args=None, logger=None):
try:
need_tgt_for_training = args.use_dn
except:
need_tgt_for_training = False
model.eval()
criterion.eval()
all_gt_morphology = []
all_pred_morphology = []
seg_mask_t=[]
classification_t=[]
classification_p=[]
# classification_t_feat=[]
# classification_p_feat=[]
Dice_score_all=[]
text_t=[]
text_p=[]
prompt_text=[]
metric_logger = utils.MetricLogger(delimiter=" ")
if not wo_class_error:
metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
header = 'Test:'
iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys())
useCats = True
try:
useCats = args.useCats
except:
useCats = True
if not useCats:
print("useCats: {} !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!".format(useCats))
if args.eval_type == "det":
coco_evaluator = CocoEvaluator(base_ds, iou_types, useCats=useCats) #coco eveluation for detetion change
# coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75]
panoptic_evaluator = None
if 'panoptic' in postprocessors.keys():
panoptic_evaluator = PanopticEvaluator(
data_loader.dataset.ann_file,
data_loader.dataset.ann_folder,
output_dir=os.path.join(output_dir, "panoptic_eval"),
)
_cnt = 0
output_state_dict = {} # for debug only
for samples, targets,prompt in metric_logger.log_every(data_loader, 100, header, logger=logger):
samples = samples.to(device)
# targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
targets = [
{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()}
for t in targets]
# targets = [{k: to_device(v, device) for k, v in t.items()} for t in targets]
# prompt=(["myeloblast","lymphoblast","neutrophil","atypical lymphocyte","promonocyte","monoblast","lymphocyte","myelocyte","abnormal promyelocyte","monocyte","metamyelocyte","eosinophil","basophil","none","gametocyte","schizont","trophozoite","ring","concentrated_leishman_parasite","leishman_parasite","Platelet", "Sickle Cells","RBC", "WBC",],)
# prompt= (["Detect for all Hematology"],)
# prompt= (['neutrophil'],)
# for t in targets:
# if t["labels"].numel() == 0: # if no labels detetion for
# print(f"No labels for image_id: {t['image_id'].item()}")
with torch.cuda.amp.autocast(enabled=args.amp):
if need_tgt_for_training:
outputs = model(samples,prompt, targets)
else:
# outputs = model(samples,prompt, targets)
outputs = model(samples,prompt)
# outputs = model(samples)
loss_dict = criterion(outputs, targets,args)
weight_dict = criterion.weight_dict
# Extract morphology predictions
pred_morphology = outputs['pred_morphology'] # Shape: [batch_size, num_queries, total_morphology_classes]
num_attributes = 6
num_classes_per_attribute = 2 # Valid labels are 0 and 1
pred_morphology = pred_morphology.view(
pred_morphology.size(0),
pred_morphology.size(1),
num_attributes,
num_classes_per_attribute
)
morphology_probs = F.softmax(pred_morphology, dim=-1)
pred_morphology_labels = morphology_probs.argmax(-1) # Shape: [batch_size, num_queries, num_attributes]
# Collect ground truth and predicted morphology labels
for i, target in enumerate(targets):
if 'segmentation' in target:
seg_mask_t.append(target['segmentation'])
pred_masks = outputs['pred_mask'] # shape: (B, 64, 64)
# Upsample to match ground truth
# pred_masks = F.interpolate(pred_maskss.unsqueeze(1), size=(512, 512), mode='bilinear', align_corners=False)
# shape: (B, 1, 512, 512)
# pred_probss = torch.sigmoid(pred_maskss)
# seg_mask_p.append(pred_probs)
# pred_np = (pred_probs > 0.5).detach().cpu().numpy().astype(np.bool_)
# gt_np = target['binary_mask'].detach().cpu().numpy().astype(np.bool_)
# Dice_score_all.append(metric.binary.dc(pred_np, gt_np))
pred_probs = torch.sigmoid(pred_masks)
Dice_score_all.append(dice_score2(pred_probs, target['binary_mask']))
for j in range(pred_probs.shape[0]):
mask = (pred_probs[j,0] * 255).cpu().detach().numpy().astype('uint8') # scale 0-255
img = Image.fromarray(mask)
image_id = target['mask'].item()
img.save(os.path.join("Malaria-Detection-2019_test_data", f"pred_mask_{image_id}.png"))
if 'classification' in target:
logits_classification = outputs['pred_image_class']
logits_classification_feat = outputs['pred_image_feat']
# shape: [1, num_classes]
probs = F.softmax(logits_classification, dim=1) # get probabilities
pred_class = probs.argmax(dim=1).item() # get predicted class index
target_class = target['category_id'].item()
#target_class_feat = target['label_embeding'].item()
# print(target['category_id'].item()," and pred is ", pred_class )
classification_t.append(target_class)
classification_p.append(pred_class)
# classification_t_feat.append(target_class_feat)
# classification_p_feat.append(logits_classification_feat)# shape: (B, 64, 64)
if 'masked_traning' in target:
pred_text = outputs['pred_text'] # shape: [1, num_classes]
# probs = F.softmax(logits_classification, dim=1) # get probabilities
# pred_class = probs.argmax(dim=1).item() # get predicted class index
target_text = outputs['completed_text']
# print(target['category_id'].item()," and pred is ", pred_class )
text_t.extend(target_text)
text_p.extend(pred_text) # shape: (B, 64, 64)
prompt_text.append(target['prompt'])
if 'morphology' in target:
gt_morphology = target['morphology'] # Shape: [num_objects, num_attributes]
# Get matching indices
indices = criterion.matcher(outputs, [target])[0]
src_idx = criterion._get_src_permutation_idx([indices])[1]
tgt_idx = criterion._get_tgt_permutation_idx([indices])[1]
# Get matched predictions and ground truths
pred_labels = pred_morphology_labels[i, src_idx] # Shape: [num_matched_objects, num_attributes]
gt_labels = gt_morphology[tgt_idx] # Shape: [num_matched_objects, num_attributes]
all_pred_morphology.append(pred_labels)
all_gt_morphology.append(gt_labels)
# reduce losses over all GPUs for logging purposes
loss_dict_reduced = utils.reduce_dict(loss_dict)
loss_dict_reduced_scaled = {k: v * weight_dict[k]
for k, v in loss_dict_reduced.items() if k in weight_dict}
loss_dict_reduced_unscaled = {f'{k}_unscaled': v
for k, v in loss_dict_reduced.items()}
metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()),
**loss_dict_reduced_scaled,
**loss_dict_reduced_unscaled)
if 'class_error' in loss_dict_reduced:
metric_logger.update(class_error=loss_dict_reduced['class_error'])
if seg_mask_t:
E_Score= 1
elif classification_t:
E_Score= 1
elif text_t:
E_Score= 1
else:
orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
results = postprocessors['bbox'](outputs, orig_target_sizes)
score_threshold = 0.1 # 👈 Change this value to your desired threshold
# Apply the threshold
for result in results:
keep = result["scores"] > score_threshold
for key in result.keys():
result[key] = result[key][keep]
# [scores: [100], labels: [100], boxes: [100, 4]] x B
if 'segm' in postprocessors.keys():
target_sizes = torch.stack([t["size"] for t in targets], dim=0)
results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes)
res = {target['image_id'].item(): output for target, output in zip(targets, results)}
# if prompt == (['complete blood count'],):
# for img_id, output in res.items():
# raw_labels = output["labels"]
# print(raw_labels)
# mapped_labels = torch.tensor([map_label(int(lbl.item())) for lbl in raw_labels],
# device=raw_labels.device)
# res[img_id]["labels"] = mapped_labels
if coco_evaluator is not None:
coco_evaluator.update(res)
if panoptic_evaluator is not None:
res_pano = postprocessors["panoptic"](outputs, target_sizes, orig_target_sizes)
for i, target in enumerate(targets):
image_id = target["image_id"].item()
file_name = f"{image_id:012d}.png"
res_pano[i]["image_id"] = image_id
res_pano[i]["file_name"] = file_name
panoptic_evaluator.update(res_pano)
if args.save_results:
# res_score = outputs['res_score']
# res_label = outputs['res_label']
# res_bbox = outputs['res_bbox']
# res_idx = outputs['res_idx']
for i, (tgt, res, outbbox) in enumerate(zip(targets, results, outputs['pred_boxes'])):
"""
pred vars:
K: number of bbox pred
score: Tensor(K),
label: list(len: K),
bbox: Tensor(K, 4)
idx: list(len: K)
tgt: dict.
"""
# compare gt and res (after postprocess)
gt_bbox = tgt['boxes']
gt_label = tgt['labels']
gt_info = torch.cat((gt_bbox, gt_label.unsqueeze(-1)), 1)
# img_h, img_w = tgt['orig_size'].unbind()
# scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=0)
# _res_bbox = res['boxes'] / scale_fct
_res_bbox = outbbox
_res_prob = res['scores']
_res_label = res['labels']
res_info = torch.cat((_res_bbox, _res_prob.unsqueeze(-1), _res_label.unsqueeze(-1)), 1)
# import ipdb;ipdb.set_trace()
if 'gt_info' not in output_state_dict:
output_state_dict['gt_info'] = []
output_state_dict['gt_info'].append(gt_info.cpu())
if 'res_info' not in output_state_dict:
output_state_dict['res_info'] = []
output_state_dict['res_info'].append(res_info.cpu())
# # for debug only
# import random
# if random.random() > 0.7:
# print("Now let's break")
# break
_cnt += 1
if args.debug:
if _cnt % 15 == 0:
print("BREAK!"*5)
break
# After all batches are processed
# if len(all_pred_morphology) > 0:
# all_pred_morphology = torch.cat(all_pred_morphology, dim=0)
# all_gt_morphology = torch.cat(all_gt_morphology, dim=0)
# pred_labels_flat = all_pred_morphology.reshape(-1)
# gt_labels_flat = all_gt_morphology.reshape(-1)
# valid_mask = gt_labels_flat != 4 # Ignore labels equal to 4
# pred_labels_valid = pred_labels_flat[valid_mask]
# gt_labels_valid = gt_labels_flat[valid_mask]
# if pred_labels_valid.numel() > 0:
# # Compute overall accuracy
# accuracy = accuracy_score(gt_labels_valid.cpu().numpy(), pred_labels_valid.cpu().numpy())
# else:
# accuracy = float('nan')
# else:
# accuracy = float('nan')
# # Reduce accuracy across all processes
# morphology_accuracy = torch.tensor([accuracy], device=device)
# if utils.is_dist_avail_and_initialized():
# torch.distributed.all_reduce(morphology_accuracy)
# morphology_accuracy /= utils.get_world_size()
if seg_mask_t:
# pred_masks = outputs['pred_mask'][:, 0] # shape: (B, 64, 64)
# # Upsample to match ground truth
# pred_masks = F.interpolate(pred_masks.unsqueeze(1), size=(512, 512), mode='bilinear', align_corners=False)
# # shape: (B, 1, 512, 512)
# # seg_mask_t = seg_mask_t.to(pred_masks.device)
# pred_probs = torch.sigmoid(pred_masks)
# plot_pred_mask(seg_mask_p[0], title="Predicted Segmentation Mask")
# plot_pred_mask(seg_mask_t[0], title="Predicted Segmentation Mask")
D_Score= sum(Dice_score_all) / len(Dice_score_all)
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
stats = {k: meter.global_avg for k, meter in metric_logger.meters.items() if meter.count > 0}
stats['Dice_Score'] = D_Score
return stats , 1
elif text_t:
# bleu = evaluate.load("bleu")
# rouge = evaluate.load("rouge")
# bertscore = evaluate.load("bertscore")
tokenized_preds = [word_tokenize(pred.lower()) for pred in text_p]
tokenized_refs = [[word_tokenize(ref.lower())] for ref in text_t]
# Compute BLEU score with smoothing
smoothing = SmoothingFunction().method4
bleu_scores_4= corpus_bleu(tokenized_refs, tokenized_preds, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smoothing)
# bleu_scores_4 = corpus_bleu(text_t, text_p, weights=(0.25, 0.25, 0.25, 0.25))+
bleu_scores_2 = corpus_bleu(
tokenized_refs, tokenized_preds,
weights=(0.5, 0.5, 0, 0),
smoothing_function=smoothing
)
# BLEU-3
bleu_scores_3 = corpus_bleu(
tokenized_refs, tokenized_preds,
weights=(0.33, 0.33, 0.33, 0),
smoothing_function=smoothing
)
bleu_scores_1 = corpus_bleu(tokenized_refs, tokenized_preds, weights=(1, 0, 0, 0), smoothing_function=smoothing)
rouge_scores = compute_rouge(text_p, text_t)
with open("predictions_Backbone_QA.txt", "a", encoding="utf-8") as f:
for i in range(len(prompt_text)):
f.write(f"Prompt: {prompt_text[i]}\n")
f.write(f"Prediction: {text_p[i]}\n")
f.write(f"Target: {text_t[i]}\n")
f.write("-" * 50 + "\n")
print("Txt file save")
# print(f"Avg BLEU Score: {sum(bleu_scores) / len(bleu_scores):.2f}")
# Compute BLEU
# bleu_result = sum(bleu_scores) / len(bleu_scores)
# print("BLEU:", bleu_result)
# Compute ROUGE
# rouge_result = rouge.compute(predictions=text_p, references=text_t)
# # print("ROUGE:", rouge_result)
# # Compute BERTScore
# bertscore_result = bertscore.compute(predictions=text_p, references=text_t, lang="en")
# bertscore_avg = sum(bertscore_result["f1"]) / len(bertscore_result["f1"])
# print("BERTScore F1 (average):", bertscore_avg)
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
print("\nROUGE Scores:")
for k, v in rouge_scores.items():
print(f"{k}: {v:.4f}")
stats = {k: meter.global_avg for k, meter in metric_logger.meters.items() if meter.count > 0}
stats['bleu_scores_1_result'] = bleu_scores_1
stats['bleu_scores_2_result'] = bleu_scores_2
stats['bleu_scores_3_result'] = bleu_scores_3
stats['bleu_scores_4_result'] = bleu_scores_4
stats['rouge_scores'] = {k: round(v, 4) for k, v in rouge_scores.items()}
# stats['rouge_result'] = rouge_result
# stats['bertscore_avg'] = bertscore_avg
return stats , 1
elif classification_t:
f1_score_classification = f1_score(classification_t, classification_p, average='weighted')
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
stats = {k: meter.global_avg for k, meter in metric_logger.meters.items() if meter.count > 0}
stats['F1_Score'] = f1_score_classification
return stats , 1
elif len(all_gt_morphology) > 0:
all_pred_morphology = torch.cat(all_pred_morphology, dim=0)
all_gt_morphology = torch.cat(all_gt_morphology, dim=0)
pred_labels_flat = all_pred_morphology.reshape(-1)
gt_labels_flat = all_gt_morphology.reshape(-1)
valid_mask = gt_labels_flat != 4 # Ignore labels equal to 4
pred_labels_valid = pred_labels_flat[valid_mask]
gt_labels_valid = gt_labels_flat[valid_mask]
if pred_labels_valid.numel() > 0:
# Compute overall accuracy
# accuracy = accuracy_score(gt_labels_valid.cpu().numpy(), pred_labels_valid.cpu().numpy())
# Compute F1 Score for each class
all_pred_morphology_np = all_pred_morphology.cpu().numpy()
all_gt_morphology_np = all_gt_morphology.cpu().numpy()
# Flatten predictions and ground truths for masking
pred_labels_flat = all_pred_morphology_np.reshape(-1)
gt_labels_flat = all_gt_morphology_np.reshape(-1)
# Apply mask to exclude labels with value 4
valid_mask = gt_labels_flat != 4
pred_labels_valid = pred_labels_flat[valid_mask]
gt_labels_valid = gt_labels_flat[valid_mask]
# Initialize variables
labels = ["NC", "NS", "N", "C", "CB", "CV"]
f1_scores = {}
# Convert valid predictions and ground truths back to 2D shape
pred_labels_valid_reshaped = pred_labels_valid.reshape(-1, len(labels))
gt_labels_valid_reshaped = gt_labels_valid.reshape(-1, len(labels))
# Calculate F1 Score for each label
for i, label in enumerate(labels):
pred_label = pred_labels_valid_reshaped[:, i]
gt_label = gt_labels_valid_reshaped[:, i]
# Ensure binary nature for each column
unique_values = set(pred_label).union(set(gt_label))
if len(unique_values) <= 2: # If binary
f1_s = f1_score(gt_label, pred_label, average="binary")
else: # If multiclass, calculate macro F1
f1_s = f1_score(gt_label, pred_label, average="macro")
f1_scores[label] = f1_s
# Calculate combined F1 score (macro-average across all labels)
overall_f1 = f1_score(
gt_labels_valid_reshaped, pred_labels_valid_reshaped, average="macro"
)
accuracy = accuracy_score(gt_labels_valid, pred_labels_valid)
# # Print results
# for label, f1 in per_label_f1.items():
# print(f"F1 Score for {label}: {f1:.4f}")
# print(f"Combined Macro F1 Score: {combined_f1:.4f}")
else:
accuracy = float('nan')
f1_scores = [0.0] * 6
overall_f1 = float('nan')
else:
accuracy = float('nan')
f1_scores = [0.0] * 6
overall_f1 = float('nan')
# Reduce accuracy and F1 scores across all processes
morphology_accuracy = torch.tensor([accuracy], device=device)
if all(value == 0.0 for value in f1_scores):
f1_scores_list = f1_scores
else:
f1_scores_list = list(f1_scores.values())
# Now create a tensor from the list of F1 scores
morphology_f1_scores = torch.tensor(f1_scores_list, device=device)
overall_f1_tensor = torch.tensor([overall_f1], device=device)
if utils.is_dist_avail_and_initialized():
torch.distributed.all_reduce(morphology_accuracy)
torch.distributed.all_reduce(morphology_f1_scores)
torch.distributed.all_reduce(overall_f1_tensor)
morphology_accuracy /= utils.get_world_size()
morphology_f1_scores /= utils.get_world_size()
overall_f1_tensor /= utils.get_world_size()
if args.save_results:
import os.path as osp
# output_state_dict['gt_info'] = torch.cat(output_state_dict['gt_info'])
# output_state_dict['res_info'] = torch.cat(output_state_dict['res_info'])
savepath = osp.join(args.output_dir, 'results-{}.pkl'.format(utils.get_rank()))
print("Saving res to {}".format(savepath))
torch.save(output_state_dict, savepath)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
if coco_evaluator is not None:
coco_evaluator.synchronize_between_processes()
if panoptic_evaluator is not None:
panoptic_evaluator.synchronize_between_processes()
# accumulate predictions from all images
if coco_evaluator is not None:
coco_evaluator.accumulate()
coco_evaluator.summarize()
coco_evaluator.count_false_positives(iou_type="bbox")
coco_evaluator.count_false_positives_far(iou_type="bbox", iou_threshold=0.7, distance_threshold=50)
cocoeval = coco_evaluator.coco_eval['bbox']
iou_index = np.where(cocoeval.params.iouThrs == 0.5)[0]
area_index = cocoeval.params.areaRngLbl.index("all")
maxdet_index = cocoeval.params.maxDets.index(100)
print(cocoeval.eval["precision"][iou_index, :, :, area_index, maxdet_index].mean(axis = 1))
# print(cocoeval.eval['recall'][iou_index, :, area_index, maxdet_index])
area_index = 0 # 'all'
maxdet_index = 2 # maxDets=1 (index 0 in [1, 10, 100])
# Average recall over IoUs for all categories
ar = cocoeval.eval['recall'][:, :, area_index, maxdet_index] # shape: [10 IoUs, num_classes]
valid = ar[ar > -1] # Filter out invalid entries (=-1)
average_recall = valid.mean()
print("Average Recall @IoU=0.50:0.95 | area=all | maxDets=1 =", average_recall)
print(cocoeval.eval['recall'][5, :, area_index, maxdet_index])
panoptic_res = None
if panoptic_evaluator is not None:
panoptic_res = panoptic_evaluator.summarize()
stats = {k: meter.global_avg for k, meter in metric_logger.meters.items() if meter.count > 0}
if coco_evaluator is not None:
if 'bbox' in postprocessors.keys():
stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist()
if 'segm' in postprocessors.keys():
stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist()
if panoptic_res is not None:
stats['PQ_all'] = panoptic_res["All"]
stats['PQ_th'] = panoptic_res["Things"]
stats['PQ_st'] = panoptic_res["Stuff"]
panoptic_res = None
if panoptic_evaluator is not None:
panoptic_res = panoptic_evaluator.summarize()
stats = {k: meter.global_avg for k, meter in metric_logger.meters.items() if meter.count > 0}
# Update stats
stats['morphology_accuracy'] = morphology_accuracy.item()
# if pred_labels_valid.numel() > 0:
if any(pred_labels_valid):
for i, label_class in enumerate(labels): # Assuming `labels` is your list of class names
stats[f"f1_{label_class}"] = morphology_f1_scores[i].item()
stats['overall_f1'] = overall_f1_tensor.item()
print("______ MORPHOLOGY RESULTS__________")
# print(f"stats['morphology_accuracy'] = {morphology_accuracy.item()}")
# print(f"stats['morphology_f1_scores'] = {morphology_f1_scores.tolist()}")
for label, f1 in zip(labels, morphology_f1_scores):
print(f"F1 Score for {label} | {f1:.4f}")
# Print overall F1 score
# print(f"stats['overall_f1'] = {overall_f1_tensor.item()}")
# Optionally, update metric logger
metric_logger.update(morphology_accuracy=morphology_accuracy.item())
metric_logger.update(
morphology_accuracy=morphology_accuracy.item(),
overall_f1=overall_f1_tensor.item(),
)
for i, label_class in enumerate(labels): # Assuming `labels` is your list of class names
metric_logger.update(**{f"f1_{label_class}": morphology_f1_scores[i].item()})
# for label_class in labels:
# print(f"F1 Score for {label_class}: {stats[f'f1_{label_class}']:.4f}")
if coco_evaluator is not None:
if 'bbox' in postprocessors.keys():
stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist()
if 'segm' in postprocessors.keys():
stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist()
if panoptic_res is not None:
stats['PQ_all'] = panoptic_res["All"]
stats['PQ_th'] = panoptic_res["Things"]
stats['PQ_st'] = panoptic_res["Stuff"]
return stats, coco_evaluator
@torch.no_grad()
def test(model, criterion, postprocessors, data_loader, base_ds, device, output_dir, wo_class_error=False, args=None, logger=None):
model.eval()
criterion.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
# if not wo_class_error:
# metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
header = 'Test:'
iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys())
# coco_evaluator = CocoEvaluator(base_ds, iou_types)
# coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75]
panoptic_evaluator = None
if 'panoptic' in postprocessors.keys():
panoptic_evaluator = PanopticEvaluator(
data_loader.dataset.ann_file,
data_loader.dataset.ann_folder,
output_dir=os.path.join(output_dir, "panoptic_eval"),
)
final_res = []
for samples, targets in metric_logger.log_every(data_loader, 10, header, logger=logger):
samples = samples.to(device)
# targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
targets = [{k: to_device(v, device) for k, v in t.items()} for t in targets]
outputs = model(samples)
# loss_dict = criterion(outputs, targets)
# weight_dict = criterion.weight_dict
# # reduce losses over all GPUs for logging purposes
# loss_dict_reduced = utils.reduce_dict(loss_dict)
# loss_dict_reduced_scaled = {k: v * weight_dict[k]
# for k, v in loss_dict_reduced.items() if k in weight_dict}
# loss_dict_reduced_unscaled = {f'{k}_unscaled': v
# for k, v in loss_dict_reduced.items()}
# metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()),
# **loss_dict_reduced_scaled,
# **loss_dict_reduced_unscaled)
# if 'class_error' in loss_dict_reduced:
# metric_logger.update(class_error=loss_dict_reduced['class_error'])
orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
results = postprocessors['bbox'](outputs, orig_target_sizes, not_to_xyxy=True)
# [scores: [100], labels: [100], boxes: [100, 4]] x B
if 'segm' in postprocessors.keys():
target_sizes = torch.stack([t["size"] for t in targets], dim=0)
results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes)
res = {target['image_id'].item(): output for target, output in zip(targets, results)}
for image_id, outputs in res.items():
_scores = outputs['scores'].tolist()
_labels = outputs['labels'].tolist()
_boxes = outputs['boxes'].tolist()
for s, l, b in zip(_scores, _labels, _boxes):
assert isinstance(l, int)
itemdict = {
"image_id": int(image_id),
"category_id": l,
"bbox": b,
"score": s,
}
final_res.append(itemdict)
if args.output_dir:
import json
with open(args.output_dir + f'/results{args.rank}.json', 'w') as f:
json.dump(final_res, f)
return final_res