| |
| |
| |
|
|
| import os, sys, warnings, shutil, glob, time, datetime |
| warnings.filterwarnings("ignore") |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| from collections import defaultdict |
|
|
| import torch |
| import numpy as np |
|
|
| from utils.misc import make_dir, viewVolume, MRIread |
| import utils.test_utils as utils |
| from Generator.utils import fast_3D_interp_torch |
|
|
| device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu' |
|
|
|
|
| |
| |
| |
| label_list_left_segmentation = [0, 1, 2, 3, 4, 7, 8, 9, 10, 14, 15, 17, 31, 34, 36, 38, 40, 42] |
| lut = torch.zeros(10000, dtype=torch.long, device=device) |
| for l in range(len(label_list_left_segmentation)): |
| lut[label_list_left_segmentation[l]] = l |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
|
|
| def prepare_paths(data_root, split_txt): |
|
|
| |
| datasets = [] |
| g = glob.glob(os.path.join(data_root, '*' + 'T1w.nii')) |
| for i in range(len(g)): |
| filename = os.path.basename(g[i]) |
| dataset = filename[:filename.find('.')] |
| found = False |
| for d in datasets: |
| if dataset == d: |
| found = True |
| if found is False: |
| datasets.append(dataset) |
| print('Found ' + str(len(datasets)) + ' datasets with ' + str(len(g)) + ' scans in total') |
| print('Dataset list', datasets) |
| names = [] |
|
|
| split_file = open(split_txt, 'r') |
| split_names = [] |
| for subj in split_file.readlines(): |
| split_names.append(subj.strip()) |
|
|
| for i in range(len(datasets)): |
| names.append([name for name in split_names if os.path.basename(name).startswith(datasets[i])]) |
| |
| datasets_num = len(datasets) |
| datasets_len = [len(names[i]) for i in range(len(names))] |
| print('Num of testing data', sum([len(names[i]) for i in range(len(names))])) |
|
|
| return names, datasets |
|
|
|
|
| def get_info(t1): |
| |
| t2 = t1[:-7] + 'T2w.nii' |
| flair = t1[:-7] + 'FLAIR.nii' |
| ct = t1[:-7] + 'CT.nii' |
| cerebral_labels = t1[:-7] + 'brainseg.nii' |
| segmentation_labels = t1[:-7] + 'brainseg_with_extracerebral.nii' |
| brain_dist_map = t1[:-7] + 'brain_dist_map.nii' |
| lp_dist_map = t1[:-7] + 'lp_dist_map.nii' |
| rp_dist_map = t1[:-7] + 'rp_dist_map.nii' |
| lw_dist_map = t1[:-7] + 'lw_dist_map.nii' |
| rw_dist_map = t1[:-7] + 'rw_dist_map.nii' |
| mni_reg_x = t1[:-7] + 'mni_reg.x.nii' |
| mni_reg_y = t1[:-7] + 'mni_reg.y.nii' |
| mni_reg_z = t1[:-7] + 'mni_reg.z.nii' |
|
|
| modalities = {'T1': t1} |
| if os.path.isfile(t2): |
| modalities.update({'T2': t2}) |
| if os.path.isfile(flair): |
| modalities.update({'FLAIR': flair}) |
| if os.path.isfile(ct): |
| modalities.update({'CT': ct}) |
|
|
| aux = {'label': segmentation_labels, 'cerebral_label': cerebral_labels, 'distance': brain_dist_map, |
| 'regx': mni_reg_x, 'regy': mni_reg_y, 'regz': mni_reg_z, |
| 'lp': lp_dist_map, 'lw': lw_dist_map, 'rp': rp_dist_map, 'rw': rw_dist_map} |
|
|
| return modalities, aux |
|
|
|
|
| |
|
|
|
|
| gen_cfg = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/generator/test/demo_test.yaml' |
| gen_hemis_cfg = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/generator/test/demo_test_hemis.yaml' |
| model_cfg = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/trainer/test/demo_test.yaml' |
|
|
| |
| win_size = [160, 160, 160] |
| mask_output = False |
|
|
|
|
| exclude_keys = ['segmentation'] |
| data_root = '/autofs/vast/lemon/data_curated/brain_mris_QCed' |
| split_txt = '/autofs/vast/lemon/temp_stuff/peirong/train_test_split/test.txt' |
| names, datasets = prepare_paths(data_root, split_txt) |
|
|
|
|
| max_num_test_dataset = None |
| max_num_per_dataset = None |
|
|
| zero_crop = False |
|
|
| main_save_dir = make_dir('/autofs/space/yogurt_002/users/pl629/results/MTBrainID/test/', reset = False) |
|
|
| models = [ |
| |
| |
| ('test_sr', '/autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/sr/l6_16/0926-2035/ckp/checkpoint_latest.pth'), |
| |
| |
|
|
| |
| |
| |
| |
| ] |
|
|
| |
| |
| setups = [ |
| |
| |
| ([1.5, 1.5, 5], False), |
| |
| ] |
|
|
|
|
|
|
| all_start_time = time.time() |
| for postfix, ckp_path in models: |
|
|
| for spacing, add_bf in setups: |
| curr_postfix = postfix + '_BF' if add_bf else postfix |
| curr_postfix += '_%s-%s-%s' % (str(spacing[0]), str(spacing[1]), str(spacing[2])) if spacing is not None else '_1-1-1' |
| save_dir = make_dir(os.path.join(main_save_dir, curr_postfix), reset = True) |
| print('\nSave at: %s\n' % save_dir) |
|
|
| curr_gen_cfg = gen_hemis_cfg if 'hemis' in postfix else gen_cfg |
|
|
|
|
| for i, curr_dataset in enumerate(names): |
| curr_dataset.sort() |
| print('Dataset: %s (%d/%d) -- %d total cases' % (datasets[i], i+1, len(datasets), len(curr_dataset))) |
|
|
| |
| if max_num_test_dataset is not None and i >= max_num_test_dataset: |
| break |
|
|
| start_time = time.time() |
| for j, t1_name in enumerate(curr_dataset): |
|
|
| if max_num_per_dataset is not None and j >= max_num_per_dataset: |
| break |
|
|
| subj_name = os.path.basename(t1_name).split('.T1w')[0] |
| subj_dir = make_dir(os.path.join(save_dir, subj_name)) |
| print('Now testing: %s (%d/%d)' % (t1_name, j+1, len(curr_dataset))) |
|
|
| modalities, aux = get_info(t1_name) |
|
|
| S_cerebral = torch.squeeze(utils.prepare_image(aux['cerebral_label'], win_size = win_size, zero_crop = zero_crop, spacing = spacing, rescale = False, im_only = True, device = device)) |
| |
| if 'hemis' in postfix: |
| S = utils.prepare_image(aux['cerebral_label'], win_size = win_size, zero_crop = zero_crop, spacing = spacing, rescale = False, im_only = True, device = device) |
| S = lut[S.int()] |
| X = utils.prepare_image(aux['regx'], win_size = win_size, zero_crop = zero_crop, spacing = spacing, rescale = False, im_only = True, device = device) |
| hemis_mask = (S > 0) & (X < 0).int() |
| viewVolume(hemis_mask, names = ['.'.join(os.path.basename(aux['label']).split('.')[:2]) + '.hemis_mask'], save_dir = subj_dir) |
| else: |
| hemis_mask = None |
|
|
| |
| for mod in modalities.keys(): |
| final, orig, high_res, bf, _, _, _ = utils.prepare_image(modalities[mod], win_size = win_size, zero_crop = zero_crop, spacing = spacing, add_bf = add_bf, is_CT = 'CT' in mod, rescale = False, hemis_mask = hemis_mask, im_only = False, device = device) |
| viewVolume(orig, names = [os.path.basename(modalities[mod])[:-4]], save_dir = subj_dir) |
| viewVolume(final, names = [os.path.basename(modalities[mod])[:-4] + '.input'], save_dir = subj_dir) |
| viewVolume(high_res, names = [os.path.basename(modalities[mod])[:-4] + '.high_res'], save_dir = subj_dir) |
| if bf is not None: |
| viewVolume(bf, names = [os.path.basename(modalities[mod])[:-4] + '.bias_field'], save_dir = subj_dir) |
| for mod in aux.keys(): |
| im = utils.prepare_image(aux[mod], win_size = win_size, zero_crop = zero_crop, is_label = 'label' in mod, rescale = False, hemis_mask = hemis_mask, im_only = True, device = device) |
| viewVolume(im, names = [os.path.basename(aux[mod])[:-4]], save_dir = subj_dir) |
|
|
| |
| for mod in modalities.keys(): |
| test_dir = make_dir(os.path.join(subj_dir, 'input_' + mod)) |
| im = utils.prepare_image(os.path.join(subj_dir, os.path.basename(modalities[mod])[:-4] + '.input.nii.gz'), win_size = win_size, zero_crop = zero_crop, is_CT = 'CT' in mod, hemis_mask = hemis_mask, im_only = True, device = device) |
| outs = utils.evaluate_image(im, ckp_path = ckp_path, feature_only = False, device = device, gen_cfg = curr_gen_cfg, model_cfg = model_cfg) |
|
|
| if mask_output: |
| mask = im.clone() |
| mask[im != 0.] = 1. |
|
|
| for k, v in outs.items(): |
| if 'feat' not in k and k not in exclude_keys: |
| viewVolume(v * mask if mask_output else v, names = [ 'out_' + k], save_dir = test_dir) |
| |
| print(S_cerebral.shape, outs['regx'].shape) |
| deformed_atlas = utils.get_deformed_atlas(S_cerebral, torch.squeeze(outs['regx']), torch.squeeze(outs['regy']), torch.squeeze(outs['regz'])) |
| viewVolume(deformed_atlas * mask if mask_output else deformed_atlas, names = [ 'out_deformed_atlas'], save_dir = test_dir) |
| |
| total_time = time.time() - start_time |
| total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
| print('Testing time for {}: {}'.format(total_time_str, datasets[i])) |
| |
| all_total_time = time.time() - all_start_time |
| all_total_time_str = str(datetime.timedelta(seconds=int(all_total_time))) |
| print('Total testing time: {}'.format(total_time_str)) |
| |