Source code for UM2N.model.extractor

# Author: Chunyang Wang
# GitHub Username: acse-cw1722

import torch
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from transformer_model import TransformerModel

__all__ = ["LocalFeatExtractor", "GlobalFeatExtractor"]


[docs] class LocalFeatExtractor(MessagePassing): """ Custom PyTorch geometric layer that performs feature extraction on local graph structure. The class extends the torch_geometric.nn.MessagePassing class and employs additive aggregation as the message-passing scheme. Attributes: lin_1 (torch.nn.Linear): First linear layer. lin_2 (torch.nn.Linear): Second linear layer. lin_3 (torch.nn.Linear): Third linear layer. activate (torch.nn.SELU): Activation function. """ def __init__(self, num_feat=10, out=16): """ Initialize the layer. Args: num_feat (int): Number of input features per node. out (int): Number of output features per node. """ super().__init__(aggr="add") # 1*distance + 2*feat + 2*coord num_in_feat = 1 + (num_feat - 2) * 2 + 2 self.lin_1 = torch.nn.Linear(num_in_feat, 64) self.lin_2 = torch.nn.Linear(64, 64) # minus 3 because dist, corrd is added back self.lin_3 = torch.nn.Linear(64, out - 1) self.activate = torch.nn.SELU()
[docs] def forward(self, input, edge_index): """ Forward pass. Args: input (Tensor): Node features. edge_index (Tensor): Edge indices. Returns: Tensor: Updated node features. """ local_feat = self.propagate(edge_index, x=input) return local_feat
[docs] def message(self, x_i, x_j): coord_idx = 2 x_i_coord = x_i[:, :coord_idx] x_j_coord = x_j[:, :coord_idx] x_i_feat = x_i[:, coord_idx:] x_j_feat = x_j[:, coord_idx:] x_coord_diff = x_j_coord - x_i_coord x_coord_dist = torch.norm(x_coord_diff, dim=1, keepdim=True) x_edge_feat = torch.cat( [x_coord_diff, x_coord_dist, x_i_feat, x_j_feat], dim=1 ) # [num_node, feat_dim] # print("x_i x_j ", x_i.shape, x_j.shape, "x_edge_feat ", x_edge_feat.shape) x_edge_feat = self.lin_1(x_edge_feat) x_edge_feat = self.activate(x_edge_feat) x_edge_feat = self.lin_2(x_edge_feat) x_edge_feat = self.activate(x_edge_feat) x_edge_feat = self.lin_3(x_edge_feat) x_edge_feat = self.activate(x_edge_feat) x_edge_feat = torch.cat([x_edge_feat, x_coord_dist], dim=1) return x_edge_feat
[docs] class GlobalFeatExtractor(torch.nn.Module): """ Custom PyTorch layer for global feature extraction. The class employs multiple convolutional layers and dropout layers. Attributes: conv1, conv2, conv3, conv4 (torch.nn.Conv2d): Convolutional layers. dropout (torch.nn.Dropout): Dropout layer. final_pool (torch.nn.AdaptiveAvgPool2d): Final pooling layer. """ def __init__(self, in_c, out_c, drop_p=0.2, use_drop=True): super().__init__() """ Initialize the layer. Args: in_c (int): Number of input channels. out_c (int): Number of output channels. drop_p (float, optional): Dropout probability. use_drop: (bool, optional): Use dropout layer or not, When it is set to `False`, this building block is exactly the block used in the original M2N model with out any change. Then set to `True`, it is used for MRN model. """ self.in_c = in_c self.out_c = out_c self.conv1 = torch.nn.Conv2d(in_c, 32, 3, padding=1, stride=1) self.conv2 = torch.nn.Conv2d(32, 64, 5, padding=2, stride=1) self.conv3 = torch.nn.Conv2d(64, 32, 3, padding=2, stride=1) self.conv4 = torch.nn.Conv2d(32, out_c, 3, padding=2, stride=1) self.use_drop = use_drop self.dropout = torch.nn.Dropout(drop_p) if use_drop else None self.final_pool = torch.nn.AdaptiveAvgPool2d(1)
[docs] def forward(self, data): """ Forward pass. Args: data (Tensor): Input data. Returns: Tensor: Extracted global features. """ x = self.conv1(data) x = F.selu(x) x = self.dropout(x) if self.use_drop else x x = self.conv2(x) x = F.selu(x) x = self.dropout(x) if self.use_drop else x x = self.conv3(x) x = F.selu(x) x = self.dropout(x) if self.use_drop else x x = self.conv4(x) # print(f"before selu {x.shape}") x = F.selu(x) x = self.dropout(x) if self.use_drop else x x = self.final_pool(x) # print(f"after final pool {x.shape}") x = x.reshape(-1, self.out_c) return x
[docs] class TransformerEncoder(torch.nn.Module): """ Mesh Refinement Network (MRN) implementing transformer as feature extrator and recurrent graph-based deformations. Attributes: num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block. """ def __init__( self, num_transformer_in=3, num_transformer_out=16, num_transformer_embed_dim=64, num_transformer_heads=4, num_transformer_layers=1, transformer_training_mask=False, transformer_key_padding_training_mask=False, transformer_attention_training_mask=False, transformer_training_mask_ratio_lower_bound=0.5, transformer_training_mask_ratio_upper_bound=0.9, deform_in_c=7, deform_out_type="coord", num_loop=3, device="cuda", ): """ Initialize MRN. Args: gfe_in_c (int): Input channels for the global feature extractor. lfe_in_c (int): Input channels for the local feature extractor. deform_in_c (int): Input channels for the deformer block. num_loop (int): Number of loops for the recurrent layer. """ super().__init__() self.device = device # self.num_loop = num_loop # self.hidden_size = 512 # set here self.mask_in_trainig = transformer_training_mask self.key_padding_mask_in_training = transformer_key_padding_training_mask self.attention_mask_in_training = transformer_attention_training_mask self.mask_ratio_ub = transformer_training_mask_ratio_upper_bound self.mask_ratio_lb = transformer_training_mask_ratio_lower_bound assert ( self.mask_ratio_ub >= self.mask_ratio_lb ), "Training mask ratio upper bound smaller than lower bound." self.num_transformer_in = num_transformer_in self.num_transformer_out = num_transformer_out self.num_transformer_embed_dim = num_transformer_embed_dim self.num_heads = num_transformer_heads self.num_layers = num_transformer_layers self.transformer_encoder = TransformerModel( input_dim=self.num_transformer_in, embed_dim=self.num_transformer_embed_dim, output_dim=self.num_transformer_out, num_heads=self.num_heads, num_layers=self.num_layers, ) # self.all_feat_c = (deform_in_c - 2) + self.num_transformer_out # use a linear layer to transform the input feature to hidden # state size # self.lin = nn.Linear(self.all_feat_c, self.hidden_size) def _transformer_forward(self, batch_size, input_q, input_kv, get_attens=False): """ Forward pass for MRN. Args: data (Data): Input data object containing mesh and feature info. Returns: coord (Tensor): Deformed coordinates. """ # mesh_feat: [num_nodes * batch_size, 4] # mesh_feat [coord_x, coord_y, u, hessian_norm] transformer_input_q = input_q.view(batch_size, -1, input_q.shape[-1]) transformer_input_kv = input_kv.view(batch_size, -1, input_kv.shape[-1]) node_num = transformer_input_q.shape[1] # print("transformer input shape ", transformer_input_q.shape, transformer_input_kv.shape) key_padding_mask = None attention_mask = None if self.train and self.mask_in_trainig: mask_ratio = (self.mask_ratio_ub - self.mask_ratio_lb) * torch.rand( 1 ) + self.mask_ratio_lb masked_num = int(node_num * mask_ratio) mask = torch.randperm(node_num)[:masked_num] if self.key_padding_mask_in_training: # Key padding mask key_padding_mask = torch.zeros( [batch_size, node_num], dtype=torch.bool ).to(self.device) key_padding_mask[:, mask] = True # print(key_padding_mask.shape, key_padding_mask) # print("Now is training") elif self.attention_mask_in_training: # Attention mask attention_mask = torch.zeros( [batch_size * self.num_heads, node_num, node_num], dtype=torch.bool ).to(self.device) attention_mask[:, mask, mask] = True features = self.transformer_encoder( transformer_input_q, transformer_input_kv, transformer_input_kv, key_padding_mask=key_padding_mask, attention_mask=attention_mask, ) # features = features.view(-1, self.num_transformer_out) # features = torch.cat([boundary, features], dim=1) # print(f"transformer raw features: {features.shape}") features = F.selu(features) if not get_attens: return features else: # TODO: adapt q k v atten_scores = self.transformer_encoder.get_attention_scores( x=transformer_input_q, key_padding_mask=key_padding_mask ) return features, atten_scores
[docs] def forward(self, data): # batch_size = data.conv_feat.shape[0] batch_size = data.shape[0] feat_dim = data.shape[-1] # input_q, input_kv, boundary input_q = data.view(-1, feat_dim) input_kv = data.view(-1, feat_dim) # [coord_ori_x, coord_ori_y, u, hessian_norm] # intput_features = torch.cat([coord_ori, data.mesh_feat[:, 2:4]], dim=-1) # print(f"input q shape: {input_q.shape} input kv shape: {input_kv.shape}") hidden = self._transformer_forward(batch_size, input_q, input_kv) # print(f"global feat before reshape: {hidden.shape}") feat_dim = hidden.shape[-1] return hidden.view(-1, feat_dim)