| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
|
|
| class Graph(): |
| """ The Graph to model the skeletons |
| |
| Args: |
| strategy (string): must be one of the follow candidates |
| - uniform: Uniform Labeling |
| - distance: Distance Partitioning |
| - spatial: Spatial Configuration |
| max_hop (int): the maximal distance between two connected nodes |
| dilation (int): controls the spacing between the kernel points |
| |
| """ |
| def __init__(self, |
| strategy='spatial', |
| max_hop=1, |
| dilation=1): |
| self.max_hop = max_hop |
| self.dilation = dilation |
|
|
| self.get_edge() |
| self.hop_dis = get_hop_distance(self.num_node, |
| self.edge, |
| max_hop=max_hop) |
| self.get_adjacency(strategy) |
|
|
| def __str__(self): |
| return self.A |
|
|
| def get_edge(self): |
| |
| self.num_node = 22 |
| self_link = [(i, i) for i in range(self.num_node)] |
| neighbor_link = [(1,0), (2,1), (3,2), (4,3), (5,0), (6,5), (7,6), (8,7), (9,0), (10,9), (11,10), (12,11), \ |
| (13,12), (14,11), (15,14), (16,15), (17,16), (18,11), (19,18), (20,19), (21,20)] |
| self.edge = self_link + neighbor_link |
| self.center = 0 |
|
|
| def get_adjacency(self, strategy): |
| valid_hop = range(0, self.max_hop + 1, self.dilation) |
| adjacency = np.zeros((self.num_node, self.num_node)) |
| for hop in valid_hop: |
| adjacency[self.hop_dis == hop] = 1 |
| normalize_adjacency = normalize_digraph(adjacency) |
|
|
| if strategy == 'uniform': |
| A = np.zeros((1, self.num_node, self.num_node)) |
| A[0] = normalize_adjacency |
| self.A = A |
| elif strategy == 'distance': |
| A = np.zeros((len(valid_hop), self.num_node, self.num_node)) |
| for i, hop in enumerate(valid_hop): |
| A[i][self.hop_dis == hop] = normalize_adjacency[self.hop_dis == |
| hop] |
| self.A = A |
| elif strategy == 'spatial': |
| A = [] |
| for hop in valid_hop: |
| a_root = np.zeros((self.num_node, self.num_node)) |
| a_close = np.zeros((self.num_node, self.num_node)) |
| a_further = np.zeros((self.num_node, self.num_node)) |
| for i in range(self.num_node): |
| for j in range(self.num_node): |
| if self.hop_dis[j, i] == hop: |
| if self.hop_dis[j, self.center] == self.hop_dis[ |
| i, self.center]: |
| a_root[j, i] = normalize_adjacency[j, i] |
| elif self.hop_dis[j, self.center] > self.hop_dis[ |
| i, self.center]: |
| a_close[j, i] = normalize_adjacency[j, i] |
| else: |
| a_further[j, i] = normalize_adjacency[j, i] |
| if hop == 0: |
| A.append(a_root) |
| else: |
| A.append(a_root + a_close) |
| A.append(a_further) |
| A = np.stack(A) |
| self.A = A |
| else: |
| raise ValueError("Do Not Exist This Strategy") |
|
|
| def get_hop_distance(num_node, edge, max_hop=1): |
| A = np.zeros((num_node, num_node)) |
| for i, j in edge: |
| A[j, i] = 1 |
| A[i, j] = 1 |
|
|
| |
| hop_dis = np.zeros((num_node, num_node)) + np.inf |
| transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)] |
| arrive_mat = (np.stack(transfer_mat) > 0) |
| for d in range(max_hop, -1, -1): |
| hop_dis[arrive_mat[d]] = d |
| return hop_dis |
|
|
| def normalize_digraph(A): |
| Dl = np.sum(A, 0) |
| num_node = A.shape[0] |
| Dn = np.zeros((num_node, num_node)) |
| for i in range(num_node): |
| if Dl[i] > 0: |
| Dn[i, i] = Dl[i]**(-1) |
| AD = np.dot(A, Dn) |
| return AD |
|
|
| def normalize_undigraph(A): |
| Dl = np.sum(A, 0) |
| num_node = A.shape[0] |
| Dn = np.zeros((num_node, num_node)) |
| for i in range(num_node): |
| if Dl[i] > 0: |
| Dn[i, i] = Dl[i]**(-0.5) |
| DAD = np.dot(np.dot(Dn, A), Dn) |
| return DAD |
|
|
| def zero(x): |
| return 0 |
|
|
| def iden(x): |
| return x |
|
|
| class ConvTemporalGraphical(nn.Module): |
| r"""The basic module for applying a graph convolution. |
| |
| Args: |
| in_channels (int): Number of channels in the input sequence data |
| out_channels (int): Number of channels produced by the convolution |
| kernel_size (int): Size of the graph convolving kernel |
| t_kernel_size (int): Size of the temporal convolving kernel |
| t_stride (int, optional): Stride of the temporal convolution. Default: 1 |
| t_padding (int, optional): Temporal zero-padding added to both sides of |
| the input. Default: 0 |
| t_dilation (int, optional): Spacing between temporal kernel elements. |
| Default: 1 |
| bias (bool, optional): If ``True``, adds a learnable bias to the output. |
| Default: ``True`` |
| |
| Shape: |
| - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format |
| - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format |
| - Output[0]: Output graph sequence in :math:`(N, out_channels, T_{out}, V)` format |
| - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format |
| |
| where |
| :math:`N` is a batch size, |
| :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`, |
| :math:`T_{in}/T_{out}` is a length of input/output sequence, |
| :math:`V` is the number of graph nodes. |
| """ |
| def __init__(self, |
| in_channels, |
| out_channels, |
| kernel_size, |
| t_kernel_size=1, |
| t_stride=1, |
| t_padding=0, |
| t_dilation=1, |
| bias=True): |
| super().__init__() |
|
|
| self.kernel_size = kernel_size |
| self.conv = nn.Conv2d(in_channels, |
| out_channels * kernel_size, |
| kernel_size=(t_kernel_size, 1), |
| padding=(t_padding, 0), |
| stride=(t_stride, 1), |
| dilation=(t_dilation, 1), |
| bias=bias) |
|
|
| def forward(self, x, A): |
| assert A.size(0) == self.kernel_size |
|
|
| x = self.conv(x) |
|
|
| n, kc, t, v = x.size() |
| x = x.view(n, self.kernel_size, kc // self.kernel_size, t, v) |
| x = torch.einsum('nkctv,kvw->nctw', (x, A)) |
|
|
| return x.contiguous(), A |
|
|
| class st_gcn_block(nn.Module): |
| r"""Applies a spatial temporal graph convolution over an input graph sequence. |
| |
| Args: |
| in_channels (int): Number of channels in the input sequence data |
| out_channels (int): Number of channels produced by the convolution |
| kernel_size (tuple): Size of the temporal convolving kernel and graph convolving kernel |
| stride (int, optional): Stride of the temporal convolution. Default: 1 |
| dropout (int, optional): Dropout rate of the final output. Default: 0 |
| residual (bool, optional): If ``True``, applies a residual mechanism. Default: ``True`` |
| |
| Shape: |
| - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format |
| - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format |
| - Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format |
| - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format |
| |
| where |
| :math:`N` is a batch size, |
| :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`, |
| :math:`T_{in}/T_{out}` is a length of input/output sequence, |
| :math:`V` is the number of graph nodes. |
| |
| """ |
| def __init__(self, |
| in_channels, |
| out_channels, |
| kernel_size, |
| stride=1, |
| dropout=0, |
| residual=True): |
| super().__init__() |
|
|
| assert len(kernel_size) == 2 |
| assert kernel_size[0] % 2 == 1 |
| padding = ((kernel_size[0] - 1) // 2, 0) |
|
|
| self.gcn = ConvTemporalGraphical(in_channels, out_channels, |
| kernel_size[1]) |
|
|
| self.tcn = nn.Sequential( |
| nn.BatchNorm2d(out_channels), |
| nn.ReLU(inplace=True), |
| nn.Conv2d( |
| out_channels, |
| out_channels, |
| (kernel_size[0], 1), |
| (stride, 1), |
| padding, |
| ), |
| nn.BatchNorm2d(out_channels), |
| nn.Dropout(dropout, inplace=True), |
| ) |
|
|
| if not residual: |
| self.residual = zero |
|
|
| elif (in_channels == out_channels) and (stride == 1): |
| self.residual = iden |
|
|
| else: |
| self.residual = nn.Sequential( |
| nn.Conv2d(in_channels, |
| out_channels, |
| kernel_size=1, |
| stride=(stride, 1)), |
| nn.BatchNorm2d(out_channels), |
| ) |
|
|
| self.relu = nn.ReLU(inplace=True) |
|
|
| def forward(self, x, A): |
|
|
| res = self.residual(x) |
| x, A = self.gcn(x, A) |
| x = self.tcn(x) + res |
|
|
| return self.relu(x), A |
| |
| class ST_GCN_18(nn.Module): |
| r"""Spatial temporal graph convolutional networks. |
| |
| Args: |
| in_channels (int): Number of channels in the input data |
| num_class (int): Number of classes for the classification task |
| graph_cfg (dict): The arguments for building the graph |
| edge_importance_weighting (bool): If ``True``, adds a learnable |
| importance weighting to the edges of the graph |
| **kwargs (optional): Other parameters for graph convolution units |
| |
| Shape: |
| - Input: :math:`(N, in_channels, T_{in}, V_{in}, M_{in})` |
| - Output: :math:`(N, num_class)` where |
| :math:`N` is a batch size, |
| :math:`T_{in}` is a length of input sequence, |
| :math:`V_{in}` is the number of graph nodes, |
| :math:`M_{in}` is the number of instance in a frame. |
| """ |
| def __init__(self, |
| in_channels, |
| edge_importance_weighting=True, |
| data_bn=True, |
| **kwargs): |
| super().__init__() |
|
|
| |
| self.graph = Graph() |
| A = torch.tensor(self.graph.A, |
| dtype=torch.float32, |
| requires_grad=False) |
| self.register_buffer('A', A) |
|
|
| |
| spatial_kernel_size = A.size(0) |
| temporal_kernel_size = 9 |
| kernel_size = (temporal_kernel_size, spatial_kernel_size) |
| self.data_bn = nn.BatchNorm1d(in_channels * |
| A.size(1)) if data_bn else iden |
| kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'} |
| self.st_gcn_networks = nn.ModuleList(( |
| st_gcn_block(in_channels, |
| 64, |
| kernel_size, |
| 1, |
| residual=False, |
| **kwargs0), |
| st_gcn_block(64, 64, kernel_size, 1, **kwargs), |
| st_gcn_block(64, 64, kernel_size, 1, **kwargs), |
| st_gcn_block(64, 64, kernel_size, 1, **kwargs), |
| st_gcn_block(64, 128, kernel_size, 2, **kwargs), |
| st_gcn_block(128, 128, kernel_size, 1, **kwargs), |
| st_gcn_block(128, 128, kernel_size, 1, **kwargs), |
| st_gcn_block(128, 256, kernel_size, 2, **kwargs), |
| st_gcn_block(256, 256, kernel_size, 1, **kwargs), |
| st_gcn_block(256, 512, kernel_size, 1, **kwargs), |
| )) |
|
|
| |
| if edge_importance_weighting: |
| self.edge_importance = nn.ParameterList([ |
| nn.Parameter(torch.ones(self.A.size())) |
| for i in self.st_gcn_networks |
| ]) |
| else: |
| self.edge_importance = [1] * len(self.st_gcn_networks) |
|
|
| def forward(self, x): |
| |
| N, C, T, V, M = x.size() |
| x = x.permute(0, 4, 3, 1, 2).contiguous() |
| x = x.view(N * M, V * C, T) |
| x = self.data_bn(x) |
| x = x.view(N, M, V, C, T) |
| x = x.permute(0, 1, 3, 4, 2).contiguous() |
| x = x.view(N * M, C, T, V) |
|
|
| |
| for gcn, importance in zip(self.st_gcn_networks, self.edge_importance): |
| x, _ = gcn(x, self.A * importance) |
|
|
| |
| x = F.avg_pool2d(x, x.size()[2:]) |
| x = x.view(N, M, -1, 1, 1).mean(dim=1) |
|
|
| return x |