UM2N.model package

Submodules

UM2N.model.M2N module

class M2N(gfe_in_c=1, lfe_in_c=3, deform_in_c=7, use_drop=False)[source]

Bases: Module

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(data, poly_mesh=False)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

UM2N.model.M2N_T module

class M2N_T(gfe_in_c=3, lfe_in_c=3, deform_in_c=3)[source]

Bases: Module

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(data, poly_mesh=False)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

UM2N.model.M2N_atten module

class M2NAtten(gfe_in_c=1, lfe_in_c=3, deform_in_c=7, use_drop=False)[source]

Bases: Module

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(data)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

UM2N.model.M2N_dynamic_drop module

class M2N_dynamic_drop(gfe_in_c=1, lfe_in_c=3, deform_in_c=7)[source]

Bases: Module

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(data)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

UM2N.model.M2N_dynamic_no_drop module

class M2N_dynamic_no_drop(gfe_in_c=1, lfe_in_c=3, deform_in_c=7)[source]

Bases: Module

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(data)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

UM2N.model.M2T module

class M2T(num_transformer_in=4, num_transformer_out=16, num_transformer_embed_dim=64, num_transformer_heads=4, num_transformer_layers=1, transformer_training_mask=False, transformer_key_padding_training_mask=False, transformer_attention_training_mask=False, transformer_training_mask_ratio_lower_bound=0.5, transformer_training_mask_ratio_upper_bound=0.9, deform_in_c=7, deform_out_type='coord', local_feature_dim_in=4, num_loop=3, device='cuda')[source]

Bases: Module

Mesh Refinement Network (MRN) implementing transformer as feature

extrator and recurrent graph-based deformations.

Attributes:

num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.

Initialize MRN.

Args:

gfe_in_c (int): Input channels for the global feature extractor. lfe_in_c (int): Input channels for the local feature extractor. deform_in_c (int): Input channels for the deformer block. num_loop (int): Number of loops for the recurrent layer.

forward(data, input_q, input_kv, mesh_query, sampled_queries, sampled_queries_edge_index, poly_mesh=False)[source]

Forward pass for MRN.

Args:

data (Data): Input data object containing mesh and feature info.

Returns:

coord (Tensor): Deformed coordinates.

get_attention_scores(data)[source]
transformer_monitor(data, input_q, input_kv, boundary)[source]

UM2N.model.M2T_deformer module

class M2TDeformer(feature_in_dim, local_feature_dim_in, coord_size=2, hidden_size=512, heads=6, output_type='coord', concat=False, device='cuda')[source]

Bases: MessagePassing

Implements a M2TDeformer.

Attributes:

to_hidden (GATv2Conv): Graph Attention layer. to_coord (nn.Sequential): Output layer for coordinates. activation (nn.SELU): Activation function.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

find_boundary(in_data)[source]
fix_boundary(in_data)[source]
forward(coord, mesh_feat, hidden_state, edge_index, coord_ori, bd_mask, poly_mesh)[source]

Runs the forward pass of the module.

UM2N.model.MRN module

class MRN(gfe_in_c=2, lfe_in_c=4, deform_in_c=7, num_loop=3)[source]

Bases: Module

Mesh Refinement Network (MRN) implementing global and local feature

extraction

and recurrent graph-based deformations.

Attributes:

num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.

Initialize MRN.

Args:

gfe_in_c (int): Input channels for the global feature extractor. lfe_in_c (int): Input channels for the local feature extractor. deform_in_c (int): Input channels for the deformer block. num_loop (int): Number of loops for the recurrent layer.

forward(data, poly_mesh=False)[source]

Forward pass for MRN.

Args:

data (Data): Input data object containing mesh and feature info.

Returns:

coord (Tensor): Deformed coordinates.

move(data, num_step=1)[source]
Move the mesh according to the deformation learned, with given number

steps.

Args:

data (Data): Input data object containing mesh and feature info. num_step (int): Number of deformation steps.

Returns:

coord (Tensor): Deformed coordinates.

class RecurrentGATConv(coord_size=2, hidden_size=512, heads=6, concat=False)[source]

Bases: MessagePassing

Implements a Recurrent Graph Attention Network (GAT) Convolution layer.

Attributes:

to_hidden (GATv2Conv): Graph Attention layer. to_coord (nn.Sequential): Output layer for coordinates. activation (nn.SELU): Activation function.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

find_boundary(in_data)[source]
fix_boundary(in_data)[source]
forward(coord, hidden_state, edge_index, bd_mask, poly_mesh)[source]

Runs the forward pass of the module.

UM2N.model.MRN_GTE module

class MRNGlobalTransformerEncoder(gfe_in_c=2, lfe_in_c=4, deform_in_c=7, num_loop=3)[source]

Bases: Module

Mesh Refinement Network (MRN) implementing global and local feature

extraction and recurrent graph-based deformations. The global feature extraction is performed by a transformer.

Attributes:

num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.

Initialize MRN.

Args:

gfe_in_c (int): Input channels for the global feature extractor. lfe_in_c (int): Input channels for the local feature extractor. deform_in_c (int): Input channels for the deformer block. num_loop (int): Number of loops for the recurrent layer.

forward(data)[source]

Forward pass for MRN.

Args:

data (Data): Input data object containing mesh and feature info.

Returns:

coord (Tensor): Deformed coordinates.

move(data, num_step=1)[source]
Move the mesh according to the deformation learned, with given number

steps.

Args:

data (Data): Input data object containing mesh and feature info. num_step (int): Number of deformation steps.

Returns:

coord (Tensor): Deformed coordinates.

UM2N.model.MRN_LTE module

class MRNLocalTransformerEncoder(gfe_in_c=2, lfe_in_c=4, deform_in_c=7, num_loop=3)[source]

Bases: Module

Mesh Refinement Network (MRN) implementing global and local feature

extraction and recurrent graph-based deformations. The global feature extraction is performed by a transformer.

Attributes:

num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.

Initialize MRN.

Args:

gfe_in_c (int): Input channels for the global feature extractor. lfe_in_c (int): Input channels for the local feature extractor. deform_in_c (int): Input channels for the deformer block. num_loop (int): Number of loops for the recurrent layer.

forward(data)[source]

Forward pass for MRN.

Args:

data (Data): Input data object containing mesh and feature info.

Returns:

coord (Tensor): Deformed coordinates.

move(data, num_step=1)[source]
Move the mesh according to the deformation learned, with given number

steps.

Args:

data (Data): Input data object containing mesh and feature info. num_step (int): Number of deformation steps.

Returns:

coord (Tensor): Deformed coordinates.

UM2N.model.MRN_atten module

class MRNAtten(gfe_in_c=2, lfe_in_c=4, deform_in_c=7, num_loop=3)[source]

Bases: Module

Mesh Refinement Network (MRN) with self attention implementing global and local feature

extraction

and recurrent graph-based deformations.

Attributes:

num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.

Initialize MRN.

Args:

gfe_in_c (int): Input channels for the global feature extractor. lfe_in_c (int): Input channels for the local feature extractor. deform_in_c (int): Input channels for the deformer block. num_loop (int): Number of loops for the recurrent layer.

forward(data)[source]

Forward pass for MRN.

Args:

data (Data): Input data object containing mesh and feature info.

Returns:

coord (Tensor): Deformed coordinates.

move(data, num_step=1)[source]
Move the mesh according to the deformation learned, with given number

steps.

Args:

data (Data): Input data object containing mesh and feature info. num_step (int): Number of deformation steps.

Returns:

coord (Tensor): Deformed coordinates.

UM2N.model.MRN_fix module

class MRN_fix(gfe_in_c=2, lfe_in_c=4, deform_in_c=7, num_loop=3)[source]

Bases: Module

Mesh Refinement Network (MRN) implementing global and local feature

extraction

and recurrent graph-based deformations.

Attributes:

num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.

Initialize MRN.

Args:

gfe_in_c (int): Input channels for the global feature extractor. lfe_in_c (int): Input channels for the local feature extractor. deform_in_c (int): Input channels for the deformer block. num_loop (int): Number of loops for the recurrent layer.

forward(data)[source]

Forward pass for MRN.

Args:

data (Data): Input data object containing mesh and feature info.

Returns:

coord (Tensor): Deformed coordinates.

move(data, num_step=1)[source]
Move the mesh according to the deformation learned, with given number

steps.

Args:

data (Data): Input data object containing mesh and feature info. num_step (int): Number of deformation steps.

Returns:

coord (Tensor): Deformed coordinates.

UM2N.model.MRN_phi module

class MRN_phi(gfe_in_c=2, lfe_in_c=4, deform_in_c=7, num_loop=3)[source]

Bases: Module

Mesh Recurrent Network (MRN) implementing global and local feature

extraction

and recurrent graph-based deformations for field phi.

Attributes:

num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.

Initialize MRN.

Args:

gfe_in_c (int): Input channels for the global feature extractor. lfe_in_c (int): Input channels for the local feature extractor. deform_in_c (int): Input channels for the deformer block. num_loop (int): Number of loops for the recurrent layer.

forward(data)[source]

Forward pass for MRN.

Args:

data (Data): Input data object containing mesh and feature info.

Returns:

coord (Tensor): Deformed coordinates.

class RecurrentGATConv(phi_size=1, hidden_size=512, heads=6, concat=False)[source]

Bases: MessagePassing

Implements a Recurrent Graph Attention Network (GAT) Convolution layer.

Attributes:

to_hidden (GATv2Conv): Graph Attention layer. to_coord (nn.Sequential): Output layer for coordinates. activation (nn.SELU): Activation function.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(phi, hidden_state, edge_index)[source]

Runs the forward pass of the module.

UM2N.model.MRT module

class MRTransformer(num_transformer_in=4, num_transformer_out=16, num_transformer_embed_dim=64, num_transformer_heads=4, num_transformer_layers=1, transformer_training_mask=False, transformer_key_padding_training_mask=False, transformer_attention_training_mask=False, transformer_training_mask_ratio_lower_bound=0.5, transformer_training_mask_ratio_upper_bound=0.9, deform_in_c=7, deform_out_type='coord', num_loop=3, device='cuda')[source]

Bases: Module

Mesh Refinement Network (MRN) implementing transformer as feature

extrator and recurrent graph-based deformations.

Attributes:

num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.

Initialize MRN.

Args:

gfe_in_c (int): Input channels for the global feature extractor. lfe_in_c (int): Input channels for the local feature extractor. deform_in_c (int): Input channels for the deformer block. num_loop (int): Number of loops for the recurrent layer.

forward(data, input_q, input_kv, mesh_query, sampled_queries, sampled_queries_edge_index, poly_mesh=False)[source]

Forward pass for MRN.

Args:

data (Data): Input data object containing mesh and feature info.

Returns:

coord (Tensor): Deformed coordinates.

get_attention_scores(data)[source]
move(data, input_q, input_kv, mesh_query, sampled_queries, sampled_queries_edge_index, num_step=1, poly_mesh=False)[source]
Move the mesh according to the deformation learned, with given number

steps.

Args:

data (Data): Input data object containing mesh and feature info. num_step (int): Number of deformation steps.

Returns:

coord (Tensor): Deformed coordinates.

transformer_monitor(data, input_q, input_kv, boundary)[source]

UM2N.model.MRT_PE module

UM2N.model.MRT_phi module

class MRT_phi(gfe_in_c=2, lfe_in_c=4, deform_in_c=7, num_loop=3)[source]

Bases: Module

Mesh Refinement Network (MRN) implementing global and local feature

extraction and recurrent graph-based deformations. The global feature extraction is performed by a transformer.

Attributes:

num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.

Initialize MRN.

Args:

gfe_in_c (int): Input channels for the global feature extractor. lfe_in_c (int): Input channels for the local feature extractor. deform_in_c (int): Input channels for the deformer block. num_loop (int): Number of loops for the recurrent layer.

forward(data)[source]

Forward pass for MRN.

Args:

data (Data): Input data object containing mesh and feature info.

Returns:

coord (Tensor): Deformed coordinates.

UM2N.model.deformer module

class RecurrentGATConv(coord_size=2, hidden_size=512, heads=6, output_type='coord', concat=False, device='cuda')[source]

Bases: MessagePassing

Implements a Recurrent Graph Attention Network (GAT) Convolution layer.

Attributes:

to_hidden (GATv2Conv): Graph Attention layer. to_coord (nn.Sequential): Output layer for coordinates. activation (nn.SELU): Activation function.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

find_boundary(in_data)[source]
fix_boundary(in_data)[source]
forward(coord, hidden_state, edge_index, coord_ori, bd_mask, poly_mesh)[source]

Runs the forward pass of the module.

UM2N.model.extractor module

class GlobalFeatExtractor(in_c, out_c, drop_p=0.2, use_drop=True)[source]

Bases: Module

Custom PyTorch layer for global feature extraction.

The class employs multiple convolutional layers and dropout layers.

Attributes:

conv1, conv2, conv3, conv4 (torch.nn.Conv2d): Convolutional layers. dropout (torch.nn.Dropout): Dropout layer. final_pool (torch.nn.AdaptiveAvgPool2d): Final pooling layer.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(data)[source]

Forward pass.

Args:

data (Tensor): Input data.

Returns:

Tensor: Extracted global features.

class LocalFeatExtractor(num_feat=10, out=16)[source]

Bases: MessagePassing

Custom PyTorch geometric layer that performs feature extraction on local graph structure.

The class extends the torch_geometric.nn.MessagePassing class and employs additive aggregation as the message-passing scheme.

Attributes:

lin_1 (torch.nn.Linear): First linear layer. lin_2 (torch.nn.Linear): Second linear layer. lin_3 (torch.nn.Linear): Third linear layer. activate (torch.nn.SELU): Activation function.

Initialize the layer.

Args:

num_feat (int): Number of input features per node. out (int): Number of output features per node.

forward(input, edge_index)[source]

Forward pass.

Args:

input (Tensor): Node features. edge_index (Tensor): Edge indices.

Returns:

Tensor: Updated node features.

message(x_i, x_j)[source]

Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in edge_index. This function can take any argument as input which was initially passed to propagate(). Furthermore, tensors passed to propagate() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

UM2N.model.gatdeformer module

class DeformGAT(in_channels: int, out_channels: int, heads: int = 1, concat: bool = False, negative_slope: float = 0.2, dropout: float = 0, bias: bool = False, **kwargs)[source]

Bases: MessagePassing

Initialize internal Module state, shared by both nn.Module and ScriptModule.

find_boundary(in_data)[source]
fix_boundary(in_data)[source]
forward(coords: Tensor | Tuple[Tensor, Tensor | None], features: Tensor | Tuple[Tensor, Tensor | None], edge_index: Tensor | SparseTensor, bd_mask, poly_mesh)[source]

Runs the forward pass of the module.

message(x_j: Tensor, alpha_j: Tensor, alpha_i: Tensor | None, index: Tensor, ptr: Tensor | None, size_i: int | None) Tensor[source]

Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in edge_index. This function can take any argument as input which was initially passed to propagate(). Furthermore, tensors passed to propagate() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

reset_parameters()[source]

Resets all learnable parameters of the module.

UM2N.model.train_util module

class TangleCounter(num_feat=10, out=16)[source]

Bases: MessagePassing

A PyTorch Geometric Message Passing class for counting tangles in the mesh. This class is deprecated, do not use this option unless you know what you are doing.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, x_new, edge_index)[source]

Runs the forward pass of the module.

message(x_i, x_j, x_new_i, x_new_j)[source]

Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in edge_index. This function can take any argument as input which was initially passed to propagate(). Furthermore, tensors passed to propagate() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

count_dataset_tangle(dataset, model, device, method='inversion')[source]

Computes the average number of tangles in a dataset.

Args:

dataset (Dataset): The PyTorch Geometric dataset. model (torch.nn.Module): The PyTorch model. device (torch.device): The device to run the computation.

Returns:

float: The average number of tangles in the dataset.

count_dataset_tangle_repeat_sampling(dataset, model, device, num_samples=1)[source]

Computes the average number of tangles in a dataset.

Args:

dataset (Dataset): The PyTorch Geometric dataset. model (torch.nn.Module): The PyTorch model. device (torch.device): The device to run the computation.

Returns:

float: The average number of tangles in the dataset.

evaluate(loader, model, device, loss_func, use_jacob=False, use_inversion_loss=False, use_inversion_diff_loss=False, use_area_loss=False, weight_deform_loss=1.0, weight_area_loss=1.0, weight_chamfer_loss=0.0, scaler=100)[source]

Evaluates a model using the given data loader and loss function.

Args:

loader (DataLoader): DataLoader object for the evaluation data. model (torch.nn.Module): The PyTorch model to evaluate. device (torch.device): The device to run the computation on. loss_func (callable): Loss function (e.g., MSE, Cross-Entropy). use_jacob (bool): Whether or not to use Jacobian loss. Defaults to.

Returns:

float: The average evaluation loss across all batches.

evaluate_repeat(dataset, model, device, loss_func, scaler=100, num_repeat=1)[source]

Evaluates model performance when sampling for different number of times. this function will evaluate:

  1. the average loss

  2. the average number of tangles

Args:

dataset (MeshDataset): The target dataset to evaluate. model (torch.nn.Module): The PyTorch model to evaluate. device (torch.device): The device to run the computation on. loss_func (callable): Loss function (e.g., MSE, Cross-Entropy).

Returns:

float: The average evaluation loss across all batches.

evaluate_repeat_sampling(dataset, model, device, loss_func, use_inversion_loss=False, use_inversion_diff_loss=False, use_area_loss=False, scaler=100, batch_size=5, num_samples=1)[source]

Evaluates a model using the given data loader and loss function.

Args:

loader (DataLoader): DataLoader object for the evaluation data. model (torch.nn.Module): The PyTorch model to evaluate. device (torch.device): The device to run the computation on. loss_func (callable): Loss function (e.g., MSE, Cross-Entropy). use_jacob (bool): Whether or not to use Jacobian loss. Defaults to.

Returns:

float: The average evaluation loss across all batches.

evaluate_unsupervised(loader, model, device, loss_func, use_jacob=False, use_inversion_loss=False, use_inversion_diff_loss=False, use_area_loss=False, use_convex_loss=False, use_add_random_query=True, finite_difference_grad=True, weight_area_loss=1, weight_deform_loss=1, weight_eq_residual_loss=1, weight_chamfer_loss=1, scaler=100)[source]

Evaluates a model using the given data loader and loss function.

Args:

loader (DataLoader): DataLoader object for the evaluation data. model (torch.nn.Module): The PyTorch model to evaluate. device (torch.device): The device to run the computation on. loss_func (callable): Loss function (e.g., MSE, Cross-Entropy). use_jacob (bool): Whether or not to use Jacobian loss. Defaults to.

Returns:

float: The average evaluation loss across all batches.

get_area_loss(out_coord, tar_coord, face, batch_size, scaler=100)[source]
get_face_area(coord, face)[source]
Calculates the area of a face. using formula:

area = 0.5 * (x1(y2 - y3) + x2(y3 - y1) + x3(y1 - y2))

Args:

coord (torch.Tensor): The coordinates. face (torch.Tensor): The face tensor.

get_inversion_diff_loss(out_coord, tar_coord, face, batch_size, scaler=100)[source]

Calculates the inversion difference loss for a batch of meshes. That is the difference between the output area and the input area, in terms of the invereted elements. Args:

out_coord (torch.Tensor): The output coordinates. tar_coord (torch.Tensor): The target coordinates. face (torch.Tensor): The face tensor. batch_size (int): The batch size. alpha (float): The loss weight.

get_inversion_loss(out_coord, in_coord, face, batch_size, scheme='relu', scaler=100)[source]

Calculates the inversion loss for a batch of meshes. Args:

out_coord (torch.Tensor): The output coordinates. in_coord (torch.Tensor): The input coordinates. face (torch.Tensor): The face tensor. batch_size (int): The batch size. alpha (float): The loss weight.

get_inversion_node_loss(out_coord, tar_coord, face, batch_size, scaler=1000)[source]

Calculates the loss between the ouput node and input node, for the inverted elements. This will penalise the node which are involved in the tangled elements. Args:

out_coord (torch.Tensor): The output coordinates. tar_coord (torch.Tensor): The target coordinates. face (torch.Tensor): The face tensor. batch_size (int): The batch size. alpha (float): The loss weight.

get_jacob_det(model, in_data)[source]
get_sample_tangle(out_coords, in_coords, face)[source]

Return the number of tangled elements in a single sample.

load_model(model, weight_path, strict=False)[source]

Loads pre-trained weights into a PyTorch model from a given file path.

Args:

model (torch.nn.Module): The PyTorch model. weight_path (str): File path to the pre-trained model weights.

Returns:

torch.nn.Module: The model with loaded weights.

train(loader, model, optimizer, device, loss_func, use_jacob=False, use_inversion_loss=False, use_inversion_diff_loss=False, use_area_loss=False, weight_deform_loss=1.0, weight_area_loss=1.0, weight_chamfer_loss=0.0, scaler=100)[source]
Trains a PyTorch model using the given data loader, optimizer,

and loss function.

Args:

loader (DataLoader): DataLoader object for the training data. model (torch.nn.Module): The PyTorch model to train. optimizer (Optimizer): The optimizer (e.g., Adam, SGD). device (torch.device): The device to run the computation on. loss_func (callable): Loss function (e.g., MSE, Cross-Entropy). use_jacob (bool): Whether or not to use Jacobian loss.

Returns:

float: The average training loss across all batches.

train_unsupervised(loader, model, optimizer, device, loss_func, use_jacob=False, use_inversion_loss=False, use_inversion_diff_loss=False, use_area_loss=False, use_convex_loss=False, use_add_random_query=True, finite_difference_grad=True, weight_area_loss=1, weight_deform_loss=1, weight_chamfer_loss=1, weight_eq_residual_loss=1, scaler=100)[source]
Trains a PyTorch model using the given data loader, optimizer,

and loss function.

Args:

loader (DataLoader): DataLoader object for the training data. model (torch.nn.Module): The PyTorch model to train. optimizer (Optimizer): The optimizer (e.g., Adam, SGD). device (torch.device): The device to run the computation on. loss_func (callable): Loss function (e.g., MSE, Cross-Entropy). use_jacob (bool): Whether or not to use Jacobian loss.

Returns:

float: The average training loss across all batches.

UM2N.model.transformer_model module

class MLP_model(input_channels, output_channels, list_hiddens=[128, 128], hidden_act='LeakyReLU', output_act='LeakyReLU', input_norm=None, dropout_prob=0.0)[source]

Bases: Module

Note that list_hiddens should be a list of hidden channels per MLP layers e.g. [64 128 64]

Args:

input_channels (_type_): _description_ list_hiddens (_type_): _description_ output_channels (_type_): _description_ hidden_act (str, optional): _description_. Defaults to “LeakyReLU”. output_act (str, optional): _description_. Defaults to “LeakyReLU”. dropout_prob (float, optional): _description_. Defaults to 0.0. input_norm (float, optional): one of [“BatchNorm1d”, “LayerNorm”]

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class TransformerBlock(embed_dim, num_heads, dense_ratio=4, list_dropout=[0.1, 0.1, 0.1], activation='GELU')[source]

Bases: Module

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, k, v, x_cls=None, key_padding_mask=None, attn_mask=None, return_attn=False)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class TransformerModel(*, input_dim, embed_dim, output_dim, num_heads=4, num_layers=3)[source]

Bases: Module

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, k, v, key_padding_mask=None, attention_mask=None)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_attention_scores(x, key_padding_mask=None, attn_mask=None)[source]

Module contents