Source code for UM2N.model.extractor
# Author: Chunyang Wang
# GitHub Username: acse-cw1722
import torch
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from transformer_model import TransformerModel
__all__ = ["LocalFeatExtractor", "GlobalFeatExtractor"]
[docs]
class LocalFeatExtractor(MessagePassing):
    """
    Custom PyTorch geometric layer that performs feature extraction
    on local graph structure.
    The class extends the torch_geometric.nn.MessagePassing class
    and employs additive aggregation as the message-passing scheme.
    Attributes:
        lin_1 (torch.nn.Linear): First linear layer.
        lin_2 (torch.nn.Linear): Second linear layer.
        lin_3 (torch.nn.Linear): Third linear layer.
        activate (torch.nn.SELU): Activation function.
    """
    def __init__(self, num_feat=10, out=16):
        """
        Initialize the layer.
        Args:
            num_feat (int): Number of input features per node.
            out (int): Number of output features per node.
        """
        super().__init__(aggr="add")
        # 1*distance + 2*feat + 2*coord
        num_in_feat = 1 + (num_feat - 2) * 2 + 2
        self.lin_1 = torch.nn.Linear(num_in_feat, 64)
        self.lin_2 = torch.nn.Linear(64, 64)
        # minus 3 because dist, corrd is added back
        self.lin_3 = torch.nn.Linear(64, out - 1)
        self.activate = torch.nn.SELU()
[docs]
    def forward(self, input, edge_index):
        """
        Forward pass.
        Args:
            input (Tensor): Node features.
            edge_index (Tensor): Edge indices.
        Returns:
            Tensor: Updated node features.
        """
        local_feat = self.propagate(edge_index, x=input)
        return local_feat
[docs]
    def message(self, x_i, x_j):
        coord_idx = 2
        x_i_coord = x_i[:, :coord_idx]
        x_j_coord = x_j[:, :coord_idx]
        x_i_feat = x_i[:, coord_idx:]
        x_j_feat = x_j[:, coord_idx:]
        x_coord_diff = x_j_coord - x_i_coord
        x_coord_dist = torch.norm(x_coord_diff, dim=1, keepdim=True)
        x_edge_feat = torch.cat(
            [x_coord_diff, x_coord_dist, x_i_feat, x_j_feat], dim=1
        )  # [num_node, feat_dim]
        # print("x_i x_j ", x_i.shape, x_j.shape, "x_edge_feat ", x_edge_feat.shape)
        x_edge_feat = self.lin_1(x_edge_feat)
        x_edge_feat = self.activate(x_edge_feat)
        x_edge_feat = self.lin_2(x_edge_feat)
        x_edge_feat = self.activate(x_edge_feat)
        x_edge_feat = self.lin_3(x_edge_feat)
        x_edge_feat = self.activate(x_edge_feat)
        x_edge_feat = torch.cat([x_edge_feat, x_coord_dist], dim=1)
        return x_edge_feat
[docs]
class GlobalFeatExtractor(torch.nn.Module):
    """
    Custom PyTorch layer for global feature extraction.
    The class employs multiple convolutional layers and dropout layers.
    Attributes:
        conv1, conv2, conv3, conv4 (torch.nn.Conv2d): Convolutional layers.
        dropout (torch.nn.Dropout): Dropout layer.
        final_pool (torch.nn.AdaptiveAvgPool2d): Final pooling layer.
    """
    def __init__(self, in_c, out_c, drop_p=0.2, use_drop=True):
        super().__init__()
        """
        Initialize the layer.
        Args:
            in_c (int): Number of input channels.
            out_c (int): Number of output channels.
            drop_p (float, optional): Dropout probability.
            use_drop: (bool, optional): Use dropout layer or not,
                When it is set to `False`, this building block is exactly
                the block used in the original M2N model with out any
                change. Then set to `True`, it is used for MRN model.
        """
        self.in_c = in_c
        self.out_c = out_c
        self.conv1 = torch.nn.Conv2d(in_c, 32, 3, padding=1, stride=1)
        self.conv2 = torch.nn.Conv2d(32, 64, 5, padding=2, stride=1)
        self.conv3 = torch.nn.Conv2d(64, 32, 3, padding=2, stride=1)
        self.conv4 = torch.nn.Conv2d(32, out_c, 3, padding=2, stride=1)
        self.use_drop = use_drop
        self.dropout = torch.nn.Dropout(drop_p) if use_drop else None
        self.final_pool = torch.nn.AdaptiveAvgPool2d(1)
[docs]
    def forward(self, data):
        """
        Forward pass.
        Args:
            data (Tensor): Input data.
        Returns:
            Tensor: Extracted global features.
        """
        x = self.conv1(data)
        x = F.selu(x)
        x = self.dropout(x) if self.use_drop else x
        x = self.conv2(x)
        x = F.selu(x)
        x = self.dropout(x) if self.use_drop else x
        x = self.conv3(x)
        x = F.selu(x)
        x = self.dropout(x) if self.use_drop else x
        x = self.conv4(x)
        # print(f"before selu {x.shape}")
        x = F.selu(x)
        x = self.dropout(x) if self.use_drop else x
        x = self.final_pool(x)
        # print(f"after final pool {x.shape}")
        x = x.reshape(-1, self.out_c)
        return x
[docs]
class TransformerEncoder(torch.nn.Module):
    """
    Mesh Refinement Network (MRN) implementing transformer as feature
        extrator and recurrent graph-based deformations.
    Attributes:
        num_loop (int): Number of loops for the recurrent layer.
        gfe_out_c (int): Output channels for global feature extractor.
        lfe_out_c (int): Output channels for local feature extractor.
        hidden_size (int): Size of the hidden layer.
        gfe (GlobalFeatExtractor): Global feature extractor.
        lfe (LocalFeatExtractor): Local feature extractor.
        lin (nn.Linear): Linear layer for feature transformation.
        deformer (RecurrentGATConv): GAT-based deformer block.
    """
    def __init__(
        self,
        num_transformer_in=3,
        num_transformer_out=16,
        num_transformer_embed_dim=64,
        num_transformer_heads=4,
        num_transformer_layers=1,
        transformer_training_mask=False,
        transformer_key_padding_training_mask=False,
        transformer_attention_training_mask=False,
        transformer_training_mask_ratio_lower_bound=0.5,
        transformer_training_mask_ratio_upper_bound=0.9,
        deform_in_c=7,
        deform_out_type="coord",
        num_loop=3,
        device="cuda",
    ):
        """
        Initialize MRN.
        Args:
            gfe_in_c (int): Input channels for the global feature extractor.
            lfe_in_c (int): Input channels for the local feature extractor.
            deform_in_c (int): Input channels for the deformer block.
            num_loop (int): Number of loops for the recurrent layer.
        """
        super().__init__()
        self.device = device
        # self.num_loop = num_loop
        # self.hidden_size = 512  # set here
        self.mask_in_trainig = transformer_training_mask
        self.key_padding_mask_in_training = transformer_key_padding_training_mask
        self.attention_mask_in_training = transformer_attention_training_mask
        self.mask_ratio_ub = transformer_training_mask_ratio_upper_bound
        self.mask_ratio_lb = transformer_training_mask_ratio_lower_bound
        assert (
            self.mask_ratio_ub >= self.mask_ratio_lb
        ), "Training mask ratio upper bound smaller than lower bound."
        self.num_transformer_in = num_transformer_in
        self.num_transformer_out = num_transformer_out
        self.num_transformer_embed_dim = num_transformer_embed_dim
        self.num_heads = num_transformer_heads
        self.num_layers = num_transformer_layers
        self.transformer_encoder = TransformerModel(
            input_dim=self.num_transformer_in,
            embed_dim=self.num_transformer_embed_dim,
            output_dim=self.num_transformer_out,
            num_heads=self.num_heads,
            num_layers=self.num_layers,
        )
        # self.all_feat_c = (deform_in_c - 2) + self.num_transformer_out
        # use a linear layer to transform the input feature to hidden
        # state size
        # self.lin = nn.Linear(self.all_feat_c, self.hidden_size)
    def _transformer_forward(self, batch_size, input_q, input_kv, get_attens=False):
        """
        Forward pass for MRN.
        Args:
            data (Data): Input data object containing mesh and feature info.
        Returns:
            coord (Tensor): Deformed coordinates.
        """
        # mesh_feat: [num_nodes * batch_size, 4]
        # mesh_feat [coord_x, coord_y, u, hessian_norm]
        transformer_input_q = input_q.view(batch_size, -1, input_q.shape[-1])
        transformer_input_kv = input_kv.view(batch_size, -1, input_kv.shape[-1])
        node_num = transformer_input_q.shape[1]
        # print("transformer input shape ", transformer_input_q.shape, transformer_input_kv.shape)
        key_padding_mask = None
        attention_mask = None
        if self.train and self.mask_in_trainig:
            mask_ratio = (self.mask_ratio_ub - self.mask_ratio_lb) * torch.rand(
                1
            ) + self.mask_ratio_lb
            masked_num = int(node_num * mask_ratio)
            mask = torch.randperm(node_num)[:masked_num]
            if self.key_padding_mask_in_training:
                # Key padding mask
                key_padding_mask = torch.zeros(
                    [batch_size, node_num], dtype=torch.bool
                ).to(self.device)
                key_padding_mask[:, mask] = True
            # print(key_padding_mask.shape, key_padding_mask)
            # print("Now is training")
            elif self.attention_mask_in_training:
                # Attention mask
                attention_mask = torch.zeros(
                    [batch_size * self.num_heads, node_num, node_num], dtype=torch.bool
                ).to(self.device)
                attention_mask[:, mask, mask] = True
        features = self.transformer_encoder(
            transformer_input_q,
            transformer_input_kv,
            transformer_input_kv,
            key_padding_mask=key_padding_mask,
            attention_mask=attention_mask,
        )
        # features = features.view(-1, self.num_transformer_out)
        # features = torch.cat([boundary, features], dim=1)
        # print(f"transformer raw features: {features.shape}")
        features = F.selu(features)
        if not get_attens:
            return features
        else:
            # TODO: adapt q k v
            atten_scores = self.transformer_encoder.get_attention_scores(
                x=transformer_input_q, key_padding_mask=key_padding_mask
            )
            return features, atten_scores
[docs]
    def forward(self, data):
        # batch_size = data.conv_feat.shape[0]
        batch_size = data.shape[0]
        feat_dim = data.shape[-1]
        # input_q, input_kv, boundary
        input_q = data.view(-1, feat_dim)
        input_kv = data.view(-1, feat_dim)
        # [coord_ori_x, coord_ori_y, u, hessian_norm]
        # intput_features = torch.cat([coord_ori, data.mesh_feat[:, 2:4]], dim=-1)
        # print(f"input q shape: {input_q.shape} input kv shape: {input_kv.shape}")
        hidden = self._transformer_forward(batch_size, input_q, input_kv)
        # print(f"global feat before reshape: {hidden.shape}")
        feat_dim = hidden.shape[-1]
        return hidden.view(-1, feat_dim)