UM2N.model package

Submodules

UM2N.model.M2N module

class NetGATDeform(in_dim)[source]

Bases: Module

forward(data, edge_idx, bd_mask, poly_mesh)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class M2N(gfe_in_c=1, lfe_in_c=3, deform_in_c=7, use_drop=False)[source]

Bases: Module

forward(data, poly_mesh=False)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

UM2N.model.M2N_T module

class NetGATDeform(in_dim)[source]

Bases: Module

forward(data, edge_idx, bd_mask, poly_mesh)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class M2N_T(gfe_in_c=3, lfe_in_c=3, deform_in_c=3)[source]

Bases: Module

forward(data, poly_mesh=False)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

UM2N.model.M2N_atten module

class NetGATDeform(in_dim)[source]

Bases: Module

forward(data, edge_idx)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class M2NAtten(gfe_in_c=1, lfe_in_c=3, deform_in_c=7, use_drop=False)[source]

Bases: Module

forward(data)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

UM2N.model.M2N_dynamic_drop module

class NetGATDeform(in_dim)[source]

Bases: Module

forward(data, edge_idx)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class M2N_dynamic_drop(gfe_in_c=1, lfe_in_c=3, deform_in_c=7)[source]

Bases: Module

forward(data)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

UM2N.model.M2N_dynamic_no_drop module

class NetGATDeform(in_dim)[source]

Bases: Module

forward(data, edge_idx)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class M2N_dynamic_no_drop(gfe_in_c=1, lfe_in_c=3, deform_in_c=7)[source]

Bases: Module

forward(data)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

UM2N.model.M2T module

class M2T(num_transformer_in=4, num_transformer_out=16, num_transformer_embed_dim=64, num_transformer_heads=4, num_transformer_layers=1, transformer_training_mask=False, transformer_key_padding_training_mask=False, transformer_attention_training_mask=False, transformer_training_mask_ratio_lower_bound=0.5, transformer_training_mask_ratio_upper_bound=0.9, deform_in_c=7, deform_out_type='coord', local_feature_dim_in=4, num_loop=3, device='cuda')[source]

Bases: Module

Mesh Refinement Network (MRN) implementing transformer as feature

extrator and recurrent graph-based deformations.

Attributes:

num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.

transformer_monitor(data, input_q, input_kv, boundary)[source]
forward(data, input_q, input_kv, mesh_query, sampled_queries, sampled_queries_edge_index, poly_mesh=False)[source]

Forward pass for MRN.

Args:

data (Data): Input data object containing mesh and feature info.

Returns:

coord (Tensor): Deformed coordinates.

get_attention_scores(data)[source]

UM2N.model.M2T_deformer module

class M2TDeformer(feature_in_dim, local_feature_dim_in, coord_size=2, hidden_size=512, heads=6, output_type='coord', concat=False, device='cuda')[source]

Bases: MessagePassing

Implements a M2TDeformer.

Attributes:

to_hidden (GATv2Conv): Graph Attention layer. to_coord (nn.Sequential): Output layer for coordinates. activation (nn.SELU): Activation function.

forward(coord, mesh_feat, hidden_state, edge_index, coord_ori, bd_mask, poly_mesh)[source]

Runs the forward pass of the module.

find_boundary(in_data)[source]
fix_boundary(in_data)[source]

UM2N.model.MRN module

class RecurrentGATConv(coord_size=2, hidden_size=512, heads=6, concat=False)[source]

Bases: MessagePassing

Implements a Recurrent Graph Attention Network (GAT) Convolution layer.

Attributes:

to_hidden (GATv2Conv): Graph Attention layer. to_coord (nn.Sequential): Output layer for coordinates. activation (nn.SELU): Activation function.

forward(coord, hidden_state, edge_index, bd_mask, poly_mesh)[source]

Runs the forward pass of the module.

find_boundary(in_data)[source]
fix_boundary(in_data)[source]
class MRN(gfe_in_c=2, lfe_in_c=4, deform_in_c=7, num_loop=3)[source]

Bases: Module

Mesh Refinement Network (MRN) implementing global and local feature

extraction

and recurrent graph-based deformations.

Attributes:

num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.

move(data, num_step=1)[source]
Move the mesh according to the deformation learned, with given number

steps.

Args:

data (Data): Input data object containing mesh and feature info. num_step (int): Number of deformation steps.

Returns:

coord (Tensor): Deformed coordinates.

forward(data, poly_mesh=False)[source]

Forward pass for MRN.

Args:

data (Data): Input data object containing mesh and feature info.

Returns:

coord (Tensor): Deformed coordinates.

UM2N.model.MRN_GTE module

class MRNGlobalTransformerEncoder(gfe_in_c=2, lfe_in_c=4, deform_in_c=7, num_loop=3)[source]

Bases: Module

Mesh Refinement Network (MRN) implementing global and local feature

extraction and recurrent graph-based deformations. The global feature extraction is performed by a transformer.

Attributes:

num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.

move(data, num_step=1)[source]
Move the mesh according to the deformation learned, with given number

steps.

Args:

data (Data): Input data object containing mesh and feature info. num_step (int): Number of deformation steps.

Returns:

coord (Tensor): Deformed coordinates.

forward(data)[source]

Forward pass for MRN.

Args:

data (Data): Input data object containing mesh and feature info.

Returns:

coord (Tensor): Deformed coordinates.

UM2N.model.MRN_LTE module

class MRNLocalTransformerEncoder(gfe_in_c=2, lfe_in_c=4, deform_in_c=7, num_loop=3)[source]

Bases: Module

Mesh Refinement Network (MRN) implementing global and local feature

extraction and recurrent graph-based deformations. The global feature extraction is performed by a transformer.

Attributes:

num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.

move(data, num_step=1)[source]
Move the mesh according to the deformation learned, with given number

steps.

Args:

data (Data): Input data object containing mesh and feature info. num_step (int): Number of deformation steps.

Returns:

coord (Tensor): Deformed coordinates.

forward(data)[source]

Forward pass for MRN.

Args:

data (Data): Input data object containing mesh and feature info.

Returns:

coord (Tensor): Deformed coordinates.

UM2N.model.MRN_atten module

class RecurrentGATConv(coord_size=2, hidden_size=512, heads=6, concat=False)[source]

Bases: MessagePassing

Implements a Recurrent Graph Attention Network (GAT) Convolution layer.

Attributes:

to_hidden (GATv2Conv): Graph Attention layer. to_coord (nn.Sequential): Output layer for coordinates. activation (nn.SELU): Activation function.

forward(coord, hidden_state, edge_index)[source]

Runs the forward pass of the module.

find_boundary(in_data)[source]
fix_boundary(in_data)[source]
class MRNAtten(gfe_in_c=2, lfe_in_c=4, deform_in_c=7, num_loop=3)[source]

Bases: Module

Mesh Refinement Network (MRN) with self attention implementing global and local feature

extraction

and recurrent graph-based deformations.

Attributes:

num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.

move(data, num_step=1)[source]
Move the mesh according to the deformation learned, with given number

steps.

Args:

data (Data): Input data object containing mesh and feature info. num_step (int): Number of deformation steps.

Returns:

coord (Tensor): Deformed coordinates.

forward(data)[source]

Forward pass for MRN.

Args:

data (Data): Input data object containing mesh and feature info.

Returns:

coord (Tensor): Deformed coordinates.

UM2N.model.MRN_fix module

class RecurrentGATConv(coord_size=2, hidden_size=512, heads=6, concat=False)[source]

Bases: MessagePassing

Implements a Recurrent Graph Attention Network (GAT) Convolution layer.

Attributes:

to_hidden (GATv2Conv): Graph Attention layer. to_coord (nn.Sequential): Output layer for coordinates. activation (nn.SELU): Activation function.

forward(coord, hidden_state, edge_index)[source]

Runs the forward pass of the module.

find_boundary(in_data)[source]
fix_boundary(in_data)[source]
class MRN_fix(gfe_in_c=2, lfe_in_c=4, deform_in_c=7, num_loop=3)[source]

Bases: Module

Mesh Refinement Network (MRN) implementing global and local feature

extraction

and recurrent graph-based deformations.

Attributes:

num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.

move(data, num_step=1)[source]
Move the mesh according to the deformation learned, with given number

steps.

Args:

data (Data): Input data object containing mesh and feature info. num_step (int): Number of deformation steps.

Returns:

coord (Tensor): Deformed coordinates.

forward(data)[source]

Forward pass for MRN.

Args:

data (Data): Input data object containing mesh and feature info.

Returns:

coord (Tensor): Deformed coordinates.

UM2N.model.MRN_phi module

class RecurrentGATConv(phi_size=1, hidden_size=512, heads=6, concat=False)[source]

Bases: MessagePassing

Implements a Recurrent Graph Attention Network (GAT) Convolution layer.

Attributes:

to_hidden (GATv2Conv): Graph Attention layer. to_coord (nn.Sequential): Output layer for coordinates. activation (nn.SELU): Activation function.

forward(phi, hidden_state, edge_index)[source]

Runs the forward pass of the module.

class MRN_phi(gfe_in_c=2, lfe_in_c=4, deform_in_c=7, num_loop=3)[source]

Bases: Module

Mesh Recurrent Network (MRN) implementing global and local feature

extraction

and recurrent graph-based deformations for field phi.

Attributes:

num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.

forward(data)[source]

Forward pass for MRN.

Args:

data (Data): Input data object containing mesh and feature info.

Returns:

coord (Tensor): Deformed coordinates.

UM2N.model.MRT module

class MRTransformer(num_transformer_in=4, num_transformer_out=16, num_transformer_embed_dim=64, num_transformer_heads=4, num_transformer_layers=1, transformer_training_mask=False, transformer_key_padding_training_mask=False, transformer_attention_training_mask=False, transformer_training_mask_ratio_lower_bound=0.5, transformer_training_mask_ratio_upper_bound=0.9, deform_in_c=7, deform_out_type='coord', num_loop=3, device='cuda')[source]

Bases: Module

Mesh Refinement Network (MRN) implementing transformer as feature

extrator and recurrent graph-based deformations.

Attributes:

num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.

transformer_monitor(data, input_q, input_kv, boundary)[source]
move(data, input_q, input_kv, mesh_query, sampled_queries, sampled_queries_edge_index, num_step=1, poly_mesh=False)[source]
Move the mesh according to the deformation learned, with given number

steps.

Args:

data (Data): Input data object containing mesh and feature info. num_step (int): Number of deformation steps.

Returns:

coord (Tensor): Deformed coordinates.

forward(data, input_q, input_kv, mesh_query, sampled_queries, sampled_queries_edge_index, poly_mesh=False)[source]

Forward pass for MRN.

Args:

data (Data): Input data object containing mesh and feature info.

Returns:

coord (Tensor): Deformed coordinates.

get_attention_scores(data)[source]

UM2N.model.MRT_PE module

UM2N.model.MRT_phi module

class RecurrentGATConv(in_size=1, hidden_size=512, heads=6, concat=False)[source]

Bases: MessagePassing

Implements a Recurrent Graph Attention Network (GAT) Convolution layer.

Attributes:

to_hidden (GATv2Conv): Graph Attention layer. to_coord (nn.Sequential): Output layer for coordinates. activation (nn.SELU): Activation function.

forward(in_feat, hidden_state, edge_index)[source]

Runs the forward pass of the module.

class MRT_phi(gfe_in_c=2, lfe_in_c=4, deform_in_c=7, num_loop=3)[source]

Bases: Module

Mesh Refinement Network (MRN) implementing global and local feature

extraction and recurrent graph-based deformations. The global feature extraction is performed by a transformer.

Attributes:

num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.

forward(data)[source]

Forward pass for MRN.

Args:

data (Data): Input data object containing mesh and feature info.

Returns:

coord (Tensor): Deformed coordinates.

UM2N.model.deformer module

class RecurrentGATConv(coord_size=2, hidden_size=512, heads=6, output_type='coord', concat=False, device='cuda')[source]

Bases: MessagePassing

Implements a Recurrent Graph Attention Network (GAT) Convolution layer.

Attributes:

to_hidden (GATv2Conv): Graph Attention layer. to_coord (nn.Sequential): Output layer for coordinates. activation (nn.SELU): Activation function.

forward(coord, hidden_state, edge_index, coord_ori, bd_mask, poly_mesh)[source]

Runs the forward pass of the module.

find_boundary(in_data)[source]
fix_boundary(in_data)[source]

UM2N.model.extractor module

class LocalFeatExtractor(num_feat=10, out=16)[source]

Bases: MessagePassing

Custom PyTorch geometric layer that performs feature extraction on local graph structure.

The class extends the torch_geometric.nn.MessagePassing class and employs additive aggregation as the message-passing scheme.

Attributes:

lin_1 (torch.nn.Linear): First linear layer. lin_2 (torch.nn.Linear): Second linear layer. lin_3 (torch.nn.Linear): Third linear layer. activate (torch.nn.SELU): Activation function.

forward(input, edge_index)[source]

Forward pass.

Args:

input (Tensor): Node features. edge_index (Tensor): Edge indices.

Returns:

Tensor: Updated node features.

message(x_i, x_j)[source]

Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in edge_index. This function can take any argument as input which was initially passed to propagate(). Furthermore, tensors passed to propagate() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

class GlobalFeatExtractor(in_c, out_c, drop_p=0.2, use_drop=True)[source]

Bases: Module

Custom PyTorch layer for global feature extraction.

The class employs multiple convolutional layers and dropout layers.

Attributes:

conv1, conv2, conv3, conv4 (torch.nn.Conv2d): Convolutional layers. dropout (torch.nn.Dropout): Dropout layer. final_pool (torch.nn.AdaptiveAvgPool2d): Final pooling layer.

forward(data)[source]

Forward pass.

Args:

data (Tensor): Input data.

Returns:

Tensor: Extracted global features.

class TransformerEncoder(num_transformer_in=3, num_transformer_out=16, num_transformer_embed_dim=64, num_transformer_heads=4, num_transformer_layers=1, transformer_training_mask=False, transformer_key_padding_training_mask=False, transformer_attention_training_mask=False, transformer_training_mask_ratio_lower_bound=0.5, transformer_training_mask_ratio_upper_bound=0.9, deform_in_c=7, deform_out_type='coord', num_loop=3, device='cuda')[source]

Bases: Module

Mesh Refinement Network (MRN) implementing transformer as feature

extrator and recurrent graph-based deformations.

Attributes:

num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.

forward(data)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

UM2N.model.gatdeformer module

class DeformGAT(in_channels: int, out_channels: int, heads: int = 1, concat: bool = False, negative_slope: float = 0.2, dropout: float = 0, bias: bool = False, **kwargs)[source]

Bases: MessagePassing

reset_parameters()[source]

Resets all learnable parameters of the module.

forward(coords: Tensor | Tuple[Tensor, Tensor | None], features: Tensor | Tuple[Tensor, Tensor | None], edge_index: Tensor | SparseTensor, bd_mask, poly_mesh)[source]

Runs the forward pass of the module.

message(x_j: Tensor, alpha_j: Tensor, alpha_i: Tensor | None, index: Tensor, ptr: Tensor | None, size_i: int | None) Tensor[source]

Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in edge_index. This function can take any argument as input which was initially passed to propagate(). Furthermore, tensors passed to propagate() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

find_boundary(in_data)[source]
fix_boundary(in_data)[source]

UM2N.model.train_util module

chamfer_distance(*args)[source]
get_face_area(coord, face)[source]
Calculates the area of a face. using formula:

area = 0.5 * (x1(y2 - y3) + x2(y3 - y1) + x3(y1 - y2))

Args:

coord (torch.Tensor): The coordinates. face (torch.Tensor): The face tensor.

get_inversion_loss(out_coord, in_coord, face, batch_size, scheme='relu', scaler=100)[source]

Calculates the inversion loss for a batch of meshes. Args:

out_coord (torch.Tensor): The output coordinates. in_coord (torch.Tensor): The input coordinates. face (torch.Tensor): The face tensor. batch_size (int): The batch size. alpha (float): The loss weight.

get_inversion_diff_loss(out_coord, tar_coord, face, batch_size, scaler=100)[source]

Calculates the inversion difference loss for a batch of meshes. That is the difference between the output area and the input area, in terms of the invereted elements. Args:

out_coord (torch.Tensor): The output coordinates. tar_coord (torch.Tensor): The target coordinates. face (torch.Tensor): The face tensor. batch_size (int): The batch size. alpha (float): The loss weight.

get_inversion_node_loss(out_coord, tar_coord, face, batch_size, scaler=1000)[source]

Calculates the loss between the ouput node and input node, for the inverted elements. This will penalise the node which are involved in the tangled elements. Args:

out_coord (torch.Tensor): The output coordinates. tar_coord (torch.Tensor): The target coordinates. face (torch.Tensor): The face tensor. batch_size (int): The batch size. alpha (float): The loss weight.

get_area_loss(out_coord, tar_coord, face, batch_size, scaler=100)[source]
jacobLoss(model, out, data, loss_func)[source]
get_jacob_det(model, in_data)[source]
class TangleCounter(num_feat=10, out=16)[source]

Bases: MessagePassing

A PyTorch Geometric Message Passing class for counting tangles in the mesh. This class is deprecated, do not use this option unless you know what you are doing.

forward(x, x_new, edge_index)[source]

Runs the forward pass of the module.

message(x_i, x_j, x_new_i, x_new_j)[source]

Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in edge_index. This function can take any argument as input which was initially passed to propagate(). Furthermore, tensors passed to propagate() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

print_parameter_grad(model)[source]
train(loader, model, optimizer, device, loss_func, use_jacob=False, use_inversion_loss=False, use_inversion_diff_loss=False, use_area_loss=False, weight_deform_loss=1.0, weight_area_loss=1.0, weight_chamfer_loss=0.0, scaler=100)[source]
Trains a PyTorch model using the given data loader, optimizer,

and loss function.

Args:

loader (DataLoader): DataLoader object for the training data. model (torch.nn.Module): The PyTorch model to train. optimizer (Optimizer): The optimizer (e.g., Adam, SGD). device (torch.device): The device to run the computation on. loss_func (callable): Loss function (e.g., MSE, Cross-Entropy). use_jacob (bool): Whether or not to use Jacobian loss.

Returns:

float: The average training loss across all batches.

evaluate(loader, model, device, loss_func, use_jacob=False, use_inversion_loss=False, use_inversion_diff_loss=False, use_area_loss=False, weight_deform_loss=1.0, weight_area_loss=1.0, weight_chamfer_loss=0.0, scaler=100)[source]

Evaluates a model using the given data loader and loss function.

Args:

loader (DataLoader): DataLoader object for the evaluation data. model (torch.nn.Module): The PyTorch model to evaluate. device (torch.device): The device to run the computation on. loss_func (callable): Loss function (e.g., MSE, Cross-Entropy). use_jacob (bool): Whether or not to use Jacobian loss. Defaults to.

Returns:

float: The average evaluation loss across all batches.

interpolate(u, ori_mesh_x, ori_mesh_y, moved_x, moved_y)[source]

u: [bs, node_num, 1] ori_mesh_x: [bs, node_num, 1] ori_mesh_y: [bs, node_num, 1] moved_x: [bs, node_num, 1] moved_y: [bs, node_num, 1]

Note: node_num equals to sample_num

generate_samples(bs, num_samples_per_mesh, data, num_meshes=5, device='cuda')[source]
compute_finite_difference(field)[source]
generate_samples_structured_grid(coords, field, grid_resolution=100, device='cuda')[source]
construct_graph(sampled_coords, num_neighbors=6, device='cuda')[source]
compute_phi_hessian(mesh_query_x, mesh_query_y, phix, phiy, out_monitor, bs, data, loss_func, finite_difference_grad=False)[source]
model_forward(bs, data, model, use_add_random_query=True)[source]
sample_nodes_by_monitor(meshes, meshes_target, monitors, num_samples_per_mesh=100, random_seed=666)[source]
train_unsupervised(loader, model, optimizer, device, loss_func, use_jacob=False, use_inversion_loss=False, use_inversion_diff_loss=False, use_area_loss=False, use_convex_loss=False, use_add_random_query=True, finite_difference_grad=True, weight_area_loss=1, weight_deform_loss=1, weight_chamfer_loss=1, weight_eq_residual_loss=1, scaler=100)[source]
Trains a PyTorch model using the given data loader, optimizer,

and loss function.

Args:

loader (DataLoader): DataLoader object for the training data. model (torch.nn.Module): The PyTorch model to train. optimizer (Optimizer): The optimizer (e.g., Adam, SGD). device (torch.device): The device to run the computation on. loss_func (callable): Loss function (e.g., MSE, Cross-Entropy). use_jacob (bool): Whether or not to use Jacobian loss.

Returns:

float: The average training loss across all batches.

evaluate_unsupervised(loader, model, device, loss_func, use_jacob=False, use_inversion_loss=False, use_inversion_diff_loss=False, use_area_loss=False, use_convex_loss=False, use_add_random_query=True, finite_difference_grad=True, weight_area_loss=1, weight_deform_loss=1, weight_eq_residual_loss=1, weight_chamfer_loss=1, scaler=100)[source]

Evaluates a model using the given data loader and loss function.

Args:

loader (DataLoader): DataLoader object for the evaluation data. model (torch.nn.Module): The PyTorch model to evaluate. device (torch.device): The device to run the computation on. loss_func (callable): Loss function (e.g., MSE, Cross-Entropy). use_jacob (bool): Whether or not to use Jacobian loss. Defaults to.

Returns:

float: The average evaluation loss across all batches.

get_sample_tangle(out_coords, in_coords, face)[source]

Return the number of tangled elements in a single sample.

count_dataset_tangle(dataset, model, device, method='inversion')[source]

Computes the average number of tangles in a dataset.

Args:

dataset (Dataset): The PyTorch Geometric dataset. model (torch.nn.Module): The PyTorch model. device (torch.device): The device to run the computation.

Returns:

float: The average number of tangles in the dataset.

evaluate_repeat_sampling(dataset, model, device, loss_func, use_inversion_loss=False, use_inversion_diff_loss=False, use_area_loss=False, scaler=100, batch_size=5, num_samples=1)[source]

Evaluates a model using the given data loader and loss function.

Args:

loader (DataLoader): DataLoader object for the evaluation data. model (torch.nn.Module): The PyTorch model to evaluate. device (torch.device): The device to run the computation on. loss_func (callable): Loss function (e.g., MSE, Cross-Entropy). use_jacob (bool): Whether or not to use Jacobian loss. Defaults to.

Returns:

float: The average evaluation loss across all batches.

count_dataset_tangle_repeat_sampling(dataset, model, device, num_samples=1)[source]

Computes the average number of tangles in a dataset.

Args:

dataset (Dataset): The PyTorch Geometric dataset. model (torch.nn.Module): The PyTorch model. device (torch.device): The device to run the computation.

Returns:

float: The average number of tangles in the dataset.

evaluate_repeat(dataset, model, device, loss_func, scaler=100, num_repeat=1)[source]

Evaluates model performance when sampling for different number of times. this function will evaluate:

  1. the average loss

  2. the average number of tangles

Args:

dataset (MeshDataset): The target dataset to evaluate. model (torch.nn.Module): The PyTorch model to evaluate. device (torch.device): The device to run the computation on. loss_func (callable): Loss function (e.g., MSE, Cross-Entropy).

Returns:

float: The average evaluation loss across all batches.

load_model(model, weight_path, strict=False)[source]

Loads pre-trained weights into a PyTorch model from a given file path.

Args:

model (torch.nn.Module): The PyTorch model. weight_path (str): File path to the pre-trained model weights.

Returns:

torch.nn.Module: The model with loaded weights.

UM2N.model.transformer_model module

class MLP_model(input_channels, output_channels, list_hiddens=[128, 128], hidden_act='LeakyReLU', output_act='LeakyReLU', input_norm=None, dropout_prob=0.0)[source]

Bases: Module

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class TransformerBlock(embed_dim, num_heads, dense_ratio=4, list_dropout=[0.1, 0.1, 0.1], activation='GELU')[source]

Bases: Module

forward(x, k, v, x_cls=None, key_padding_mask=None, attn_mask=None, return_attn=False)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class TransformerModel(*, input_dim, embed_dim, output_dim, num_heads=4, num_layers=3)[source]

Bases: Module

forward(x, k, v, key_padding_mask=None, attention_mask=None)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_attention_scores(x, key_padding_mask=None, attn_mask=None)[source]

Module contents