Source code for UM2N.model.M2T_deformer

import os
import sys

import torch
import torch.nn as nn
from extractor import LocalFeatExtractor
from M2N import NetGATDeform
from torch_geometric.nn import GATv2Conv, MessagePassing

cur_dir = os.path.dirname(__file__)
sys.path.append(cur_dir)

__all__ = ["M2TDeformer"]


[docs] class M2TDeformer(MessagePassing): """ Implements a M2TDeformer. Attributes: to_hidden (GATv2Conv): Graph Attention layer. to_coord (nn.Sequential): Output layer for coordinates. activation (nn.SELU): Activation function. """ def __init__( self, feature_in_dim, local_feature_dim_in, coord_size=2, hidden_size=512, heads=6, output_type="coord", concat=False, device="cuda", ): super(M2TDeformer, self).__init__() assert output_type in [ "coord", "phi_grad", "phi", ], f"output type {output_type} is invalid" self.device = device self.output_type = output_type if self.output_type == "coord" or self.output_type == "phi_grad": self.output_dim = 2 elif output_type == "phi": self.output_dim = 1 else: raise Exception(f"Output type {output_type} is invalid.") lfe_in_c = local_feature_dim_in self.lfe_out_c = 16 self.lfe = LocalFeatExtractor(num_feat=lfe_in_c, out=self.lfe_out_c) self.gat_deformer = NetGATDeform(in_dim=feature_in_dim + self.lfe_out_c) # self.gat_deformer = NetGATDeform(in_dim=feature_in_dim) # GAT layer self.to_hidden = GATv2Conv( in_channels=coord_size + hidden_size, out_channels=hidden_size, heads=heads, concat=concat, ) # output layer self.to_output = nn.Sequential( nn.Linear(hidden_size, self.output_dim), ) # activation function self.activation = nn.SELU()
[docs] def forward( self, coord, mesh_feat, hidden_state, edge_index, coord_ori, bd_mask, poly_mesh ): self.bd_mask = bd_mask.squeeze().bool() self.poly_mesh = poly_mesh # Recurrent GAT # print(coord.shape, hidden_state.shape) extra_sample_ratio = coord.shape[0] // hidden_state.shape[0] # print(coord.shape, hidden_state.shape) local_feat = self.lfe(mesh_feat, edge_index) in_feat = torch.cat( [coord, local_feat, hidden_state.repeat(extra_sample_ratio, 1)], dim=1 ) # in_feat = torch.cat([coord, hidden_state.repeat(extra_sample_ratio, 1)], dim=1) # in_feat = torch.cat((coord, hidden_state), dim=1) # hidden = self.to_hidden(in_feat, edge_index) # hidden = self.activation(hidden) # output = self.to_output(hidden) # print("in feat shape ", in_feat.shape) output = self.gat_deformer(in_feat, edge_index, bd_mask, poly_mesh) phix = None phiy = None if self.output_type == "coord": output_coord = output # find boundary self.find_boundary(coord_ori) # fix boundary self.fix_boundary(output_coord) elif self.output_type == "phi_grad": output_coord = output + coord_ori # find boundary self.find_boundary(coord_ori) # fix boundary self.fix_boundary(output_coord) phix = output[:, 0].view(-1, 1) phiy = output[:, 1].view(-1, 1) elif self.output_type == "phi": # Compute the residual to the equation grad_seed = torch.ones(output.shape).to(self.device) phi_grad = torch.autograd.grad( output, coord_ori, grad_outputs=grad_seed, retain_graph=True, create_graph=True, allow_unused=False, )[0] # print(f"[phi grad] {phi_grad.shape}, [coord_ori] {coord_ori.shape}") phix = phi_grad[:, 0] phiy = phi_grad[:, 1] # New coord coord_x = (coord_ori[:, 0] + phix).reshape(-1, 1) coord_y = (coord_ori[:, 1] + phiy).reshape(-1, 1) output_coord = torch.cat([coord_x, coord_y], dim=-1).reshape(-1, 2) # find boundary self.find_boundary(coord_ori) # fix boundary self.fix_boundary(output_coord) # print('[phi] output coord shape ', output_coord.shape) return (output_coord, output), (phix, phiy)
[docs] def find_boundary(self, in_data): self.upper_node_idx = in_data[:, 0] == 1 self.down_node_idx = in_data[:, 0] == 0 self.left_node_idx = in_data[:, 1] == 0 self.right_node_idx = in_data[:, 1] == 1 if self.poly_mesh: self.bd_pos_x = in_data[self.bd_mask, 0].clone() self.bd_pos_y = in_data[self.bd_mask, 1].clone()
[docs] def fix_boundary(self, in_data): in_data[self.upper_node_idx, 0] = 1 in_data[self.down_node_idx, 0] = 0 in_data[self.left_node_idx, 1] = 0 in_data[self.right_node_idx, 1] = 1 if self.poly_mesh: in_data[self.bd_mask, 0] = self.bd_pos_x in_data[self.bd_mask, 1] = self.bd_pos_y