Instructions to use Cainiao-AI/TAAS with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Cainiao-AI/TAAS with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="Cainiao-AI/TAAS", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Cainiao-AI/TAAS", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from copy import deepcopy | |
| from torch.nn.init import xavier_uniform_ | |
| import torch.nn.functional as F | |
| from torch.nn import Parameter | |
| from torch.nn.init import normal_ | |
| import torch.utils.checkpoint | |
| from torch import Tensor, device | |
| from .TAAS_utils import * | |
| from transformers.modeling_utils import ModuleUtilsMixin | |
| from fairseq import utils | |
| from fairseq.models import ( | |
| FairseqEncoder, | |
| FairseqEncoderModel, | |
| register_model, | |
| register_model_architecture, | |
| ) | |
| from fairseq.modules import ( | |
| LayerNorm, | |
| ) | |
| from fairseq.utils import safe_hasattr | |
| def init_params(module, n_layers): | |
| if isinstance(module, nn.Linear): | |
| module.weight.data.normal_(mean=0.0, std=0.02 / math.sqrt(n_layers)) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| if isinstance(module, nn.Embedding): | |
| module.weight.data.normal_(mean=0.0, std=0.02) | |
| def softmax_dropout(input, dropout_prob: float, is_training: bool): | |
| return F.dropout(F.softmax(input, -1), dropout_prob, is_training) | |
| class SelfMultiheadAttention(nn.Module): | |
| def __init__( | |
| self, | |
| embed_dim, | |
| num_heads, | |
| dropout=0.0, | |
| bias=True, | |
| scaling_factor=1, | |
| ): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.num_heads = num_heads | |
| self.dropout = dropout | |
| self.head_dim = embed_dim // num_heads | |
| assert (self.head_dim * num_heads == self.embed_dim), "embed_dim must be divisible by num_heads" | |
| self.scaling = (self.head_dim * scaling_factor) ** -0.5 | |
| self.linear_q = nn.Linear(self.embed_dim, self.num_heads * self.head_dim) | |
| self.linear_k = nn.Linear(self.embed_dim, self.num_heads * self.head_dim) | |
| self.linear_v = nn.Linear(self.embed_dim, self.num_heads * self.head_dim) | |
| self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias) | |
| def forward( | |
| self, | |
| query: Tensor, | |
| attn_bias: Tensor = None, | |
| ) -> Tensor: | |
| n_graph, n_node, embed_dim = query.size() | |
| # q, k, v = self.in_proj(query).chunk(3, dim=-1) | |
| _shape = (-1, n_graph * self.num_heads, self.head_dim) | |
| q = self.linear_q(query).contiguous().view(n_graph, -1, self.num_heads, self.head_dim).transpose(1, 2) * self.scaling | |
| k = self.linear_k(query).contiguous().view(n_graph, -1, self.num_heads, self.head_dim).transpose(1, 2) | |
| v = self.linear_v(query).contiguous().view(n_graph, -1, self.num_heads, self.head_dim).transpose(1, 2) | |
| attn_weights = torch.matmul(q, k.transpose(2, 3)) | |
| attn_weights = attn_weights + attn_bias | |
| attn_probs = softmax_dropout(attn_weights, self.dropout, self.training) | |
| attn = torch.matmul(attn_probs, v) | |
| attn = attn.transpose(1, 2).contiguous().view(n_graph, -1, embed_dim) | |
| attn = self.out_proj(attn) | |
| return attn | |
| class Graphormer3DEncoderLayer(nn.Module): | |
| """ | |
| Implements a Graphormer-3D Encoder Layer. | |
| """ | |
| def __init__( | |
| self, | |
| embedding_dim: int = 768, | |
| ffn_embedding_dim: int = 3072, | |
| num_attention_heads: int = 8, | |
| dropout: float = 0.1, | |
| attention_dropout: float = 0.1, | |
| activation_dropout: float = 0.1, | |
| ) -> None: | |
| super().__init__() | |
| # Initialize parameters | |
| self.embedding_dim = embedding_dim | |
| self.num_attention_heads = num_attention_heads | |
| self.attention_dropout = attention_dropout | |
| self.dropout = dropout | |
| self.activation_dropout = activation_dropout | |
| self.self_attn = SelfMultiheadAttention(self.embedding_dim, num_attention_heads, dropout=attention_dropout) | |
| # layer norm associated with the self attention layer | |
| self.self_attn_layer_norm = nn.LayerNorm(self.embedding_dim) | |
| self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) | |
| self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) | |
| self.final_layer_norm = nn.LayerNorm(self.embedding_dim) | |
| def forward(self, x: Tensor, attn_bias: Tensor = None): | |
| residual = x | |
| x = self.self_attn_layer_norm(x) | |
| x = self.self_attn(query=x, attn_bias=attn_bias) | |
| x = F.dropout(x, p=self.dropout, training=self.training) | |
| x = residual + x | |
| residual = x | |
| x = self.final_layer_norm(x) | |
| x = F.gelu(self.fc1(x)) | |
| x = F.dropout(x, p=self.activation_dropout, training=self.training) | |
| x = self.fc2(x) | |
| x = F.dropout(x, p=self.dropout, training=self.training) | |
| x = residual + x | |
| return x | |
| from fairseq.models import ( | |
| BaseFairseqModel, | |
| register_model, | |
| register_model_architecture, | |
| ) | |
| class Graphormer3D(BaseFairseqModel): | |
| def __init__(self): | |
| super().__init__() | |
| self.atom_types = 64 | |
| self.edge_types = 64 * 64 | |
| self.embed_dim = 768 | |
| self.layer_nums = 12 | |
| self.ffn_embed_dim = 768 | |
| self.blocks = 4 | |
| self.attention_heads = 48 | |
| self.input_dropout = 0.0 | |
| self.dropout = 0.1 | |
| self.attention_dropout = 0.1 | |
| self.activation_dropout = 0.0 | |
| self.node_loss_weight = 15 | |
| self.min_node_loss_weight = 1 | |
| self.eng_loss_weight = 1 | |
| self.num_kernel = 128 | |
| self.atom_encoder = nn.Embedding(self.atom_types, self.embed_dim, padding_idx=0) | |
| self.edge_embedding = nn.Embedding(32, self.attention_heads, padding_idx=0) | |
| self.input_dropout = nn.Dropout(0.1) | |
| self.layers = nn.ModuleList( | |
| [ | |
| Graphormer3DEncoderLayer( | |
| self.embed_dim, | |
| self.ffn_embed_dim, | |
| num_attention_heads=self.attention_heads, | |
| dropout=self.dropout, | |
| attention_dropout=self.attention_dropout, | |
| activation_dropout=self.activation_dropout, | |
| ) | |
| for _ in range(self.layer_nums) | |
| ] | |
| ) | |
| self.atom_encoder = nn.Embedding(512 * 9 + 1, self.embed_dim, padding_idx=0) | |
| self.edge_encoder = nn.Embedding(512 * 3 + 1, self.attention_heads, padding_idx=0) | |
| self.edge_type = 'multi_hop' | |
| if self.edge_type == 'multi_hop': | |
| self.edge_dis_encoder = nn.Embedding(16 * self.attention_heads * self.attention_heads, 1) | |
| self.spatial_pos_encoder = nn.Embedding(512, self.attention_heads, padding_idx=0) | |
| self.in_degree_encoder = nn.Embedding(512, self.embed_dim, padding_idx=0) | |
| self.out_degree_encoder = nn.Embedding(512, self.embed_dim, padding_idx=0) | |
| self.node_position_ids_encoder = nn.Embedding(10, self.embed_dim, padding_idx=0) | |
| self.final_ln: Callable[[Tensor], Tensor] = nn.LayerNorm(self.embed_dim) | |
| self.engergy_proj: Callable[[Tensor], Tensor] = NonLinear(self.embed_dim, 1) | |
| self.energe_agg_factor: Callable[[Tensor], Tensor] = nn.Embedding(3, 1) | |
| nn.init.normal_(self.energe_agg_factor.weight, 0, 0.01) | |
| self.graph_token = nn.Embedding(1, 768) | |
| self.graph_token_virtual_distance = nn.Embedding(1, self.attention_heads) | |
| K = self.num_kernel | |
| self.gbf: Callable[[Tensor, Tensor], Tensor] = GaussianLayer(K, self.edge_types) | |
| self.bias_proj: Callable[[Tensor], Tensor] = NonLinear(K, self.attention_heads) | |
| self.edge_proj: Callable[[Tensor], Tensor] = nn.Linear(K, self.embed_dim) | |
| self.node_proc: Callable[[Tensor, Tensor, Tensor], Tensor] = NodeTaskHead(self.embed_dim, self.attention_heads) | |
| def forward(self, node_feature, spatial_pos, in_degree, out_degree, edge_type_matrix, edge_input, node_position_ids): | |
| """ | |
| attn_bias:图中节点对之间的最短路径距离超过最短路径限制最大距离(spatial_pos_max)的位置为-inf,其余位置为0,形状为(n_graph, n_node+1, n_node+1) | |
| spatial_pos:图中节点对之间的最短路径长度,形状为(n_graph, n_node, n_node) | |
| x:图中节点的特征,形状为(n_graph, n_node, n_node_features) | |
| in_degree:图中节点的入度,形状为(n_graph, n_node) | |
| out_degree:图中节点的出度,形状为(n_graph, n_node) | |
| edge_input:图中节点对之间的最短路径(限制最短路径最大跳数为multi_hop_max_dist)上的边的特征,形状为(n_graph, n_node, n_node, multi_hop_max_dist, n_edge_features) | |
| attn_edge_type:图的边特征,形状为(n_graph, n_node, n_node, n_edge_features) | |
| :param batch_data: | |
| :return: | |
| """ | |
| # attn_bias, spatial_pos, x = batch_data.attn_bias, batch_data.spatial_pos, batch_data.x | |
| # in_degree, out_degree = batch_data.in_degree, batch_data.out_degree | |
| # edge_input, attn_edge_type = batch_data.edge_input, batch_data.attn_edge_type | |
| # graph_attn_bias | |
| attn_edge_type = self.edge_embedding(edge_type_matrix) | |
| edge_input = self.edge_embedding(edge_input)#.mean(-2) | |
| # 添加虚拟节点表示全图特征表示,之后按照图中正常节点处理 | |
| n_graph, n_node = node_feature.size()[:2] | |
| # graph_attn_bias = attn_bias.clone() | |
| # graph_attn_bias = graph_attn_bias.unsqueeze(1).repeat(1, self.attention_heads, 1, 1) # [n_graph, n_head, n_node+1, n_node+1] | |
| # spatial pos | |
| # 空间编码,节点之间最短路径长度对应的可学习标量 | |
| # [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node] | |
| spatial_pos_bias = self.spatial_pos_encoder(spatial_pos).permute(0, 3, 1, 2) | |
| # graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, :, 1:, 1:] + spatial_pos_bias | |
| # graph_attn_bias = spatial_pos_bias | |
| # reset spatial pos here | |
| # 所有节点都和虚拟节点直接有边相连,则所有节点和虚拟节点之间的最短路径长度为1 | |
| # t = self.graph_token_virtual_distance.weight.view(1, self.attention_heads, 1) | |
| # graph_attn_bias[:, :, 1:, 0] = graph_attn_bias[:, :, 1:, 0] + t | |
| # graph_attn_bias[:, :, 0, :] = graph_attn_bias[:, :, 0, :] + t | |
| # edge feature | |
| # 每个节点对沿最短路径计算边特征和可学习嵌入点积的平均值,并作为偏置项添加到注意模块中 | |
| if self.edge_type == 'multi_hop': | |
| spatial_pos_ = spatial_pos.clone() | |
| spatial_pos_[spatial_pos_ == 0] = 1 # set pad to 1 | |
| # set 1 to 1, x > 1 to x - 1 | |
| spatial_pos_ = torch.where(spatial_pos_ > 1, spatial_pos_ - 1, spatial_pos_) | |
| # if self.multi_hop_max_dist > 0: | |
| # spatial_pos_ = spatial_pos_.clamp(0, self.multi_hop_max_dist) | |
| # edge_input = edge_input[:, :, :, :self.multi_hop_max_dist, :] | |
| # [n_graph, n_node, n_node, max_dist, n_head] | |
| # edge_input = self.edge_encoder(edge_input).mean(-2) | |
| max_dist = edge_input.size(-2) | |
| edge_input_flat = edge_input.permute(3, 0, 1, 2, 4).reshape(max_dist, -1, self.attention_heads) | |
| edge_input_flat = torch.bmm(edge_input_flat, self.edge_dis_encoder.weight.reshape(-1, self.attention_heads, self.attention_heads)[:max_dist, :, :]) | |
| edge_input = edge_input_flat.reshape(max_dist, n_graph, n_node, n_node, self.attention_heads).permute(1, 2, 3, 0, 4) | |
| edge_input = (edge_input.sum(-2) / (spatial_pos_.float().unsqueeze(-1))).permute(0, 3, 1, 2) | |
| else: | |
| # [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node] | |
| edge_input = self.edge_encoder(attn_edge_type).mean(-2).permute(0, 3, 1, 2) | |
| # graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, :, 1:, 1:] + edge_input | |
| graph_attn_bias = spatial_pos_bias + edge_input | |
| # graph_attn_bias = graph_attn_bias + attn_bias.unsqueeze(1) # reset | |
| # graph_attn_bias = graph_attn_bias.contiguous().view(-1, 6, 6) | |
| # node feauture + graph token | |
| # node_feature = x # self.atom_encoder(x).sum(dim=-2) # [n_graph, n_node, n_hidden] | |
| # if self.flag and perturb is not None: | |
| # node_feature += perturb | |
| node_position_embedding = self.node_position_ids_encoder(node_position_ids) | |
| node_position_embedding = node_position_embedding.contiguous().view(n_graph, n_node, self.embed_dim) | |
| # print(node_position_embedding.shape) | |
| # 根据节点的入度、出度为每个节点分配两个实值嵌入向量,添加到节点特征中作为输入 | |
| node_feature = node_feature + self.in_degree_encoder(in_degree) + \ | |
| self.out_degree_encoder(out_degree) + node_position_embedding | |
| # print(node_feature.shape) | |
| # graph_token_feature = self.graph_token.weight.unsqueeze(0).repeat(n_graph, 1, 1) | |
| # graph_node_feature = torch.cat([graph_token_feature, node_feature], dim=1) | |
| # transfomrer encoder | |
| output = self.input_dropout(node_feature)#.permute(1, 0, 2) | |
| for enc_layer in self.layers: | |
| output = enc_layer(output, graph_attn_bias) | |
| output = self.final_ln(output) | |
| # output part | |
| # 整个图的表示是最后一层虚拟节点的特征 | |
| # if self.dataset_name == 'PCQM4M-LSC': | |
| # # get whole graph rep | |
| # output = self.out_proj(output[:, 0, :]) | |
| # else: | |
| # output = self.downstream_out_proj(output[:, 0, :]) | |
| # print(output.shape) | |
| return output | |
| def gaussian(x, mean, std): | |
| pi = 3.14159 | |
| a = (2 * pi) ** 0.5 | |
| return torch.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std) | |
| class GaussianLayer(nn.Module): | |
| def __init__(self, K=128, edge_types=1024): | |
| super().__init__() | |
| self.K = K | |
| self.means = nn.Embedding(1, K) | |
| self.stds = nn.Embedding(1, K) | |
| self.mul = nn.Embedding(edge_types, 1) | |
| self.bias = nn.Embedding(edge_types, 1) | |
| nn.init.uniform_(self.means.weight, 0, 3) | |
| nn.init.uniform_(self.stds.weight, 0, 3) | |
| nn.init.constant_(self.bias.weight, 0) | |
| nn.init.constant_(self.mul.weight, 1) | |
| def forward(self, x, edge_types): | |
| mul = self.mul(edge_types) | |
| bias = self.bias(edge_types) | |
| x = mul * x.unsqueeze(-1) + bias | |
| x = x.expand(-1, -1, -1, self.K) | |
| mean = self.means.weight.float().view(-1) | |
| std = self.stds.weight.float().view(-1).abs() + 1e-5 | |
| return gaussian(x.float(), mean, std).type_as(self.means.weight) | |
| class RBF(nn.Module): | |
| def __init__(self, K, edge_types): | |
| super().__init__() | |
| self.K = K | |
| self.means = nn.parameter.Parameter(torch.empty(K)) | |
| self.temps = nn.parameter.Parameter(torch.empty(K)) | |
| self.mul: Callable[..., Tensor] = nn.Embedding(edge_types, 1) | |
| self.bias: Callable[..., Tensor] = nn.Embedding(edge_types, 1) | |
| nn.init.uniform_(self.means, 0, 3) | |
| nn.init.uniform_(self.temps, 0.1, 10) | |
| nn.init.constant_(self.bias.weight, 0) | |
| nn.init.constant_(self.mul.weight, 1) | |
| def forward(self, x: Tensor, edge_types): | |
| mul = self.mul(edge_types) | |
| bias = self.bias(edge_types) | |
| x = mul * x.unsqueeze(-1) + bias | |
| mean = self.means.float() | |
| temp = self.temps.float().abs() | |
| return ((x - mean).square() * (-temp)).exp().type_as(self.means) | |
| class NonLinear(nn.Module): | |
| def __init__(self, input, output_size, hidden=None): | |
| super(NonLinear, self).__init__() | |
| if hidden is None: | |
| hidden = input | |
| self.layer1 = nn.Linear(input, hidden) | |
| self.layer2 = nn.Linear(hidden, output_size) | |
| def forward(self, x): | |
| x = F.gelu(self.layer1(x)) | |
| x = self.layer2(x) | |
| return x | |
| class NodeTaskHead(nn.Module): | |
| def __init__( | |
| self, | |
| embed_dim: int, | |
| num_heads: int, | |
| ): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.q_proj: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, embed_dim) | |
| self.k_proj: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, embed_dim) | |
| self.v_proj: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, embed_dim) | |
| self.num_heads = num_heads | |
| self.scaling = (embed_dim // num_heads) ** -0.5 | |
| self.force_proj1: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, 1) | |
| self.force_proj2: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, 1) | |
| self.force_proj3: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, 1) | |
| def forward( | |
| self, | |
| query: Tensor, | |
| attn_bias: Tensor, | |
| delta_pos: Tensor, | |
| ) -> Tensor: | |
| bsz, n_node, _ = query.size() | |
| q = (self.q_proj(query).view(bsz, n_node, self.num_heads, -1).transpose(1, 2) * self.scaling) | |
| k = self.k_proj(query).view(bsz, n_node, self.num_heads, -1).transpose(1, 2) | |
| v = self.v_proj(query).view(bsz, n_node, self.num_heads, -1).transpose(1, 2) | |
| attn = q @ k.transpose(-1, -2) # [bsz, head, n, n] | |
| attn_probs = softmax_dropout(attn.view(-1, n_node, n_node) + attn_bias, 0.1, self.training).view(bsz, self.num_heads, n_node, n_node) | |
| rot_attn_probs = attn_probs.unsqueeze(-1) * delta_pos.unsqueeze(1).type_as(attn_probs) # [bsz, head, n, n, 3] | |
| rot_attn_probs = rot_attn_probs.permute(0, 1, 4, 2, 3) | |
| x = rot_attn_probs @ v.unsqueeze(2) # [bsz, head , 3, n, d] | |
| x = x.permute(0, 3, 2, 1, 4).contiguous().view(bsz, n_node, 3, -1) | |
| f1 = self.force_proj1(x[:, :, 0, :]).view(bsz, n_node, 1) | |
| f2 = self.force_proj2(x[:, :, 1, :]).view(bsz, n_node, 1) | |
| f3 = self.force_proj3(x[:, :, 2, :]).view(bsz, n_node, 1) | |
| cur_force = torch.cat([f1, f2, f3], dim=-1).float() | |
| return cur_force | |