Source code for UM2N.model.MRT
# Author: Chunyang Wang
# GitHub Username: acse-cw1722
# Modified by Mingrui Zhang
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
cur_dir = os.path.dirname(__file__)
sys.path.append(cur_dir)
from deformer import RecurrentGATConv  # noqa: E402
from transformer_model import TransformerModel
__all__ = ["MRTransformer"]
[docs]
class MRTransformer(torch.nn.Module):
    """
    Mesh Refinement Network (MRN) implementing transformer as feature
        extrator and recurrent graph-based deformations.
    Attributes:
        num_loop (int): Number of loops for the recurrent layer.
        gfe_out_c (int): Output channels for global feature extractor.
        lfe_out_c (int): Output channels for local feature extractor.
        hidden_size (int): Size of the hidden layer.
        gfe (GlobalFeatExtractor): Global feature extractor.
        lfe (LocalFeatExtractor): Local feature extractor.
        lin (nn.Linear): Linear layer for feature transformation.
        deformer (RecurrentGATConv): GAT-based deformer block.
    """
    def __init__(
        self,
        num_transformer_in=4,
        num_transformer_out=16,
        num_transformer_embed_dim=64,
        num_transformer_heads=4,
        num_transformer_layers=1,
        transformer_training_mask=False,
        transformer_key_padding_training_mask=False,
        transformer_attention_training_mask=False,
        transformer_training_mask_ratio_lower_bound=0.5,
        transformer_training_mask_ratio_upper_bound=0.9,
        deform_in_c=7,
        deform_out_type="coord",
        num_loop=3,
        device="cuda",
    ):
        """
        Initialize MRN.
        Args:
            gfe_in_c (int): Input channels for the global feature extractor.
            lfe_in_c (int): Input channels for the local feature extractor.
            deform_in_c (int): Input channels for the deformer block.
            num_loop (int): Number of loops for the recurrent layer.
        """
        super().__init__()
        self.device = device
        self.num_loop = num_loop
        self.hidden_size = 512  # set here
        self.mask_in_trainig = transformer_training_mask
        self.key_padding_mask_in_training = transformer_key_padding_training_mask
        self.attention_mask_in_training = transformer_attention_training_mask
        self.mask_ratio_ub = transformer_training_mask_ratio_upper_bound
        self.mask_ratio_lb = transformer_training_mask_ratio_lower_bound
        assert (
            self.mask_ratio_ub >= self.mask_ratio_lb
        ), "Training mask ratio upper bound smaller than lower bound."
        self.num_transformer_in = num_transformer_in
        self.num_transformer_out = num_transformer_out
        self.num_transformer_embed_dim = num_transformer_embed_dim
        self.num_heads = num_transformer_heads
        self.num_layers = num_transformer_layers
        self.transformer_encoder = TransformerModel(
            input_dim=self.num_transformer_in,
            embed_dim=self.num_transformer_embed_dim,
            output_dim=self.num_transformer_out,
            num_heads=self.num_heads,
            num_layers=self.num_layers,
        )
        self.all_feat_c = (deform_in_c - 2) + self.num_transformer_out
        # use a linear layer to transform the input feature to hidden
        # state size
        self.lin = nn.Linear(self.all_feat_c, self.hidden_size)
        # Mapping embedding to monitor
        self.to_monitor_1 = nn.Linear(self.hidden_size, self.hidden_size // 8)
        self.to_monitor_2 = nn.Linear(self.hidden_size // 8, self.hidden_size // 16)
        self.to_monitor_3 = nn.Linear(self.hidden_size // 16, 1)
        self.deformer = RecurrentGATConv(
            coord_size=2,
            hidden_size=self.hidden_size,
            heads=6,
            concat=False,
            output_type=deform_out_type,
            device=device,
        )
    def _transformer_forward(
        self, batch_size, input_q, input_kv, boundary, get_attens=False
    ):
        """
        Forward pass for MRN.
        Args:
            data (Data): Input data object containing mesh and feature info.
        Returns:
            coord (Tensor): Deformed coordinates.
        """
        # mesh_feat: [num_nodes * batch_size, 4]
        # mesh_feat [coord_x, coord_y, u, hessian_norm]
        transformer_input_q = input_q.view(batch_size, -1, input_q.shape[-1])
        transformer_input_kv = input_kv.view(batch_size, -1, input_kv.shape[-1])
        node_num = transformer_input_q.shape[1]
        # print(transformer_input_q.shape, transformer_input_kv.shape)
        key_padding_mask = None
        attention_mask = None
        if self.train and self.mask_in_trainig:
            mask_ratio = (self.mask_ratio_ub - self.mask_ratio_lb) * torch.rand(
                1
            ) + self.mask_ratio_lb
            masked_num = int(node_num * mask_ratio)
            mask = torch.randperm(node_num)[:masked_num]
            if self.key_padding_mask_in_training:
                # Key padding mask
                key_padding_mask = torch.zeros(
                    [batch_size, node_num], dtype=torch.bool
                ).to(self.device)
                key_padding_mask[:, mask] = True
            # print(key_padding_mask.shape, key_padding_mask)
            # print("Now is training")
            elif self.attention_mask_in_training:
                # Attention mask
                attention_mask = torch.zeros(
                    [batch_size * self.num_heads, node_num, node_num], dtype=torch.bool
                ).to(self.device)
                attention_mask[:, mask, mask] = True
        features = self.transformer_encoder(
            transformer_input_q,
            transformer_input_kv,
            transformer_input_kv,
            key_padding_mask=key_padding_mask,
            attention_mask=attention_mask,
        )
        features = features.view(-1, self.num_transformer_out)
        features = torch.cat([boundary, features], dim=1)
        # print(f"transformer raw features: {features.shape}")
        features = F.selu(self.lin(features))
        if not get_attens:
            return features
        else:
            # TODO: adapt q k v
            atten_scores = self.transformer_encoder.get_attention_scores(
                x=transformer_input_q, key_padding_mask=key_padding_mask
            )
            return features, atten_scores
[docs]
    def transformer_monitor(self, data, input_q, input_kv, boundary):
        batch_size = data.conv_feat.shape[0]
        # [coord_ori_x, coord_ori_y, u, hessian_norm]
        # intput_features = torch.cat([coord_ori, data.mesh_feat[:, 2:4]], dim=-1)
        # print(f"input q shape: {input_q.shape} input kv shape: {input_kv.shape}")
        hidden = self._transformer_forward(batch_size, input_q, input_kv, boundary)
        # TODO: more sampling points inspired by neural operator
        # edge_idx = data.edge_index_with_cluster.reshape(2, -1)
        # print("input data after reshape ", edge_idx.shape)
        # ===== Ablation for hessian norm as direct input to the deformer =====
        # hidden = data.mesh_feat[:, -1].unsqueeze(-1)
        # hidden = torch.cat([x_feat[:, 2:], hidden], dim=1)
        # hidden = F.selu(self.lin(hidden))
        # =====================================================================
        return hidden
[docs]
    def move(
        self,
        data,
        input_q,
        input_kv,
        mesh_query,
        sampled_queries,
        sampled_queries_edge_index,
        num_step=1,
        poly_mesh=False,
    ):
        """
        Move the mesh according to the deformation learned, with given number
            steps.
        Args:
            data (Data): Input data object containing mesh and feature info.
            num_step (int): Number of deformation steps.
        Returns:
            coord (Tensor): Deformed coordinates.
        """
        bd_mask = data.bd_mask
        poly_mesh = False
        if data.poly_mesh is not False:
            poly_mesh = True if data.poly_mesh.sum() > 0 else False
        edge_idx = data.edge_index
        boundary = data.x[:, 2:]
        hidden = self.transformer_monitor(data, input_q, input_kv, boundary)
        coord = mesh_query
        model_output = None
        out_monitor = None
        # Recurrent GAT deform
        for i in range(num_step):
            (coord, model_output), hidden, (phix, phiy) = self.deformer(
                coord, hidden, edge_idx, mesh_query, bd_mask, poly_mesh
            )
        coord_extra = sampled_queries
        # Recurrent GAT deform (extra sampled)
        for i in range(num_step):
            (coord_extra, model_output_extra), hidden, (phix_extra, phiy_extra) = (
                self.deformer(
                    coord_extra,
                    hidden,
                    sampled_queries_edge_index,
                    sampled_queries,
                    bd_mask,
                    poly_mesh,
                )
            )
        coord_output = torch.cat([coord, coord_extra], dim=0)
        model_raw_output = torch.cat([model_output, model_output_extra], dim=0)
        # phix_output = torch.cat([phix, phix_extra], dim=0)
        # phiy_output = torch.cat([phiy, phiy_extra], dim=0)
        phix_output = phix_extra
        phiy_output = phiy_extra
        # print(phix.shape, phix_extra.shape)
        return (coord_output, model_raw_output, out_monitor), (phix_output, phiy_output)
[docs]
    def forward(
        self,
        data,
        input_q,
        input_kv,
        mesh_query,
        sampled_queries,
        sampled_queries_edge_index,
        poly_mesh=False,
    ):
        """
        Forward pass for MRN.
        Args:
            data (Data): Input data object containing mesh and feature info.
        Returns:
            coord (Tensor): Deformed coordinates.
        """
        bd_mask = data.bd_mask
        poly_mesh = False
        if data.poly_mesh is not False:
            poly_mesh = True if data.poly_mesh.sum() > 0 else False
        edge_idx = data.edge_index
        boundary = data.x[:, 2:].view(-1, 1)
        hidden = self.transformer_monitor(data, input_q, input_kv, boundary)
        coord = mesh_query
        model_output = None
        out_monitor = None
        # Recurrent GAT deform
        for i in range(self.num_loop):
            (coord, model_output), hidden, (phix, phiy) = self.deformer(
                coord, hidden, edge_idx, mesh_query, bd_mask, poly_mesh
            )
        if sampled_queries is not None:
            coord_extra = sampled_queries
            # Recurrent GAT deform (extra sampled)
            for i in range(self.num_loop):
                (coord_extra, model_output_extra), hidden, (phix_extra, phiy_extra) = (
                    self.deformer(
                        coord_extra,
                        hidden,
                        sampled_queries_edge_index,
                        sampled_queries,
                        bd_mask,
                        poly_mesh,
                    )
                )
            coord_output = torch.cat([coord, coord_extra], dim=0)
            model_raw_output = torch.cat([model_output, model_output_extra], dim=0)
            # # phix_output = torch.cat([phix, phix_extra], dim=0)
            # # phiy_output = torch.cat([phiy, phiy_extra], dim=0)
            phix_output = phix_extra
            phiy_output = phiy_extra
            # print(phix.shape, phix_extra.shape)
        else:
            coord_output = coord
            model_raw_output = model_output
            phix_output = phix
            phiy_output = phiy
        return (coord_output, model_raw_output, out_monitor), (phix_output, phiy_output)
        # return (coord, model_output, out_monitor), (phix, phiy)
[docs]
    def get_attention_scores(self, data):
        conv_feat_in = data.conv_feat
        batch_size = batch_size = conv_feat_in.shape[0]
        feat_dim = data.x.shape[-1]
        x_feat = data.x.view(-1, feat_dim)
        # coord = x_feat[:, :2]
        # edge_idx = data.edge_index
        _, attentions = self._transformer_forward(
            batch_size, data.mesh_feat[:, :4], x_feat, get_attens=True
        )
        return attentions