Source code for UM2N.model.gatdeformer
# This file is not written by the author of the project.
# The purose of this file is for comparison with the MRN model.
# The impelemented DeformGAT class is from M2N paper:
# https://arxiv.org/abs/2204.11188
# The original code is from: https://github.com/erizmr/M2N. However,
# this is a private repo belongs to https://github.com/erizmr, So the
# marker of this project may need to contact the original author for
# original code base.
from typing import Optional, Union
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Linear, Parameter
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import glorot
from torch_geometric.typing import Adj, OptPairTensor, OptTensor
from torch_geometric.utils import softmax
__all__ = ["DeformGAT"]
[docs]
class DeformGAT(MessagePassing):
def __init__(
self,
in_channels: int,
out_channels: int,
heads: int = 1,
concat: bool = False,
negative_slope: float = 0.2,
dropout: float = 0,
bias: bool = False,
**kwargs,
):
kwargs.setdefault("aggr", "add")
super(DeformGAT, self).__init__(node_dim=0, **kwargs)
# comment:指定一些参数。。
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.concat = concat
self.negative_slope = negative_slope
self.dropout = dropout
self.add_self_loops = False
# comment:这边没有bias,我觉得不太行!!!
# TODO:这里 bias 是 True 还是 False,再仔细想想看吧。
self.lin_l = Linear(in_channels, heads * out_channels, bias=True).float()
self.lin_ = self.lin_l
# 这个是用来算attention的 vector
self.att_l = Parameter(torch.FloatTensor(1, heads, out_channels))
self.att_r = Parameter(torch.FloatTensor(1, heads, out_channels))
if bias and concat: # comment:bias要不要自己决定的啊
self.bias = Parameter(torch.FloatTensor(heads * out_channels))
elif bias and not concat:
self.bias = Parameter(torch.FloatTensor(out_channels))
else:
self.register_parameter("bias", None)
self.negative_slope = -0.2
self.reset_parameters()
[docs]
def reset_parameters(self):
glorot(self.lin_l.weight)
glorot(self.lin_.weight)
glorot(self.att_l)
glorot(self.att_r)
[docs]
def forward(
self,
coords: Union[Tensor, OptPairTensor],
features: Union[Tensor, OptPairTensor],
edge_index: Adj,
bd_mask,
poly_mesh,
):
self.bd_mask = bd_mask.squeeze().bool()
self.poly_mesh = poly_mesh
self.find_boundary(coords)
# coords:各个节点的坐标(其实就是features的前两个纬度)
H, C = self.heads, self.out_channels
x_l = x_r = self.lin_l(features).view(
-1, H, C
) # [num_node , heads, out_channels]
x_coords_l = x_coords_r = coords # [119, 2]
alpha_l = (x_l * self.att_l).sum(dim=-1) # [119, 6] 因为 attention
alpha_r = (x_r * self.att_r).sum(dim=-1) # [119, 6]
x_coords_l = x_coords_r = coords.unsqueeze(1) # (119, 1, 2)
# TODO:这里的alpha_l和alpha_r为啥需要乘以个0.2??
out_coords = self.propagate(
edge_index, x=(x_coords_l, x_coords_r), alpha=(0.2 * alpha_l, 0.2 * alpha_r)
) # [119, 6, 2]
out_coords = out_coords.mean(dim=1) # [119, 6, 2] --> [119, 2]
out_features = self.propagate(
edge_index, x=(x_l, x_r), alpha=(alpha_l, alpha_r)
) # [119, 6, 40]
out_features = out_features.mean(dim=1) # [119, 40]
out_features = F.selu(out_features) # [119, 40] # TODO:这个可以去掉么??
self.fix_boundary(out_coords)
return out_coords, out_features
[docs]
def message(
self,
x_j: Tensor,
alpha_j: Tensor,
alpha_i: OptTensor,
index: Tensor,
ptr: OptTensor,
size_i: Optional[int],
) -> Tensor:
if alpha_i is None:
alpha = alpha_j
else:
alpha = (
alpha_j + alpha_i
) # comment:应该是走了这一步,因为有这两个都有的啊。。
alpha = F.selu(alpha)
# 这边 softmax 只要汇点信息是有原因的哦。
alpha = softmax(alpha, index, ptr, size_i)
# 这个函数通过广播的操作,将最后的一个纬度给扩充了。
return x_j * alpha.unsqueeze(-1)
[docs]
def find_boundary(self, in_data):
self.upper_node_idx = in_data[:, 0] == 1
self.down_node_idx = in_data[:, 0] == 0
self.left_node_idx = in_data[:, 1] == 0
self.right_node_idx = in_data[:, 1] == 1
# if self.poly_mesh:
self.bd_pos_x = in_data[self.bd_mask, 0].clone()
self.bd_pos_y = in_data[self.bd_mask, 1].clone()
[docs]
def fix_boundary(self, in_data):
in_data[self.upper_node_idx, 0] = 1
in_data[self.down_node_idx, 0] = 0
in_data[self.left_node_idx, 1] = 0
in_data[self.right_node_idx, 1] = 1
# if self.poly_mesh:
in_data[self.bd_mask, 0] = self.bd_pos_x
in_data[self.bd_mask, 1] = self.bd_pos_y
def __repr__(self):
return "{}({}, {}, heads={})".format(
self.__class__.__name__, self.in_channels, self.out_channels, self.heads
)