|
|
| import datetime |
| import os, sys |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| import glob |
| import yaml |
| import json |
| import random |
| import time |
| from argparse import Namespace |
| from pathlib import Path |
|
|
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
|
|
| from torch.utils.data import DataLoader |
|
|
| from utils.checkpoint import load_checkpoint |
| import utils.logging as logging |
| import utils.misc as utils |
| |
| from Generator import build_datasets |
| from Trainer.visualizer import TaskVisualizer, FeatVisualizer |
| from Trainer.models import build_model, build_optimizer, build_schedulers |
| from Trainer.engine import train_one_epoch |
| |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| |
| submit_cfg_file = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/submit.yaml' |
|
|
| default_gen_cfg_file = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/generator/default.yaml' |
|
|
| default_train_cfg_file = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/trainer/default_train.yaml' |
| default_val_file = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/trainer/default_val.yaml' |
|
|
| gen_cfg_dir = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/generator/train' |
| train_cfg_dir = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/trainer/train' |
|
|
|
|
| def get_params_groups(model): |
| all = [] |
| for name, param in model.named_parameters(): |
| if not param.requires_grad: |
| continue |
| |
| all.append(param) |
| return [{'params': all}] |
|
|
|
|
| def train(args): |
|
|
| """ |
| args: list of configs |
| """ |
|
|
| submit_args, gen_args, train_args = args |
|
|
| utils.init_distributed_mode(submit_args) |
| if torch.cuda.is_available(): |
| if submit_args.num_gpus > torch.cuda.device_count(): |
| submit_args.num_gpus = torch.cuda.device_count() |
| assert ( |
| submit_args.num_gpus <= torch.cuda.device_count() |
| ), "Cannot use more GPU devices than available" |
| else: |
| submit_args.num_gpus = 0 |
|
|
| if train_args.debug: |
| submit_args.num_workers = 0 |
| |
| output_dir = utils.make_dir(train_args.out_dir) |
| cfg_dir = utils.make_dir(os.path.join(output_dir, "cfg")) |
| plt_dir = utils.make_dir(os.path.join(output_dir, "plt")) |
| vis_train_dir = utils.make_dir(os.path.join(output_dir, "vis-train")) |
| ckp_output_dir = utils.make_dir(os.path.join(output_dir, "ckp")) |
| |
|
|
| yaml.dump( |
| vars(submit_args), |
| open(cfg_dir / 'config_submit.yaml', 'w'), allow_unicode=True) |
| yaml.dump( |
| vars(gen_args), |
| open(cfg_dir / 'config_generator.yaml', 'w'), allow_unicode=True) |
| yaml.dump( |
| vars(train_args), |
| open(cfg_dir / 'config_trainer.yaml', 'w'), allow_unicode=True) |
|
|
| |
| logging.setup_logging(output_dir) |
| logger.info("git:\n {}\n".format(utils.get_sha())) |
| logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(submit_args)).items()))) |
| logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(gen_args)).items()))) |
| logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(train_args)).items()))) |
| log_path = os.path.join(output_dir, 'log.txt') |
|
|
| if submit_args.device is not None: |
| device = submit_args.device |
| elif torch.cuda.is_available(): |
| device = torch.cuda.current_device() |
| else: |
| device = 'cpu' |
| logger.info('device: %s' % device) |
|
|
| |
| |
| seed = int(time.time()) |
|
|
| os.environ['PYTHONHASHSEED'] = str(seed) |
|
|
| np.random.seed(seed) |
| random.seed(seed) |
|
|
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| torch.backends.cudnn.deterministic = True |
|
|
| |
| dataset_dict = build_datasets(gen_args, device = gen_args.device_generator if gen_args.device_generator is not None else device) |
| data_loader_dict = {} |
| data_total = 0 |
| for name in dataset_dict.keys(): |
| if submit_args.num_gpus>1: |
| sampler_train = utils.DistributedWeightedSampler(dataset_dict[name]) |
| else: |
| sampler_train = torch.utils.data.RandomSampler(dataset_dict[name]) |
|
|
| data_loader_dict[name] = DataLoader( |
| dataset_dict[name], |
| batch_sampler=torch.utils.data.BatchSampler(sampler_train, train_args.batch_size, drop_last=True), |
| |
| num_workers=submit_args.num_workers) |
| data_total += len(data_loader_dict[name]) |
| logger.info('Dataset: {}'.format(name)) |
| logger.info('Num of total training data: {}'.format(data_total)) |
|
|
| visualizers = {'result': TaskVisualizer(gen_args, train_args)} |
| if train_args.visualizer.feat_vis: |
| visualizers['feature'] = FeatVisualizer(gen_args, train_args) |
|
|
| |
| gen_args, train_args, model, processors, criterion, postprocessor = build_model(gen_args, train_args, device = device) |
|
|
| model_without_ddp = model |
| |
| if submit_args.num_gpus > 1: |
| logger.info('currect device: %s' % str(torch.cuda.current_device())) |
| |
| model = torch.nn.parallel.DistributedDataParallel( |
| module=model, device_ids=[device], output_device=device, |
| find_unused_parameters=True |
| ) |
| model_without_ddp = model.module |
| n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| logger.info('Num of trainable model params: {}'.format(n_parameters)) |
|
|
|
|
| |
| scaler = torch.cuda.amp.GradScaler() |
| param_dicts = get_params_groups(model_without_ddp) |
| optimizer = build_optimizer(train_args, param_dicts) |
|
|
| |
| lr_scheduler, wd_scheduler = build_schedulers(train_args, data_total, train_args.lr, train_args.min_lr) |
| logger.info(f"Optimizer and schedulers ready.") |
|
|
|
|
| best_val_stats = None |
| train_args.start_epoch = 0 |
| |
| if train_args.resume or train_args.eval_only: |
| if train_args.ckp_path: |
| ckp_path = train_args.ckp_path |
| else: |
| ckp_path = sorted(glob.glob(ckp_output_dir + '/*.pth')) |
|
|
| train_args.start_epoch, best_val_stats = load_checkpoint(ckp_path, [model_without_ddp], optimizer, ['model'], exclude_key = 'supervised_seg') |
| logger.info(f"Resume epoch: {train_args.start_epoch}") |
| else: |
| logger.info('Starting from scratch') |
| if train_args.reset_epoch: |
| train_args.start_epoch = 0 |
| logger.info(f"Start epoch: {train_args.start_epoch}") |
|
|
| |
|
|
| logger.info("Start training") |
| start_time = time.time() |
|
|
| for epoch in range(train_args.start_epoch, train_args.n_epochs): |
|
|
| if os.path.isfile(os.path.join(ckp_output_dir,'checkpoint_latest.pth')): |
| os.rename(os.path.join(ckp_output_dir,'checkpoint_latest.pth'), os.path.join(ckp_output_dir,'checkpoint_latest_bk.pth')) |
|
|
| checkpoint_paths = [ckp_output_dir / 'checkpoint_latest.pth'] |
| |
| |
| |
|
|
| for checkpoint_path in checkpoint_paths: |
| utils.save_on_master({ |
| 'model': model_without_ddp.state_dict(), |
| 'optimizer': optimizer.state_dict(), |
| 'epoch': epoch, |
| 'submit_args': submit_args, |
| 'gen_args': gen_args, |
| 'train_args': train_args, |
| 'best_val_stats': best_val_stats |
| }, checkpoint_path) |
|
|
| |
| if submit_args.num_gpus > 1: |
| sampler_train.set_epoch(epoch) |
| log_stats = train_one_epoch(epoch, gen_args, train_args, model_without_ddp, processors, criterion, data_loader_dict, |
| scaler, optimizer, lr_scheduler, wd_scheduler, postprocessor, visualizers, vis_train_dir, device) |
| |
| |
| if utils.is_main_process(): |
| with (Path(output_dir) / "log.txt").open("a") as f: |
| f.write('epoch %s - ' % str(epoch).zfill(5)) |
| f.write(json.dumps(log_stats) + "\n") |
| |
| |
| if os.path.isfile(log_path): |
| sum_losses = [0.] * (epoch + 1) |
| for loss_name in criterion.loss_names: |
| curr_epoches, curr_losses = utils.read_log(log_path, 'loss_' + loss_name) |
| sum_losses = [sum_losses[i] + curr_losses[i] for i in range(len(curr_losses))] |
| utils.plot_loss(curr_losses, os.path.join(utils.make_dir(plt_dir), 'loss_%s.png' % loss_name)) |
| utils.plot_loss(sum_losses, os.path.join(utils.make_dir(plt_dir), 'loss_all.png')) |
|
|
|
|
| total_time = time.time() - start_time |
| total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
| logger.info('Training time {}'.format(total_time_str)) |
|
|
|
|
| |
|
|
| if __name__ == '__main__': |
| submit_args = utils.preprocess_cfg([submit_cfg_file]) |
| gen_args = utils.preprocess_cfg([default_gen_cfg_file, sys.argv[1]], cfg_dir = gen_cfg_dir) |
| train_args = utils.preprocess_cfg([default_train_cfg_file, default_val_file, sys.argv[2]], cfg_dir = train_cfg_dir) |
| utils.launch_job(submit_args, gen_args, train_args, train) |