| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| from collections import defaultdict, OrderedDict |
| import logging |
| import os |
| import re |
| import torch |
| import traceback |
|
|
| from torch.serialization import default_restore_location |
|
|
|
|
| def torch_persistent_save(*args, **kwargs): |
| for i in range(3): |
| try: |
| return torch.save(*args, **kwargs) |
| except Exception: |
| if i == 2: |
| logging.error(traceback.format_exc()) |
|
|
|
|
| def convert_state_dict_type(state_dict, ttype=torch.FloatTensor): |
| if isinstance(state_dict, dict): |
| cpu_dict = OrderedDict() |
| for k, v in state_dict.items(): |
| cpu_dict[k] = convert_state_dict_type(v) |
| return cpu_dict |
| elif isinstance(state_dict, list): |
| return [convert_state_dict_type(v) for v in state_dict] |
| elif torch.is_tensor(state_dict): |
| return state_dict.type(ttype) |
| else: |
| return state_dict |
|
|
|
|
| def save_state(filename, args, model, criterion, optimizer, lr_scheduler, |
| num_updates, optim_history=None, extra_state=None): |
| if optim_history is None: |
| optim_history = [] |
| if extra_state is None: |
| extra_state = {} |
| state_dict = { |
| 'args': args, |
| 'model': convert_state_dict_type(model.state_dict()), |
| 'optimizer_history': optim_history + [ |
| { |
| 'criterion_name': criterion.__class__.__name__, |
| 'optimizer_name': optimizer.__class__.__name__, |
| 'lr_scheduler_state': lr_scheduler.state_dict(), |
| 'num_updates': num_updates, |
| } |
| ], |
| 'last_optimizer_state': convert_state_dict_type(optimizer.state_dict()), |
| 'extra_state': extra_state, |
| } |
| torch_persistent_save(state_dict, filename) |
|
|
|
|
| def load_model_state(filename, model): |
| if not os.path.exists(filename): |
| return None, [], None |
| state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu')) |
| state = _upgrade_state_dict(state) |
| model.upgrade_state_dict(state['model']) |
|
|
| |
| try: |
| model.load_state_dict(state['model'], strict=True) |
| except Exception: |
| raise Exception('Cannot load model parameters from checkpoint, ' |
| 'please ensure that the architectures match') |
|
|
| return state['extra_state'], state['optimizer_history'], state['last_optimizer_state'] |
|
|
|
|
| def _upgrade_state_dict(state): |
| """Helper for upgrading old model checkpoints.""" |
| |
| if 'optimizer_history' not in state: |
| state['optimizer_history'] = [ |
| { |
| 'criterion_name': 'CrossEntropyCriterion', |
| 'best_loss': state['best_loss'], |
| }, |
| ] |
| state['last_optimizer_state'] = state['optimizer'] |
| del state['optimizer'] |
| del state['best_loss'] |
| |
| if 'epoch' in state and 'extra_state' not in state: |
| state['extra_state'] = { |
| 'epoch': state['epoch'], |
| 'batch_offset': state['batch_offset'], |
| 'val_loss': state['val_loss'], |
| } |
| del state['epoch'] |
| del state['batch_offset'] |
| del state['val_loss'] |
| |
| if 'optimizer' in state['optimizer_history'][-1]: |
| state['last_optimizer_state'] = state['optimizer_history'][-1]['optimizer'] |
| for optim_hist in state['optimizer_history']: |
| del optim_hist['optimizer'] |
| |
| if 'optimizer_name' not in state['optimizer_history'][-1]: |
| state['optimizer_history'][-1]['optimizer_name'] = 'FairseqNAG' |
| |
| if 'lr_scheduler_state' not in state['optimizer_history'][-1]: |
| state['optimizer_history'][-1]['lr_scheduler_state'] = { |
| 'best': state['optimizer_history'][-1]['best_loss'], |
| } |
| del state['optimizer_history'][-1]['best_loss'] |
| |
| if 'num_updates' not in state['optimizer_history'][-1]: |
| state['optimizer_history'][-1]['num_updates'] = 0 |
| |
| if hasattr(state['args'], 'max_positions') and not hasattr(state['args'], 'max_source_positions'): |
| state['args'].max_source_positions = state['args'].max_positions |
| state['args'].max_target_positions = state['args'].max_positions |
| |
| if 'train_iterator' not in state['extra_state']: |
| state['extra_state']['train_iterator'] = { |
| 'epoch': state['extra_state']['epoch'], |
| 'iterations_in_epoch': 0, |
| } |
| return state |
|
|
|
|
| def load_ensemble_for_inference(filenames, task, model_arg_overrides=None): |
| """Load an ensemble of models for inference. |
| model_arg_overrides allows you to pass a dictionary model_arg_overrides -- |
| {'arg_name': arg} -- to override model args that were used during model |
| training |
| """ |
| |
| states = [] |
| for filename in filenames: |
| if not os.path.exists(filename): |
| raise IOError('Model file not found: {}'.format(filename)) |
| state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu')) |
| state = _upgrade_state_dict(state) |
| states.append(state) |
| args = states[0]['args'] |
| if model_arg_overrides is not None: |
| args = _override_model_args(args, model_arg_overrides) |
|
|
| |
| ensemble = [] |
| for state in states: |
| model = task.build_model(args) |
| model.upgrade_state_dict(state['model']) |
| model.load_state_dict(state['model'], strict=True) |
| ensemble.append(model) |
| return ensemble, args |
|
|
|
|
| def _override_model_args(args, model_arg_overrides): |
| |
| for arg_name, arg_val in model_arg_overrides.items(): |
| setattr(args, arg_name, arg_val) |
| return args |
|
|
|
|
| def move_to_cuda(sample): |
| if len(sample) == 0: |
| return {} |
|
|
| def _move_to_cuda(maybe_tensor): |
| if torch.is_tensor(maybe_tensor): |
| return maybe_tensor.cuda() |
| elif isinstance(maybe_tensor, dict): |
| return { |
| key: _move_to_cuda(value) |
| for key, value in maybe_tensor.items() |
| } |
| elif isinstance(maybe_tensor, list): |
| return [_move_to_cuda(x) for x in maybe_tensor] |
| else: |
| return maybe_tensor |
|
|
| return _move_to_cuda(sample) |
|
|
|
|
| INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0) |
|
|
|
|
| def _get_full_incremental_state_key(module_instance, key): |
| module_name = module_instance.__class__.__name__ |
|
|
| |
| |
| if not hasattr(module_instance, '_fairseq_instance_id'): |
| INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1 |
| module_instance._fairseq_instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name] |
|
|
| return '{}.{}.{}'.format(module_name, module_instance._fairseq_instance_id, key) |
|
|
|
|
| def get_incremental_state(module, incremental_state, key): |
| """Helper for getting incremental state for an nn.Module.""" |
| full_key = _get_full_incremental_state_key(module, key) |
| if incremental_state is None or full_key not in incremental_state: |
| return None |
| return incremental_state[full_key] |
|
|
|
|
| def set_incremental_state(module, incremental_state, key, value): |
| """Helper for setting incremental state for an nn.Module.""" |
| if incremental_state is not None: |
| full_key = _get_full_incremental_state_key(module, key) |
| incremental_state[full_key] = value |
|
|
|
|
| def load_align_dict(replace_unk): |
| if replace_unk is None: |
| align_dict = None |
| elif isinstance(replace_unk, str): |
| |
| align_dict = {} |
| with open(replace_unk, 'r') as f: |
| for line in f: |
| cols = line.split() |
| align_dict[cols[0]] = cols[1] |
| else: |
| |
| |
| align_dict = {} |
| return align_dict |
|
|
|
|
| def print_embed_overlap(embed_dict, vocab_dict): |
| embed_keys = set(embed_dict.keys()) |
| vocab_keys = set(vocab_dict.symbols) |
| overlap = len(embed_keys & vocab_keys) |
| print("| Found {}/{} types in embedding file.".format(overlap, len(vocab_dict))) |
|
|
|
|
| def parse_embedding(embed_path): |
| """Parse embedding text file into a dictionary of word and embedding tensors. |
| The first line can have vocabulary size and dimension. The following lines |
| should contain word and embedding separated by spaces. |
| Example: |
| 2 5 |
| the -0.0230 -0.0264 0.0287 0.0171 0.1403 |
| at -0.0395 -0.1286 0.0275 0.0254 -0.0932 |
| """ |
| embed_dict = {} |
| with open(embed_path) as f_embed: |
| next(f_embed) |
| for line in f_embed: |
| pieces = line.rstrip().split(" ") |
| embed_dict[pieces[0]] = torch.Tensor([float(weight) for weight in pieces[1:]]) |
| return embed_dict |
|
|
|
|
| def load_embedding(embed_dict, vocab, embedding): |
| for idx in range(len(vocab)): |
| token = vocab[idx] |
| if token in embed_dict: |
| embedding.weight.data[idx] = embed_dict[token] |
| return embedding |
|
|
|
|
| def replace_unk(hypo_str, src_str, alignment, align_dict, unk): |
| from fairseq import tokenizer |
| |
| hypo_tokens = tokenizer.tokenize_line(hypo_str) |
| |
| src_tokens = tokenizer.tokenize_line(src_str) + ['<eos>'] |
| for i, ht in enumerate(hypo_tokens): |
| if ht == unk: |
| src_token = src_tokens[alignment[i]] |
| |
| hypo_tokens[i] = align_dict.get(src_token, src_token) |
| return ' '.join(hypo_tokens) |
|
|
|
|
| def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, tgt_dict, remove_bpe): |
| from fairseq import tokenizer |
| hypo_str = tgt_dict.string(hypo_tokens, remove_bpe) |
| if align_dict is not None: |
| hypo_str = replace_unk(hypo_str, src_str, alignment, align_dict, tgt_dict.unk_string()) |
| if align_dict is not None or remove_bpe is not None: |
| |
| |
| hypo_tokens = tokenizer.Tokenizer.tokenize(hypo_str, tgt_dict, add_if_not_exist=True) |
| return hypo_tokens, hypo_str, alignment |
|
|
|
|
| def make_positions(tensor, padding_idx, left_pad): |
| """Replace non-padding symbols with their position numbers. |
| Position numbers begin at padding_idx+1. |
| Padding symbols are ignored, but it is necessary to specify whether padding |
| is added on the left side (left_pad=True) or right side (left_pad=False). |
| """ |
| max_pos = padding_idx + 1 + tensor.size(1) |
| if not hasattr(make_positions, 'range_buf'): |
| make_positions.range_buf = tensor.new() |
| make_positions.range_buf = make_positions.range_buf.type_as(tensor) |
| if make_positions.range_buf.numel() < max_pos: |
| torch.arange(padding_idx + 1, max_pos, out=make_positions.range_buf) |
| mask = tensor.ne(padding_idx) |
| positions = make_positions.range_buf[:tensor.size(1)].expand_as(tensor) |
| if left_pad: |
| positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1) |
| return tensor.clone().masked_scatter_(mask, positions[mask]) |
|
|
|
|
| def strip_pad(tensor, pad): |
| return tensor[tensor.ne(pad)] |
|
|
|
|
| def buffered_arange(max): |
| if not hasattr(buffered_arange, 'buf'): |
| buffered_arange.buf = torch.LongTensor() |
| if max > buffered_arange.buf.numel(): |
| torch.arange(max, out=buffered_arange.buf) |
| return buffered_arange.buf[:max] |
|
|
|
|
| def convert_padding_direction(src_tokens, padding_idx, right_to_left=False, left_to_right=False): |
| assert right_to_left ^ left_to_right |
| pad_mask = src_tokens.eq(padding_idx) |
| if not pad_mask.any(): |
| |
| return src_tokens |
| if left_to_right and not pad_mask[:, 0].any(): |
| |
| return src_tokens |
| if right_to_left and not pad_mask[:, -1].any(): |
| |
| return src_tokens |
| max_len = src_tokens.size(1) |
| range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens) |
| num_pads = pad_mask.long().sum(dim=1, keepdim=True) |
| if right_to_left: |
| index = torch.remainder(range - num_pads, max_len) |
| else: |
| index = torch.remainder(range + num_pads, max_len) |
| return src_tokens.gather(1, index) |
|
|
|
|
| def item(tensor): |
| if hasattr(tensor, 'item'): |
| return tensor.item() |
| if hasattr(tensor, '__getitem__'): |
| return tensor[0] |
| return tensor |
|
|
|
|
| def clip_grad_norm_(tensor, max_norm): |
| grad_norm = item(torch.norm(tensor)) |
| if grad_norm > max_norm > 0: |
| clip_coef = max_norm / (grad_norm + 1e-6) |
| tensor.mul_(clip_coef) |
| return grad_norm |
|
|
|
|
| def fill_with_neg_inf(t): |
| """FP16-compatible function that fills a tensor with -inf.""" |
| return t.float().fill_(float('-inf')).type_as(t) |
|
|
|
|
| def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'): |
| """Retrieves all checkpoints found in `path` directory. |
| Checkpoints are identified by matching filename to the specified pattern. If |
| the pattern contains groups, the result will be sorted by the first group in |
| descending order. |
| """ |
| pt_regexp = re.compile(pattern) |
| files = os.listdir(path) |
|
|
| entries = [] |
| for i, f in enumerate(files): |
| m = pt_regexp.fullmatch(f) |
| if m is not None: |
| idx = int(m.group(1)) if len(m.groups()) > 0 else i |
| entries.append((idx, m.group(0))) |
| return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] |
|
|