UM2N.model package¶
Submodules¶
UM2N.model.M2N module¶
- class NetGATDeform(in_dim)[source]¶
Bases:
Module
- forward(data, edge_idx, bd_mask, poly_mesh)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class M2N(gfe_in_c=1, lfe_in_c=3, deform_in_c=7, use_drop=False)[source]¶
Bases:
Module
- forward(data, poly_mesh=False)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
UM2N.model.M2N_T module¶
- class NetGATDeform(in_dim)[source]¶
Bases:
Module
- forward(data, edge_idx, bd_mask, poly_mesh)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class M2N_T(gfe_in_c=3, lfe_in_c=3, deform_in_c=3)[source]¶
Bases:
Module
- forward(data, poly_mesh=False)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
UM2N.model.M2N_atten module¶
- class NetGATDeform(in_dim)[source]¶
Bases:
Module
- forward(data, edge_idx)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class M2NAtten(gfe_in_c=1, lfe_in_c=3, deform_in_c=7, use_drop=False)[source]¶
Bases:
Module
- forward(data)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
UM2N.model.M2N_dynamic_drop module¶
- class NetGATDeform(in_dim)[source]¶
Bases:
Module
- forward(data, edge_idx)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class M2N_dynamic_drop(gfe_in_c=1, lfe_in_c=3, deform_in_c=7)[source]¶
Bases:
Module
- forward(data)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
UM2N.model.M2N_dynamic_no_drop module¶
- class NetGATDeform(in_dim)[source]¶
Bases:
Module
- forward(data, edge_idx)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class M2N_dynamic_no_drop(gfe_in_c=1, lfe_in_c=3, deform_in_c=7)[source]¶
Bases:
Module
- forward(data)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
UM2N.model.M2T module¶
- class M2T(num_transformer_in=4, num_transformer_out=16, num_transformer_embed_dim=64, num_transformer_heads=4, num_transformer_layers=1, transformer_training_mask=False, transformer_key_padding_training_mask=False, transformer_attention_training_mask=False, transformer_training_mask_ratio_lower_bound=0.5, transformer_training_mask_ratio_upper_bound=0.9, deform_in_c=7, deform_out_type='coord', local_feature_dim_in=4, num_loop=3, device='cuda')[source]¶
Bases:
Module
- Mesh Refinement Network (MRN) implementing transformer as feature
extrator and recurrent graph-based deformations.
- Attributes:
num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.
UM2N.model.M2T_deformer module¶
- class M2TDeformer(feature_in_dim, local_feature_dim_in, coord_size=2, hidden_size=512, heads=6, output_type='coord', concat=False, device='cuda')[source]¶
Bases:
MessagePassing
Implements a M2TDeformer.
- Attributes:
to_hidden (GATv2Conv): Graph Attention layer. to_coord (nn.Sequential): Output layer for coordinates. activation (nn.SELU): Activation function.
UM2N.model.MRN module¶
- class RecurrentGATConv(coord_size=2, hidden_size=512, heads=6, concat=False)[source]¶
Bases:
MessagePassing
Implements a Recurrent Graph Attention Network (GAT) Convolution layer.
- Attributes:
to_hidden (GATv2Conv): Graph Attention layer. to_coord (nn.Sequential): Output layer for coordinates. activation (nn.SELU): Activation function.
- class MRN(gfe_in_c=2, lfe_in_c=4, deform_in_c=7, num_loop=3)[source]¶
Bases:
Module
- Mesh Refinement Network (MRN) implementing global and local feature
extraction
and recurrent graph-based deformations.
- Attributes:
num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.
UM2N.model.MRN_GTE module¶
- class MRNGlobalTransformerEncoder(gfe_in_c=2, lfe_in_c=4, deform_in_c=7, num_loop=3)[source]¶
Bases:
Module
- Mesh Refinement Network (MRN) implementing global and local feature
extraction and recurrent graph-based deformations. The global feature extraction is performed by a transformer.
- Attributes:
num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.
UM2N.model.MRN_LTE module¶
- class MRNLocalTransformerEncoder(gfe_in_c=2, lfe_in_c=4, deform_in_c=7, num_loop=3)[source]¶
Bases:
Module
- Mesh Refinement Network (MRN) implementing global and local feature
extraction and recurrent graph-based deformations. The global feature extraction is performed by a transformer.
- Attributes:
num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.
UM2N.model.MRN_atten module¶
- class RecurrentGATConv(coord_size=2, hidden_size=512, heads=6, concat=False)[source]¶
Bases:
MessagePassing
Implements a Recurrent Graph Attention Network (GAT) Convolution layer.
- Attributes:
to_hidden (GATv2Conv): Graph Attention layer. to_coord (nn.Sequential): Output layer for coordinates. activation (nn.SELU): Activation function.
- class MRNAtten(gfe_in_c=2, lfe_in_c=4, deform_in_c=7, num_loop=3)[source]¶
Bases:
Module
- Mesh Refinement Network (MRN) with self attention implementing global and local feature
extraction
and recurrent graph-based deformations.
- Attributes:
num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.
UM2N.model.MRN_fix module¶
- class RecurrentGATConv(coord_size=2, hidden_size=512, heads=6, concat=False)[source]¶
Bases:
MessagePassing
Implements a Recurrent Graph Attention Network (GAT) Convolution layer.
- Attributes:
to_hidden (GATv2Conv): Graph Attention layer. to_coord (nn.Sequential): Output layer for coordinates. activation (nn.SELU): Activation function.
- class MRN_fix(gfe_in_c=2, lfe_in_c=4, deform_in_c=7, num_loop=3)[source]¶
Bases:
Module
- Mesh Refinement Network (MRN) implementing global and local feature
extraction
and recurrent graph-based deformations.
- Attributes:
num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.
UM2N.model.MRN_phi module¶
- class RecurrentGATConv(phi_size=1, hidden_size=512, heads=6, concat=False)[source]¶
Bases:
MessagePassing
Implements a Recurrent Graph Attention Network (GAT) Convolution layer.
- Attributes:
to_hidden (GATv2Conv): Graph Attention layer. to_coord (nn.Sequential): Output layer for coordinates. activation (nn.SELU): Activation function.
- class MRN_phi(gfe_in_c=2, lfe_in_c=4, deform_in_c=7, num_loop=3)[source]¶
Bases:
Module
- Mesh Recurrent Network (MRN) implementing global and local feature
extraction
and recurrent graph-based deformations for field phi.
- Attributes:
num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.
UM2N.model.MRT module¶
- class MRTransformer(num_transformer_in=4, num_transformer_out=16, num_transformer_embed_dim=64, num_transformer_heads=4, num_transformer_layers=1, transformer_training_mask=False, transformer_key_padding_training_mask=False, transformer_attention_training_mask=False, transformer_training_mask_ratio_lower_bound=0.5, transformer_training_mask_ratio_upper_bound=0.9, deform_in_c=7, deform_out_type='coord', num_loop=3, device='cuda')[source]¶
Bases:
Module
- Mesh Refinement Network (MRN) implementing transformer as feature
extrator and recurrent graph-based deformations.
- Attributes:
num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.
- move(data, input_q, input_kv, mesh_query, sampled_queries, sampled_queries_edge_index, num_step=1, poly_mesh=False)[source]¶
- Move the mesh according to the deformation learned, with given number
steps.
- Args:
data (Data): Input data object containing mesh and feature info. num_step (int): Number of deformation steps.
- Returns:
coord (Tensor): Deformed coordinates.
UM2N.model.MRT_PE module¶
UM2N.model.MRT_phi module¶
- class RecurrentGATConv(in_size=1, hidden_size=512, heads=6, concat=False)[source]¶
Bases:
MessagePassing
Implements a Recurrent Graph Attention Network (GAT) Convolution layer.
- Attributes:
to_hidden (GATv2Conv): Graph Attention layer. to_coord (nn.Sequential): Output layer for coordinates. activation (nn.SELU): Activation function.
- class MRT_phi(gfe_in_c=2, lfe_in_c=4, deform_in_c=7, num_loop=3)[source]¶
Bases:
Module
- Mesh Refinement Network (MRN) implementing global and local feature
extraction and recurrent graph-based deformations. The global feature extraction is performed by a transformer.
- Attributes:
num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.
UM2N.model.deformer module¶
- class RecurrentGATConv(coord_size=2, hidden_size=512, heads=6, output_type='coord', concat=False, device='cuda')[source]¶
Bases:
MessagePassing
Implements a Recurrent Graph Attention Network (GAT) Convolution layer.
- Attributes:
to_hidden (GATv2Conv): Graph Attention layer. to_coord (nn.Sequential): Output layer for coordinates. activation (nn.SELU): Activation function.
UM2N.model.extractor module¶
- class LocalFeatExtractor(num_feat=10, out=16)[source]¶
Bases:
MessagePassing
Custom PyTorch geometric layer that performs feature extraction on local graph structure.
The class extends the torch_geometric.nn.MessagePassing class and employs additive aggregation as the message-passing scheme.
- Attributes:
lin_1 (torch.nn.Linear): First linear layer. lin_2 (torch.nn.Linear): Second linear layer. lin_3 (torch.nn.Linear): Third linear layer. activate (torch.nn.SELU): Activation function.
- forward(input, edge_index)[source]¶
Forward pass.
- Args:
input (Tensor): Node features. edge_index (Tensor): Edge indices.
- Returns:
Tensor: Updated node features.
- message(x_i, x_j)[source]¶
Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in
edge_index
. This function can take any argument as input which was initially passed topropagate()
. Furthermore, tensors passed topropagate()
can be mapped to the respective nodes \(i\) and \(j\) by appending_i
or_j
to the variable name, .e.g.x_i
andx_j
.
- class GlobalFeatExtractor(in_c, out_c, drop_p=0.2, use_drop=True)[source]¶
Bases:
Module
Custom PyTorch layer for global feature extraction.
The class employs multiple convolutional layers and dropout layers.
- Attributes:
conv1, conv2, conv3, conv4 (torch.nn.Conv2d): Convolutional layers. dropout (torch.nn.Dropout): Dropout layer. final_pool (torch.nn.AdaptiveAvgPool2d): Final pooling layer.
- class TransformerEncoder(num_transformer_in=3, num_transformer_out=16, num_transformer_embed_dim=64, num_transformer_heads=4, num_transformer_layers=1, transformer_training_mask=False, transformer_key_padding_training_mask=False, transformer_attention_training_mask=False, transformer_training_mask_ratio_lower_bound=0.5, transformer_training_mask_ratio_upper_bound=0.9, deform_in_c=7, deform_out_type='coord', num_loop=3, device='cuda')[source]¶
Bases:
Module
- Mesh Refinement Network (MRN) implementing transformer as feature
extrator and recurrent graph-based deformations.
- Attributes:
num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.
- forward(data)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
UM2N.model.gatdeformer module¶
- class DeformGAT(in_channels: int, out_channels: int, heads: int = 1, concat: bool = False, negative_slope: float = 0.2, dropout: float = 0, bias: bool = False, **kwargs)[source]¶
Bases:
MessagePassing
- forward(coords: Tensor | Tuple[Tensor, Tensor | None], features: Tensor | Tuple[Tensor, Tensor | None], edge_index: Tensor | SparseTensor, bd_mask, poly_mesh)[source]¶
Runs the forward pass of the module.
- message(x_j: Tensor, alpha_j: Tensor, alpha_i: Tensor | None, index: Tensor, ptr: Tensor | None, size_i: int | None) Tensor [source]¶
Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in
edge_index
. This function can take any argument as input which was initially passed topropagate()
. Furthermore, tensors passed topropagate()
can be mapped to the respective nodes \(i\) and \(j\) by appending_i
or_j
to the variable name, .e.g.x_i
andx_j
.
UM2N.model.train_util module¶
- get_face_area(coord, face)[source]¶
- Calculates the area of a face. using formula:
area = 0.5 * (x1(y2 - y3) + x2(y3 - y1) + x3(y1 - y2))
- Args:
coord (torch.Tensor): The coordinates. face (torch.Tensor): The face tensor.
- get_inversion_loss(out_coord, in_coord, face, batch_size, scheme='relu', scaler=100)[source]¶
Calculates the inversion loss for a batch of meshes. Args:
out_coord (torch.Tensor): The output coordinates. in_coord (torch.Tensor): The input coordinates. face (torch.Tensor): The face tensor. batch_size (int): The batch size. alpha (float): The loss weight.
- get_inversion_diff_loss(out_coord, tar_coord, face, batch_size, scaler=100)[source]¶
Calculates the inversion difference loss for a batch of meshes. That is the difference between the output area and the input area, in terms of the invereted elements. Args:
out_coord (torch.Tensor): The output coordinates. tar_coord (torch.Tensor): The target coordinates. face (torch.Tensor): The face tensor. batch_size (int): The batch size. alpha (float): The loss weight.
- get_inversion_node_loss(out_coord, tar_coord, face, batch_size, scaler=1000)[source]¶
Calculates the loss between the ouput node and input node, for the inverted elements. This will penalise the node which are involved in the tangled elements. Args:
out_coord (torch.Tensor): The output coordinates. tar_coord (torch.Tensor): The target coordinates. face (torch.Tensor): The face tensor. batch_size (int): The batch size. alpha (float): The loss weight.
- class TangleCounter(num_feat=10, out=16)[source]¶
Bases:
MessagePassing
A PyTorch Geometric Message Passing class for counting tangles in the mesh. This class is deprecated, do not use this option unless you know what you are doing.
- message(x_i, x_j, x_new_i, x_new_j)[source]¶
Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in
edge_index
. This function can take any argument as input which was initially passed topropagate()
. Furthermore, tensors passed topropagate()
can be mapped to the respective nodes \(i\) and \(j\) by appending_i
or_j
to the variable name, .e.g.x_i
andx_j
.
- train(loader, model, optimizer, device, loss_func, use_jacob=False, use_inversion_loss=False, use_inversion_diff_loss=False, use_area_loss=False, weight_deform_loss=1.0, weight_area_loss=1.0, weight_chamfer_loss=0.0, scaler=100)[source]¶
- Trains a PyTorch model using the given data loader, optimizer,
and loss function.
- Args:
loader (DataLoader): DataLoader object for the training data. model (torch.nn.Module): The PyTorch model to train. optimizer (Optimizer): The optimizer (e.g., Adam, SGD). device (torch.device): The device to run the computation on. loss_func (callable): Loss function (e.g., MSE, Cross-Entropy). use_jacob (bool): Whether or not to use Jacobian loss.
- Returns:
float: The average training loss across all batches.
- evaluate(loader, model, device, loss_func, use_jacob=False, use_inversion_loss=False, use_inversion_diff_loss=False, use_area_loss=False, weight_deform_loss=1.0, weight_area_loss=1.0, weight_chamfer_loss=0.0, scaler=100)[source]¶
Evaluates a model using the given data loader and loss function.
- Args:
loader (DataLoader): DataLoader object for the evaluation data. model (torch.nn.Module): The PyTorch model to evaluate. device (torch.device): The device to run the computation on. loss_func (callable): Loss function (e.g., MSE, Cross-Entropy). use_jacob (bool): Whether or not to use Jacobian loss. Defaults to.
- Returns:
float: The average evaluation loss across all batches.
- interpolate(u, ori_mesh_x, ori_mesh_y, moved_x, moved_y)[source]¶
u: [bs, node_num, 1] ori_mesh_x: [bs, node_num, 1] ori_mesh_y: [bs, node_num, 1] moved_x: [bs, node_num, 1] moved_y: [bs, node_num, 1]
Note: node_num equals to sample_num
- compute_phi_hessian(mesh_query_x, mesh_query_y, phix, phiy, out_monitor, bs, data, loss_func, finite_difference_grad=False)[source]¶
- sample_nodes_by_monitor(meshes, meshes_target, monitors, num_samples_per_mesh=100, random_seed=666)[source]¶
- train_unsupervised(loader, model, optimizer, device, loss_func, use_jacob=False, use_inversion_loss=False, use_inversion_diff_loss=False, use_area_loss=False, use_convex_loss=False, use_add_random_query=True, finite_difference_grad=True, weight_area_loss=1, weight_deform_loss=1, weight_chamfer_loss=1, weight_eq_residual_loss=1, scaler=100)[source]¶
- Trains a PyTorch model using the given data loader, optimizer,
and loss function.
- Args:
loader (DataLoader): DataLoader object for the training data. model (torch.nn.Module): The PyTorch model to train. optimizer (Optimizer): The optimizer (e.g., Adam, SGD). device (torch.device): The device to run the computation on. loss_func (callable): Loss function (e.g., MSE, Cross-Entropy). use_jacob (bool): Whether or not to use Jacobian loss.
- Returns:
float: The average training loss across all batches.
- evaluate_unsupervised(loader, model, device, loss_func, use_jacob=False, use_inversion_loss=False, use_inversion_diff_loss=False, use_area_loss=False, use_convex_loss=False, use_add_random_query=True, finite_difference_grad=True, weight_area_loss=1, weight_deform_loss=1, weight_eq_residual_loss=1, weight_chamfer_loss=1, scaler=100)[source]¶
Evaluates a model using the given data loader and loss function.
- Args:
loader (DataLoader): DataLoader object for the evaluation data. model (torch.nn.Module): The PyTorch model to evaluate. device (torch.device): The device to run the computation on. loss_func (callable): Loss function (e.g., MSE, Cross-Entropy). use_jacob (bool): Whether or not to use Jacobian loss. Defaults to.
- Returns:
float: The average evaluation loss across all batches.
- get_sample_tangle(out_coords, in_coords, face)[source]¶
Return the number of tangled elements in a single sample.
- count_dataset_tangle(dataset, model, device, method='inversion')[source]¶
Computes the average number of tangles in a dataset.
- Args:
dataset (Dataset): The PyTorch Geometric dataset. model (torch.nn.Module): The PyTorch model. device (torch.device): The device to run the computation.
- Returns:
float: The average number of tangles in the dataset.
- evaluate_repeat_sampling(dataset, model, device, loss_func, use_inversion_loss=False, use_inversion_diff_loss=False, use_area_loss=False, scaler=100, batch_size=5, num_samples=1)[source]¶
Evaluates a model using the given data loader and loss function.
- Args:
loader (DataLoader): DataLoader object for the evaluation data. model (torch.nn.Module): The PyTorch model to evaluate. device (torch.device): The device to run the computation on. loss_func (callable): Loss function (e.g., MSE, Cross-Entropy). use_jacob (bool): Whether or not to use Jacobian loss. Defaults to.
- Returns:
float: The average evaluation loss across all batches.
- count_dataset_tangle_repeat_sampling(dataset, model, device, num_samples=1)[source]¶
Computes the average number of tangles in a dataset.
- Args:
dataset (Dataset): The PyTorch Geometric dataset. model (torch.nn.Module): The PyTorch model. device (torch.device): The device to run the computation.
- Returns:
float: The average number of tangles in the dataset.
- evaluate_repeat(dataset, model, device, loss_func, scaler=100, num_repeat=1)[source]¶
Evaluates model performance when sampling for different number of times. this function will evaluate:
the average loss
the average number of tangles
- Args:
dataset (MeshDataset): The target dataset to evaluate. model (torch.nn.Module): The PyTorch model to evaluate. device (torch.device): The device to run the computation on. loss_func (callable): Loss function (e.g., MSE, Cross-Entropy).
- Returns:
float: The average evaluation loss across all batches.
UM2N.model.transformer_model module¶
- class MLP_model(input_channels, output_channels, list_hiddens=[128, 128], hidden_act='LeakyReLU', output_act='LeakyReLU', input_norm=None, dropout_prob=0.0)[source]¶
Bases:
Module
- forward(x)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class TransformerBlock(embed_dim, num_heads, dense_ratio=4, list_dropout=[0.1, 0.1, 0.1], activation='GELU')[source]¶
Bases:
Module
- forward(x, k, v, x_cls=None, key_padding_mask=None, attn_mask=None, return_attn=False)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class TransformerModel(*, input_dim, embed_dim, output_dim, num_heads=4, num_layers=3)[source]¶
Bases:
Module
- forward(x, k, v, key_padding_mask=None, attention_mask=None)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.