| import os |
| import numpy as np |
| import torch |
|
|
| from Generator.utils import fast_3D_interp_torch, myzoom_torch |
| from Trainer.models import build_model, build_inpaint_model |
| from utils.checkpoint import load_checkpoint |
| import utils.misc as utils |
|
|
| device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu' |
|
|
|
|
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
|
|
|
|
|
|
| submit_cfg_file = '/autofs/vast/lemon/temp_stuff/brainfm/cfg/defaults/submit.yaml' |
| default_gen_cfg_file = '/autofs/vast/lemon/temp_stuff/brainfm/cfg/defaults/default_gen.yaml' |
|
|
| default_train_cfg_file = '/autofs/vast/lemon/temp_stuff/brainfm/cfg/defaults/default_train.yaml' |
| default_val_file = '/autofs/vast/lemon/temp_stuff/brainfm/cfg/defaults/default_val.yaml' |
|
|
|
|
| gen_cfg_dir = '' |
| train_cfg_dir = '' |
| atlas_path = '/autofs/vast/lemon/temp_stuff/brainfm/files/gca.mgz' |
|
|
|
|
| MNI, aff2 = utils.MRIread(atlas_path) |
| A = np.linalg.inv(aff2) |
| A = torch.tensor(A, device=device, dtype=torch.float32) |
| MNI = torch.tensor(MNI, device = device, dtype = torch.float32) |
|
|
| def get_deformed_atlas(brain_labels, regx, regy, regz): |
| M = brain_labels>0 |
| xx = 100 * regx[M] |
| yy = 100 * regy[M] |
| zz = 100 * regz[M] |
| ii = A[0, 0] * xx + A[0, 1] * yy + A[0, 2] * zz + A[0, 3] |
| jj = A[1, 0] * xx + A[1, 1] * yy + A[1, 2] * zz + A[1, 3] |
| kk = A[2, 0] * xx + A[2, 1] * yy + A[2, 2] * zz + A[2, 3] |
|
|
| vals = fast_3D_interp_torch(MNI, ii, jj, kk, 'linear') |
| DEF = torch.zeros_like(regx) |
| DEF[M] = vals |
| return DEF |
|
|
|
|
| def zero_crop(orig, tol = 0, crop_range_lst = None, save_path = None): |
| |
| ''' |
| crop_range_lst: [[x0, y0, z0], [x1, y1, z1]] |
| ''' |
| |
| if crop_range_lst is None: |
| |
| |
| mask = orig > tol |
|
|
| |
| coords = torch.argwhere(mask) |
|
|
| |
| |
| x0, y0, z0 = coords.min(dim=0)[0] |
| x1, y1, z1 = coords.max(dim=0)[0] + 1 |
|
|
| |
| |
| |
| |
| |
| else: |
| |
| [[x0, y0, z0], [x1, y1, z1]] = crop_range_lst |
|
|
| cropped = orig[x0 : x1, y0 : y1, z0 : z1] |
| |
| return cropped |
|
|
|
|
| def tiling(img, stride = [40, 40, 40], win_size = [160, 160, 160], zero_crop_first = False): |
|
|
| orig_shp = img.shape[2:] |
|
|
| |
| if zero_crop_first: |
| |
| img = zero_crop(img[0, 0])[None, None] |
| orig_shp = img.shape[2:] |
| print('shape after zero croppping:', orig_shp) |
|
|
| |
| x_start, y_start, z_start = 0, 0, 0 |
| x_end = min(x_start + win_size[0], orig_shp[0]) |
| y_end = min(y_start + win_size[1], orig_shp[1]) |
| z_end = min(z_start + win_size[2], orig_shp[2]) |
| x_list, y_list, z_list = [(x_start, x_end)], [(y_start, y_end)], [([z_start, z_end])] |
|
|
| while x_end < orig_shp[0]: |
| x_start = min(x_end, orig_shp[0] - stride[0]) |
| x_end = min(x_start + stride[0], orig_shp[0]) |
| x_list.append(([x_start, x_end])) |
|
|
| while y_end < orig_shp[1]: |
| y_start = min(y_end, orig_shp[1] - stride[1]) |
| y_end = min(y_start + stride[1], orig_shp[1]) |
| y_list.append(([y_start, y_end])) |
|
|
| while z_end < orig_shp[2]: |
| z_start = min(z_end, orig_shp[2] - stride[2]) |
| z_end = min(z_start + stride[2], orig_shp[2]) |
| z_list.append(([z_start, z_end])) |
|
|
| img_list = [] |
| cnt = torch.zeros_like(img[0, 0]) |
| for (x_start, x_end) in x_list: |
| for (y_start, y_end) in y_list: |
| for (z_start, z_end) in z_list: |
| curr_img = img[:, :, x_start : x_end, y_start : y_end, z_start : z_end] |
| curr_range = [(x_start, x_end), (y_start, y_end), (z_start, z_end)] |
| img_list.append((curr_img, curr_range)) |
| cnt[x_start : x_end, y_start : y_end, z_start : z_end] += 1 |
| print('Patch #'+str(len(img_list)), curr_img.shape[2:], curr_range) |
|
|
| return img_list, cnt |
| |
|
|
|
|
| def center_crop(img, win_size = [220, 220, 220], zero_crop_first = False, aff=np.eye(4)): |
| |
| if len(img.shape) == 4: |
| img = torch.permute(img, (3, 0, 1, 2)) |
| img = img[None] |
| permuted = True |
| else: |
| assert len(img.shape) == 3 |
| img = img[None, None] |
| permuted = False |
|
|
| orig_shp = img.shape[2:] |
|
|
| |
| if zero_crop_first: |
| print(' before zero croppping:', orig_shp) |
| img = zero_crop(img[0, 0])[None, None] |
| orig_shp = img.shape[2:] |
| print(' after zero croppping:', orig_shp) |
|
|
|
|
| if win_size is None: |
| if permuted: |
| return torch.permute(img, (0, 2, 3, 4, 1)), [0, 0, 0], orig_shp |
| return img, [0, 0, 0], orig_shp, aff |
| |
| elif orig_shp[0] > win_size[0] or orig_shp[1] > win_size[1] or orig_shp[2] > win_size[2]: |
| crop_start = [ max((orig_shp[i] - win_size[i]), 0) // 2 for i in range(3) ] |
| aff[:-1,-1] = aff[:-1,-1] + aff[:-1,:-1] @ np.array(crop_start) |
| crop_img = img[ :, :, crop_start[0] : crop_start[0] + win_size[0], |
| crop_start[1] : crop_start[1] + win_size[1], |
| crop_start[2] : crop_start[2] + win_size[2]] |
| |
| |
| |
| |
| if permuted: |
| return torch.permute(crop_img, (0, 2, 3, 4, 1)), [0, 0, 0], orig_shp, aff |
| return crop_img, crop_start, orig_shp, aff |
| |
| else: |
| |
| |
| |
| |
| if permuted: |
| return torch.permute(img, (0, 2, 3, 4, 1)), [0, 0, 0], orig_shp, aff |
| return img, [0, 0, 0], orig_shp, aff |
| |
|
|
|
|
| def add_bias_field(I, bf_scale_min = 0.02, bf_scale_max = 0.04, bf_std_min = 0.1, bf_std_max = 0.6, device = 'cpu'): |
| bf_scale = bf_scale_min + np.random.rand(1) * (bf_scale_max - bf_scale_min) |
| size_BF_small = np.round(bf_scale * np.array(I.shape)).astype(int).tolist() |
| BFsmall = torch.tensor(bf_std_min + (bf_std_max - bf_std_min) * np.random.rand(1), dtype=torch.float, device=device) * \ |
| torch.randn(size_BF_small, dtype=torch.float, device=device) |
| BFlog = myzoom_torch(BFsmall, np.array(I.shape) / size_BF_small) |
| BF = torch.exp(BFlog) |
| I_bf = I * BF |
| return I_bf, BF |
|
|
| def resample(I, orig_res = [1., 1., 1.], new_res = [1., 1., 1.]): |
| if not isinstance(orig_res, list): |
| orig_res = [orig_res, orig_res, orig_res] |
| if not isinstance(new_res, list): |
| new_res = [new_res, new_res, new_res] |
| |
| resolution = np.array(new_res) |
| new_size = (np.array(I.shape) * orig_res / resolution).astype(int) |
|
|
| factors = np.array(new_size) / np.array(I.shape) |
| delta = (1.0 - factors) / (2.0 * factors) |
| vx = np.arange(delta[0], delta[0] + new_size[0] / factors[0], 1 / factors[0])[:new_size[0]] |
| vy = np.arange(delta[1], delta[1] + new_size[1] / factors[1], 1 / factors[1])[:new_size[1]] |
| vz = np.arange(delta[2], delta[2] + new_size[2] / factors[2], 1 / factors[2])[:new_size[2]] |
| II, JJ, KK = np.meshgrid(vx, vy, vz, sparse=False, indexing='ij') |
| II = torch.tensor(II, dtype=torch.float, device=I.device) |
| JJ = torch.tensor(JJ, dtype=torch.float, device=I.device) |
| KK = torch.tensor(KK, dtype=torch.float, device=I.device) |
|
|
| I_resize = fast_3D_interp_torch(I, II, JJ, KK, 'linear') |
| I_new = utils.myzoom_torch(I_resize, 1 / factors) |
|
|
| |
| return I_new |
|
|
|
|
| def read_image(img_path, is_label = False, device = 'cpu'): |
| im, aff = utils.MRIread(img_path, im_only=False, dtype='int' if is_label else 'float') |
| im = torch.tensor(np.squeeze(im), dtype=torch.int if is_label else torch.float32, device=device) |
| im = torch.nan_to_num(im) |
| return im |
|
|
|
|
| def prepare_image(img_path, win_size = None, zero_crop_first = False, spacing = None, |
| add_bf = False, is_CT = False, is_label = False, rescale = True, |
| hemis_mask = None, im_only = False, device = 'cpu'): |
| im, aff = utils.MRIread(img_path, im_only=False, dtype='int' if is_label else 'float') |
| im = torch.tensor(np.squeeze(im), dtype=torch.int if is_label else torch.float32, device=device) |
| im = torch.nan_to_num(im) |
|
|
| if len(im.shape) > 3: |
| |
| im = im.mean(dim = -1) |
|
|
| if is_CT and rescale: |
| im = torch.clamp(im, min = 0., max = 80.) |
|
|
| if not is_label and rescale: |
| im -= torch.min(im) |
| im /= torch.max(im) |
|
|
| im, aff = utils.torch_resize(im, aff, 1.) |
|
|
| orig = im |
| orig, aff_before_crop = utils.align_volume_to_ref(orig, aff, aff_ref=np.eye(4), return_aff=True, n_dims=3) |
|
|
| orig, crop_start, orig_shp, aff = center_crop(orig, win_size, zero_crop_first = zero_crop_first, aff = aff_before_crop) |
|
|
| if add_bf and not is_CT: |
| high_res, bf = add_bias_field(im, device = device) |
| bf, _ = utils.align_volume_to_ref(bf, aff_before_crop, aff_ref=np.eye(4), return_aff=True, n_dims=3) |
| bf, crop_start, orig_shp, _ = center_crop(bf, win_size, zero_crop_first = zero_crop_first, aff = aff_before_crop) |
| else: |
| high_res, bf = im, None |
|
|
| if spacing is not None: |
| final = resample(high_res, new_res = spacing) |
| else: |
| final = high_res |
|
|
| high_res, _ = utils.align_volume_to_ref(high_res, aff_before_crop, aff_ref=np.eye(4), return_aff=True, n_dims=3) |
| high_res, crop_start, orig_shp, _ = center_crop(high_res, win_size, zero_crop_first = zero_crop_first, aff = aff_before_crop) |
| |
| final, _ = utils.align_volume_to_ref(final, aff_before_crop, aff_ref=np.eye(4), return_aff=True, n_dims=3) |
| final, crop_start, orig_shp, _ = center_crop(final, win_size, zero_crop_first = zero_crop_first, aff = aff_before_crop) |
|
|
| if hemis_mask is not None: |
| final[hemis_mask ==0] = 0 |
|
|
| if im_only: |
| return final |
| |
| return final, orig, high_res, bf, aff, crop_start, orig_shp |
|
|
|
|
|
|
| |
| @torch.no_grad() |
| def evaluate_image(inputs, ckp_path, feature_only = True, device = 'cpu', gen_cfg = None, model_cfg = None): |
| |
|
|
| |
| gen_args = utils.preprocess_cfg([default_gen_cfg_file, gen_cfg], cfg_dir = gen_cfg_dir) |
| train_args = utils.preprocess_cfg([default_train_cfg_file, default_val_file, model_cfg], cfg_dir = train_cfg_dir) |
|
|
| samples = [ { 'input': inputs } ] |
|
|
| |
| gen_args, train_args, feat_model, processors, criterion, postprocessor = build_model(gen_args, train_args, device) |
| load_checkpoint(ckp_path, [feat_model], model_keys = ['model'], to_print = False) |
| outputs, _ = feat_model(samples) |
|
|
| for processor in processors: |
| outputs = processor(outputs, samples) |
| if postprocessor is not None: |
| outputs, _, _ = postprocessor(gen_args, train_args, outputs, samples, target = None, feats = None, tasks = gen_args.tasks) |
|
|
| if feature_only: |
| return outputs[0]['feat'][-1] |
| else: |
| return outputs[0] |
| |
|
|
| @torch.no_grad() |
| def evaluate_image_twostage(inputs, pathol_ckp_path, task_ckp_path, feature_only = True, device = 'cpu', gen_cfg = None, model_cfg = None): |
| |
|
|
| |
| gen_args = utils.preprocess_cfg([default_gen_cfg_file, gen_cfg], cfg_dir = gen_cfg_dir) |
| train_args = utils.preprocess_cfg([default_train_cfg_file, default_val_file, model_cfg], cfg_dir = train_cfg_dir) |
|
|
| samples = [ { 'input': inputs } ] |
|
|
| |
| gen_args, train_args, pathol_model, task_model, pathol_processors, task_processors, criterion, postprocessor = build_inpaint_model(gen_args, train_args, device) |
| load_checkpoint(pathol_ckp_path, [pathol_model], model_keys = ['model'], to_print = False) |
| load_checkpoint(task_ckp_path, [task_model], model_keys = ['model'], to_print = False) |
|
|
| |
| outputs_pathol, _ = pathol_model(samples) |
| for processor in pathol_processors: |
| outputs_pathol = processor(outputs_pathol, samples) |
|
|
| |
| for i in range(len(samples)): |
| samples[i]['input_masked'] = samples[i]['input'] * (1 - outputs_pathol[i]['pathology']) |
| outputs_task, _ = task_model(samples, input_name = 'input_masked', cond = [o['pathology'] for o in outputs_pathol]) |
| for processor in task_processors: |
| outputs_task = processor(outputs_task, samples) |
|
|
| outputs = utils.merge_list_of_dict(outputs_task, outputs_pathol) |
|
|
| if postprocessor is not None: |
| outputs, _, _ = postprocessor(gen_args, train_args, outputs, samples, target = None, feats = None, tasks = gen_args.tasks) |
|
|
| if feature_only: |
| return outputs[0]['feat_pathol'][-1], outputs[0]['feat_task'][-1] |
| else: |
| return outputs[0] |
|
|
|
|
|
|
| @torch.no_grad() |
| def evaluate_path(input_paths, save_dir, ckp_path, win_size = [220, 220, 220], |
| save_input = False, aux_paths = {}, save_aux = False, exclude_keys = [], |
| mask_output = False, ext = '.nii.gz', device = 'cpu', |
| gen_cfg = None, model_cfg = None): |
| |
| gen_args = utils.preprocess_cfg([default_gen_cfg_file, gen_cfg], cfg_dir = gen_cfg_dir) |
| train_args = utils.preprocess_cfg([default_train_cfg_file, default_val_file, model_cfg], cfg_dir = train_cfg_dir) |
| |
| |
| gen_args, train_args, model, processors, criterion, postprocessor = build_model(gen_args, train_args, device) |
| load_checkpoint(ckp_path, [model], model_keys = ['model'], to_print = False) |
|
|
| for i, input_path in enumerate(input_paths): |
| print('Now testing: %s (%d/%d)' % (input_path, i+1, len(input_paths))) |
| print(' ckp:', ckp_path) |
| curr_save_dir = utils.make_dir(os.path.join(save_dir, os.path.basename(input_path).split('.nii')[0])) |
|
|
| |
| mask = None |
| im, orig, high_res, bf, aff, crop_start, orig_shp = prepare_image(input_path, win_size, device = device) |
| if save_input: |
| print(' Input: saved in - %s' % (os.path.join(curr_save_dir, 'input' + ext))) |
| utils.viewVolume(im, aff, names = ['input'], ext = ext, save_dir = curr_save_dir) |
| for k in aux_paths.keys(): |
| im_k, _, _, _, _, _, _ = prepare_image(aux_paths[k][i], win_size, is_label = 'label' in k, device = device) |
| if save_aux: |
| print(' Aux input: %s - saved in - %s' % (k, os.path.join(curr_save_dir, k + ext))) |
| utils.viewVolume(im_k, aff, names = [k], ext = ext, save_dir = curr_save_dir) |
| if mask_output and 'mask' in k: |
| mask = im_k.clone() |
| mask[im_k != 0.] = 1. |
| samples = [ { 'input': im } ] |
| |
| |
| outputs, _ = model(samples) |
|
|
| for processor in processors: |
| outputs = processor(outputs, samples) |
| if postprocessor is not None: |
| outputs, _, _ = postprocessor(gen_args, train_args, outputs, samples, target = None, feats = None, tasks = gen_args.tasks) |
|
|
| out = outputs[0] |
| if mask_output and mask is None: |
| mask = torch.zeros_like(im) |
| mask[im != 0.] = 1. |
| for key in out.keys(): |
| if key not in exclude_keys and isinstance(out[key], torch.Tensor): |
| print(' Output: %s - saved in - %s' % (key, os.path.join(curr_save_dir, 'out_' + key + ext))) |
| out[key][out[key] < 0.] = 0. |
| utils.viewVolume(out[key] * mask if mask_output else out[key], aff, names = ['out_'+key], ext = ext, save_dir = curr_save_dir) |
|
|