# This file is not written by the author of the project.
# The purose of this file is for comparison with the MRN model.
# The impelemented DeformGAT class is from M2N paper:
# https://arxiv.org/abs/2204.11188
# The original code is from: https://github.com/erizmr/M2N. However,
# this is a private repo belongs to https://github.com/erizmr, So the
# marker of this project may need to contact the original author for
# original code base.
import os
import sys
import torch
import torch.nn.functional as F
cur_dir = os.path.dirname(__file__)
sys.path.append(cur_dir)
from extractor import GlobalFeatExtractor, LocalFeatExtractor  # noqa: E402
from gatdeformer import DeformGAT  # noqa: E402
__all__ = ["M2N"]
[docs]
class M2N(torch.nn.Module):
    def __init__(self, gfe_in_c=1, lfe_in_c=3, deform_in_c=7, use_drop=False):
        super().__init__()
        self.gfe_out_c = 16
        self.lfe_out_c = 16
        self.deformer_in_feat = deform_in_c + self.gfe_out_c + self.lfe_out_c
        self.gfe = GlobalFeatExtractor(
            in_c=gfe_in_c, out_c=self.gfe_out_c, use_drop=use_drop
        )
        self.lfe = LocalFeatExtractor(num_feat=lfe_in_c, out=self.lfe_out_c)
        self.deformer = NetGATDeform(in_dim=self.deformer_in_feat)
[docs]
    def forward(self, data, poly_mesh=False):
        bd_mask = data.bd_mask
        if data.poly_mesh is not False:
            poly_mesh = True if data.poly_mesh.sum() > 0 else False
        x = data.x  # [num_nodes * batch_size, 2]
        # conv_feat_in = data.conv_feat_fix  # [batch_size, feat, 20, 20], using fixed conv-sample. # noqa
        conv_feat_in = data.conv_feat
        mesh_feat = data.mesh_feat  # [num_nodes * batch_size, 2]
        edge_idx = data.edge_index  # [num_edges * batch_size, 2]
        node_num = data.node_num
        conv_feat = self.gfe(conv_feat_in)
        # print(f"conv feat shape {conv_feat.shape}")
        conv_feat = conv_feat.repeat_interleave(node_num.reshape(-1), dim=0)
        # print(f"conv feat after inter leave {conv_feat.shape}")
        local_feat = self.lfe(mesh_feat, edge_idx)
        x = torch.cat([x, local_feat, conv_feat], dim=1)
        # x = torch.cat([x, local_feat], dim=1)
        x = self.deformer(x, edge_idx, bd_mask, poly_mesh)
        return x