Source code for UM2N.model.M2T_deformer
import os
import sys
import torch
import torch.nn as nn
from extractor import LocalFeatExtractor
from M2N import NetGATDeform
from torch_geometric.nn import GATv2Conv, MessagePassing
cur_dir = os.path.dirname(__file__)
sys.path.append(cur_dir)
__all__ = ["M2TDeformer"]
[docs]
class M2TDeformer(MessagePassing):
    """
    Implements a M2TDeformer.
    Attributes:
        to_hidden (GATv2Conv): Graph Attention layer.
        to_coord (nn.Sequential): Output layer for coordinates.
        activation (nn.SELU): Activation function.
    """
    def __init__(
        self,
        feature_in_dim,
        local_feature_dim_in,
        coord_size=2,
        hidden_size=512,
        heads=6,
        output_type="coord",
        concat=False,
        device="cuda",
    ):
        super(M2TDeformer, self).__init__()
        assert output_type in [
            "coord",
            "phi_grad",
            "phi",
        ], f"output type {output_type} is invalid"
        self.device = device
        self.output_type = output_type
        if self.output_type == "coord" or self.output_type == "phi_grad":
            self.output_dim = 2
        elif output_type == "phi":
            self.output_dim = 1
        else:
            raise Exception(f"Output type {output_type} is invalid.")
        lfe_in_c = local_feature_dim_in
        self.lfe_out_c = 16
        self.lfe = LocalFeatExtractor(num_feat=lfe_in_c, out=self.lfe_out_c)
        self.gat_deformer = NetGATDeform(in_dim=feature_in_dim + self.lfe_out_c)
        # self.gat_deformer = NetGATDeform(in_dim=feature_in_dim)
        # GAT layer
        self.to_hidden = GATv2Conv(
            in_channels=coord_size + hidden_size,
            out_channels=hidden_size,
            heads=heads,
            concat=concat,
        )
        # output layer
        self.to_output = nn.Sequential(
            nn.Linear(hidden_size, self.output_dim),
        )
        # activation function
        self.activation = nn.SELU()
[docs]
    def forward(
        self, coord, mesh_feat, hidden_state, edge_index, coord_ori, bd_mask, poly_mesh
    ):
        self.bd_mask = bd_mask.squeeze().bool()
        self.poly_mesh = poly_mesh
        # Recurrent GAT
        # print(coord.shape, hidden_state.shape)
        extra_sample_ratio = coord.shape[0] // hidden_state.shape[0]
        # print(coord.shape, hidden_state.shape)
        local_feat = self.lfe(mesh_feat, edge_index)
        in_feat = torch.cat(
            [coord, local_feat, hidden_state.repeat(extra_sample_ratio, 1)], dim=1
        )
        # in_feat = torch.cat([coord, hidden_state.repeat(extra_sample_ratio, 1)], dim=1)
        # in_feat = torch.cat((coord, hidden_state), dim=1)
        # hidden = self.to_hidden(in_feat, edge_index)
        # hidden = self.activation(hidden)
        # output = self.to_output(hidden)
        # print("in feat shape ", in_feat.shape)
        output = self.gat_deformer(in_feat, edge_index, bd_mask, poly_mesh)
        phix = None
        phiy = None
        if self.output_type == "coord":
            output_coord = output
            # find boundary
            self.find_boundary(coord_ori)
            # fix boundary
            self.fix_boundary(output_coord)
        elif self.output_type == "phi_grad":
            output_coord = output + coord_ori
            # find boundary
            self.find_boundary(coord_ori)
            # fix boundary
            self.fix_boundary(output_coord)
            phix = output[:, 0].view(-1, 1)
            phiy = output[:, 1].view(-1, 1)
        elif self.output_type == "phi":
            # Compute the residual to the equation
            grad_seed = torch.ones(output.shape).to(self.device)
            phi_grad = torch.autograd.grad(
                output,
                coord_ori,
                grad_outputs=grad_seed,
                retain_graph=True,
                create_graph=True,
                allow_unused=False,
            )[0]
            # print(f"[phi grad] {phi_grad.shape}, [coord_ori] {coord_ori.shape}")
            phix = phi_grad[:, 0]
            phiy = phi_grad[:, 1]
            # New coord
            coord_x = (coord_ori[:, 0] + phix).reshape(-1, 1)
            coord_y = (coord_ori[:, 1] + phiy).reshape(-1, 1)
            output_coord = torch.cat([coord_x, coord_y], dim=-1).reshape(-1, 2)
            # find boundary
            self.find_boundary(coord_ori)
            # fix boundary
            self.fix_boundary(output_coord)
            # print('[phi] output coord shape ', output_coord.shape)
        return (output_coord, output), (phix, phiy)
[docs]
    def find_boundary(self, in_data):
        self.upper_node_idx = in_data[:, 0] == 1
        self.down_node_idx = in_data[:, 0] == 0
        self.left_node_idx = in_data[:, 1] == 0
        self.right_node_idx = in_data[:, 1] == 1
        if self.poly_mesh:
            self.bd_pos_x = in_data[self.bd_mask, 0].clone()
            self.bd_pos_y = in_data[self.bd_mask, 1].clone()
[docs]
    def fix_boundary(self, in_data):
        in_data[self.upper_node_idx, 0] = 1
        in_data[self.down_node_idx, 0] = 0
        in_data[self.left_node_idx, 1] = 0
        in_data[self.right_node_idx, 1] = 1
        if self.poly_mesh:
            in_data[self.bd_mask, 0] = self.bd_pos_x
            in_data[self.bd_mask, 1] = self.bd_pos_y