| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.nn.modules.utils import _single |
| import src.modules.utils as utils |
| from src.modules.multihead_attention import MultiheadAttention |
| import numpy as np |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| import copy |
|
|
|
|
| def make_positions(tensor, padding_idx, left_pad): |
| """Replace non-padding symbols with their position numbers. |
| Position numbers begin at padding_idx+1. |
| Padding symbols are ignored, but it is necessary to specify whether padding |
| is added on the left side (left_pad=True) or right side (left_pad=False). |
| """ |
|
|
| |
| max_pos = padding_idx + 1 + tensor.size(1) |
| |
| range_buf = tensor.new() |
| |
| if range_buf.numel() < max_pos: |
| torch.arange(padding_idx + 1, max_pos, out=range_buf) |
| mask = tensor.ne(padding_idx) |
| positions = range_buf[:tensor.size(1)].expand_as(tensor) |
| if left_pad: |
| positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1) |
|
|
| out = tensor.clone() |
| out = out.masked_scatter_(mask,positions[mask]) |
| return out |
|
|
|
|
| class LearnedPositionalEmbedding(nn.Embedding): |
| """This module learns positional embeddings up to a fixed maximum size. |
| Padding symbols are ignored, but it is necessary to specify whether padding |
| is added on the left side (left_pad=True) or right side (left_pad=False). |
| """ |
|
|
| def __init__(self, num_embeddings, embedding_dim, padding_idx, left_pad): |
| super().__init__(num_embeddings, embedding_dim, padding_idx) |
| self.left_pad = left_pad |
| nn.init.normal_(self.weight, mean=0, std=embedding_dim ** -0.5) |
|
|
| def forward(self, input, incremental_state=None): |
| """Input is expected to be of size [bsz x seqlen].""" |
| if incremental_state is not None: |
| |
|
|
| positions = input.data.new(1, 1).fill_(self.padding_idx + input.size(1)) |
| else: |
|
|
| positions = make_positions(input.data, self.padding_idx, self.left_pad) |
| return super().forward(positions) |
|
|
| def max_positions(self): |
| """Maximum number of supported positions.""" |
| return self.num_embeddings - self.padding_idx - 1 |
|
|
| class SinusoidalPositionalEmbedding(nn.Module): |
| """This module produces sinusoidal positional embeddings of any length. |
| Padding symbols are ignored, but it is necessary to specify whether padding |
| is added on the left side (left_pad=True) or right side (left_pad=False). |
| """ |
|
|
| def __init__(self, embedding_dim, padding_idx, left_pad, init_size=1024): |
| super().__init__() |
| self.embedding_dim = embedding_dim |
| self.padding_idx = padding_idx |
| self.left_pad = left_pad |
| self.weights = SinusoidalPositionalEmbedding.get_embedding( |
| init_size, |
| embedding_dim, |
| padding_idx, |
| ) |
| self.register_buffer('_float_tensor', torch.FloatTensor()) |
|
|
| @staticmethod |
| def get_embedding(num_embeddings, embedding_dim, padding_idx=None): |
| """Build sinusoidal embeddings. |
| This matches the implementation in tensor2tensor, but differs slightly |
| from the description in Section 3.5 of "Attention Is All You Need". |
| """ |
| half_dim = embedding_dim // 2 |
| emb = math.log(10000) / (half_dim - 1) |
| emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) |
| emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) |
| if embedding_dim % 2 == 1: |
| |
| emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) |
| if padding_idx is not None: |
| emb[padding_idx, :] = 0 |
| return emb |
|
|
| def forward(self, input, incremental_state=None): |
| """Input is expected to be of size [bsz x seqlen].""" |
| |
| bsz, seq_len = input.size() |
| max_pos = self.padding_idx + 1 + seq_len |
| if self.weights is None or max_pos > self.weights.size(0): |
| self.weights = SinusoidalPositionalEmbedding.get_embedding( |
| max_pos, |
| self.embedding_dim, |
| self.padding_idx, |
| ) |
| self.weights = self.weights.type_as(self._float_tensor) |
|
|
| if incremental_state is not None: |
| |
| return self.weights[self.padding_idx + seq_len, :].expand(bsz, 1, -1) |
|
|
| positions = make_positions(input.data, self.padding_idx, self.left_pad) |
| return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() |
|
|
| def max_positions(self): |
| """Maximum number of supported positions.""" |
| return int(1e5) |
|
|
| class TransformerDecoderLayer(nn.Module): |
| """Decoder layer block.""" |
|
|
| def __init__(self, embed_dim, n_att, dropout=0.5, normalize_before=True, last_ln=False): |
| super().__init__() |
|
|
| self.embed_dim = embed_dim |
| self.dropout = dropout |
| self.relu_dropout = dropout |
| self.normalize_before = normalize_before |
| num_layer_norm = 3 |
|
|
| |
| self.self_attn = MultiheadAttention( |
| self.embed_dim, n_att, |
| dropout=dropout, |
| ) |
|
|
| self.cond_att = MultiheadAttention( |
| self.embed_dim, n_att, |
| dropout=dropout, |
| ) |
|
|
| self.fc1 = Linear(self.embed_dim, self.embed_dim) |
| self.fc2 = Linear(self.embed_dim, self.embed_dim) |
| self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(num_layer_norm)]) |
| self.use_last_ln = last_ln |
| if self.use_last_ln: |
| self.last_ln = LayerNorm(self.embed_dim) |
|
|
| def forward(self, x, ingr_features, ingr_mask, incremental_state, img_features): |
|
|
| |
| residual = x |
| x = self.maybe_layer_norm(0, x, before=True) |
| x, _ = self.self_attn( |
| query=x, |
| key=x, |
| value=x, |
| mask_future_timesteps=True, |
| incremental_state=incremental_state, |
| need_weights=False, |
| ) |
| x = F.dropout(x, p=self.dropout, training=self.training) |
| x = residual + x |
| x = self.maybe_layer_norm(0, x, after=True) |
|
|
| residual = x |
| x = self.maybe_layer_norm(1, x, before=True) |
|
|
| |
| if ingr_features is None: |
|
|
| x, _ = self.cond_att(query=x, |
| key=img_features, |
| value=img_features, |
| key_padding_mask=None, |
| incremental_state=incremental_state, |
| static_kv=True, |
| ) |
| elif img_features is None: |
| x, _ = self.cond_att(query=x, |
| key=ingr_features, |
| value=ingr_features, |
| key_padding_mask=ingr_mask, |
| incremental_state=incremental_state, |
| static_kv=True, |
| ) |
|
|
|
|
| else: |
| |
| kv = torch.cat((img_features, ingr_features), 0) |
| mask = torch.cat((torch.zeros(img_features.shape[1], img_features.shape[0], dtype=torch.uint8).to(device), |
| ingr_mask), 1) |
| x, _ = self.cond_att(query=x, |
| key=kv, |
| value=kv, |
| key_padding_mask=mask, |
| incremental_state=incremental_state, |
| static_kv=True, |
| ) |
| x = F.dropout(x, p=self.dropout, training=self.training) |
| x = residual + x |
| x = self.maybe_layer_norm(1, x, after=True) |
|
|
| residual = x |
| x = self.maybe_layer_norm(-1, x, before=True) |
| x = F.relu(self.fc1(x)) |
| x = F.dropout(x, p=self.relu_dropout, training=self.training) |
| x = self.fc2(x) |
| x = F.dropout(x, p=self.dropout, training=self.training) |
| x = residual + x |
| x = self.maybe_layer_norm(-1, x, after=True) |
|
|
| if self.use_last_ln: |
| x = self.last_ln(x) |
|
|
| return x |
|
|
| def maybe_layer_norm(self, i, x, before=False, after=False): |
| assert before ^ after |
| if after ^ self.normalize_before: |
| return self.layer_norms[i](x) |
| else: |
| return x |
|
|
| class DecoderTransformer(nn.Module): |
| """Transformer decoder.""" |
|
|
| def __init__(self, embed_size, vocab_size, dropout=0.5, seq_length=20, num_instrs=15, |
| attention_nheads=16, pos_embeddings=True, num_layers=8, learned=True, normalize_before=True, |
| normalize_inputs=False, last_ln=False, scale_embed_grad=False): |
| super(DecoderTransformer, self).__init__() |
| self.dropout = dropout |
| self.seq_length = seq_length * num_instrs |
| self.embed_tokens = nn.Embedding(vocab_size, embed_size, padding_idx=vocab_size-1, |
| scale_grad_by_freq=scale_embed_grad) |
| nn.init.normal_(self.embed_tokens.weight, mean=0, std=embed_size ** -0.5) |
| if pos_embeddings: |
| self.embed_positions = PositionalEmbedding(1024, embed_size, 0, left_pad=False, learned=learned) |
| else: |
| self.embed_positions = None |
| self.normalize_inputs = normalize_inputs |
| if self.normalize_inputs: |
| self.layer_norms_in = nn.ModuleList([LayerNorm(embed_size) for i in range(3)]) |
|
|
| self.embed_scale = math.sqrt(embed_size) |
| self.layers = nn.ModuleList([]) |
| self.layers.extend([ |
| TransformerDecoderLayer(embed_size, attention_nheads, dropout=dropout, normalize_before=normalize_before, |
| last_ln=last_ln) |
| for i in range(num_layers) |
| ]) |
|
|
| self.linear = Linear(embed_size, vocab_size-1) |
|
|
| def forward(self, ingr_features, ingr_mask, captions, img_features, incremental_state=None): |
|
|
| if ingr_features is not None: |
| ingr_features = ingr_features.permute(0, 2, 1) |
| ingr_features = ingr_features.transpose(0, 1) |
| if self.normalize_inputs: |
| self.layer_norms_in[0](ingr_features) |
|
|
| if img_features is not None: |
| img_features = img_features.permute(0, 2, 1) |
| img_features = img_features.transpose(0, 1) |
| if self.normalize_inputs: |
| self.layer_norms_in[1](img_features) |
|
|
| if ingr_mask is not None: |
| ingr_mask = (1-ingr_mask.squeeze(1)).byte() |
|
|
| |
| if self.embed_positions is not None: |
| positions = self.embed_positions(captions, incremental_state=incremental_state) |
| if incremental_state is not None: |
| if self.embed_positions is not None: |
| positions = positions[:, -1:] |
| captions = captions[:, -1:] |
|
|
| |
| x = self.embed_scale * self.embed_tokens(captions) |
|
|
| if self.embed_positions is not None: |
| x += positions |
|
|
| if self.normalize_inputs: |
| x = self.layer_norms_in[2](x) |
|
|
| x = F.dropout(x, p=self.dropout, training=self.training) |
|
|
| |
| x = x.transpose(0, 1) |
|
|
| for p, layer in enumerate(self.layers): |
| x = layer( |
| x, |
| ingr_features, |
| ingr_mask, |
| incremental_state, |
| img_features |
| ) |
| |
| |
| x = x.transpose(0, 1) |
|
|
| x = self.linear(x) |
| _, predicted = x.max(dim=-1) |
|
|
| return x, predicted |
|
|
| def sample(self, ingr_features, ingr_mask, greedy=True, temperature=1.0, beam=-1, |
| img_features=None, first_token_value=0, |
| replacement=True, last_token_value=0): |
|
|
| incremental_state = {} |
|
|
| |
| if ingr_features is not None: |
| fs = ingr_features.size(0) |
| else: |
| fs = img_features.size(0) |
|
|
| if beam != -1: |
| if fs == 1: |
| return self.sample_beam(ingr_features, ingr_mask, beam, img_features, first_token_value, |
| replacement, last_token_value) |
| else: |
| print ("Beam Search can only be used with batch size of 1. Running greedy or temperature sampling...") |
|
|
| first_word = torch.ones(fs)*first_token_value |
|
|
| first_word = first_word.to(device).long() |
| sampled_ids = [first_word] |
| logits = [] |
|
|
| for i in range(self.seq_length): |
| |
| outputs, _ = self.forward(ingr_features, ingr_mask, torch.stack(sampled_ids, 1), |
| img_features, incremental_state) |
| outputs = outputs.squeeze(1) |
| if not replacement: |
| |
| if i == 0: |
| predicted_mask = torch.zeros(outputs.shape).float().to(device) |
| else: |
| |
| batch_ind = [j for j in range(fs) if sampled_ids[i][j] != 0] |
| sampled_ids_new = sampled_ids[i][batch_ind] |
| predicted_mask[batch_ind, sampled_ids_new] = float('-inf') |
|
|
| |
| outputs += predicted_mask |
|
|
| logits.append(outputs) |
| if greedy: |
| outputs_prob = torch.nn.functional.softmax(outputs, dim=-1) |
| _, predicted = outputs_prob.max(1) |
| predicted = predicted.detach() |
| else: |
| k = 10 |
| outputs_prob = torch.div(outputs.squeeze(1), temperature) |
| outputs_prob = torch.nn.functional.softmax(outputs_prob, dim=-1).data |
|
|
| |
| prob_prev_topk, indices = torch.topk(outputs_prob, k=k, dim=1) |
| predicted = torch.multinomial(prob_prev_topk, 1).view(-1) |
| predicted = torch.index_select(indices, dim=1, index=predicted)[:, 0].detach() |
|
|
| sampled_ids.append(predicted) |
|
|
| sampled_ids = torch.stack(sampled_ids[1:], 1) |
| logits = torch.stack(logits, 1) |
|
|
| return sampled_ids, logits |
|
|
| def sample_beam(self, ingr_features, ingr_mask, beam=3, img_features=None, first_token_value=0, |
| replacement=True, last_token_value=0): |
| k = beam |
| alpha = 0.0 |
| |
| if ingr_features is not None: |
| fs = ingr_features.size(0) |
| else: |
| fs = img_features.size(0) |
| first_word = torch.ones(fs)*first_token_value |
|
|
| first_word = first_word.to(device).long() |
|
|
| sequences = [[[first_word], 0, {}, False, 1]] |
| finished = [] |
|
|
| for i in range(self.seq_length): |
| |
| all_candidates = [] |
| for rem in range(len(sequences)): |
| incremental = sequences[rem][2] |
| outputs, _ = self.forward(ingr_features, ingr_mask, torch.stack(sequences[rem][0], 1), |
| img_features, incremental) |
| outputs = outputs.squeeze(1) |
| if not replacement: |
| |
| if i == 0: |
| predicted_mask = torch.zeros(outputs.shape).float().to(device) |
| else: |
| |
| batch_ind = [j for j in range(fs) if sequences[rem][0][i][j] != 0] |
| sampled_ids_new = sequences[rem][0][i][batch_ind] |
| predicted_mask[batch_ind, sampled_ids_new] = float('-inf') |
|
|
| |
| outputs += predicted_mask |
|
|
| outputs_prob = torch.nn.functional.log_softmax(outputs, dim=-1) |
| probs, indices = torch.topk(outputs_prob, beam) |
| |
| |
| |
|
|
|
|
| for bid in range(beam): |
| tokens = sequences[rem][0] + [indices[:, bid]] |
| score = sequences[rem][1] + probs[:, bid].squeeze().item() |
| if indices[:,bid].item() == last_token_value: |
| finished.append([tokens, score, None, True, sequences[rem][-1] + 1]) |
| else: |
| all_candidates.append([tokens, score, incremental, False, sequences[rem][-1] + 1]) |
|
|
| |
| ordered_all = sorted(all_candidates + finished, key=lambda tup: tup[1]/(np.power(tup[-1],alpha)), |
| reverse=True)[:k] |
| if all(el[-1] == True for el in ordered_all): |
| all_candidates = [] |
|
|
| |
| ordered = sorted(all_candidates, key=lambda tup: tup[1]/(np.power(tup[-1],alpha)), reverse=True) |
| |
| sequences = ordered[:k] |
| finished = sorted(finished, key=lambda tup: tup[1]/(np.power(tup[-1],alpha)), reverse=True)[:k] |
|
|
| if len(finished) != 0: |
| sampled_ids = torch.stack(finished[0][0][1:], 1) |
| logits = finished[0][1] |
| else: |
| sampled_ids = torch.stack(sequences[0][0][1:], 1) |
| logits = sequences[0][1] |
| return sampled_ids, logits |
|
|
| def max_positions(self): |
| """Maximum output length supported by the decoder.""" |
| return self.embed_positions.max_positions() |
|
|
| def upgrade_state_dict(self, state_dict): |
| if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): |
| if 'decoder.embed_positions.weights' in state_dict: |
| del state_dict['decoder.embed_positions.weights'] |
| if 'decoder.embed_positions._float_tensor' not in state_dict: |
| state_dict['decoder.embed_positions._float_tensor'] = torch.FloatTensor() |
| return state_dict |
|
|
|
|
|
|
| def Embedding(num_embeddings, embedding_dim, padding_idx, ): |
| m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) |
| nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) |
| return m |
|
|
|
|
| def LayerNorm(embedding_dim): |
| m = nn.LayerNorm(embedding_dim) |
| return m |
|
|
|
|
| def Linear(in_features, out_features, bias=True): |
| m = nn.Linear(in_features, out_features, bias) |
| nn.init.xavier_uniform_(m.weight) |
| nn.init.constant_(m.bias, 0.) |
| return m |
|
|
|
|
| def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, learned=False): |
| if learned: |
| m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad) |
| nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) |
| nn.init.constant_(m.weight[padding_idx], 0) |
| else: |
| m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, num_embeddings) |
| return m |
|
|