UM2N.model package¶
Submodules¶
UM2N.model.M2N module¶
- class M2N(gfe_in_c=1, lfe_in_c=3, deform_in_c=7, use_drop=False)[source]¶
Bases:
Module
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(data, poly_mesh=False)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
UM2N.model.M2N_T module¶
- class M2N_T(gfe_in_c=3, lfe_in_c=3, deform_in_c=3)[source]¶
Bases:
Module
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(data, poly_mesh=False)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
UM2N.model.M2N_atten module¶
- class M2NAtten(gfe_in_c=1, lfe_in_c=3, deform_in_c=7, use_drop=False)[source]¶
Bases:
Module
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(data)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
UM2N.model.M2N_dynamic_drop module¶
- class M2N_dynamic_drop(gfe_in_c=1, lfe_in_c=3, deform_in_c=7)[source]¶
Bases:
Module
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(data)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
UM2N.model.M2N_dynamic_no_drop module¶
- class M2N_dynamic_no_drop(gfe_in_c=1, lfe_in_c=3, deform_in_c=7)[source]¶
Bases:
Module
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(data)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
UM2N.model.M2T module¶
- class M2T(num_transformer_in=4, num_transformer_out=16, num_transformer_embed_dim=64, num_transformer_heads=4, num_transformer_layers=1, transformer_training_mask=False, transformer_key_padding_training_mask=False, transformer_attention_training_mask=False, transformer_training_mask_ratio_lower_bound=0.5, transformer_training_mask_ratio_upper_bound=0.9, deform_in_c=7, deform_out_type='coord', local_feature_dim_in=4, num_loop=3, device='cuda')[source]¶
Bases:
Module
- Mesh Refinement Network (MRN) implementing transformer as feature
extrator and recurrent graph-based deformations.
- Attributes:
num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.
Initialize MRN.
- Args:
gfe_in_c (int): Input channels for the global feature extractor. lfe_in_c (int): Input channels for the local feature extractor. deform_in_c (int): Input channels for the deformer block. num_loop (int): Number of loops for the recurrent layer.
UM2N.model.M2T_deformer module¶
- class M2TDeformer(feature_in_dim, local_feature_dim_in, coord_size=2, hidden_size=512, heads=6, output_type='coord', concat=False, device='cuda')[source]¶
Bases:
MessagePassing
Implements a M2TDeformer.
- Attributes:
to_hidden (GATv2Conv): Graph Attention layer. to_coord (nn.Sequential): Output layer for coordinates. activation (nn.SELU): Activation function.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
UM2N.model.MRN module¶
- class MRN(gfe_in_c=2, lfe_in_c=4, deform_in_c=7, num_loop=3)[source]¶
Bases:
Module
- Mesh Refinement Network (MRN) implementing global and local feature
extraction
and recurrent graph-based deformations.
- Attributes:
num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.
Initialize MRN.
- Args:
gfe_in_c (int): Input channels for the global feature extractor. lfe_in_c (int): Input channels for the local feature extractor. deform_in_c (int): Input channels for the deformer block. num_loop (int): Number of loops for the recurrent layer.
- class RecurrentGATConv(coord_size=2, hidden_size=512, heads=6, concat=False)[source]¶
Bases:
MessagePassing
Implements a Recurrent Graph Attention Network (GAT) Convolution layer.
- Attributes:
to_hidden (GATv2Conv): Graph Attention layer. to_coord (nn.Sequential): Output layer for coordinates. activation (nn.SELU): Activation function.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
UM2N.model.MRN_GTE module¶
- class MRNGlobalTransformerEncoder(gfe_in_c=2, lfe_in_c=4, deform_in_c=7, num_loop=3)[source]¶
Bases:
Module
- Mesh Refinement Network (MRN) implementing global and local feature
extraction and recurrent graph-based deformations. The global feature extraction is performed by a transformer.
- Attributes:
num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.
Initialize MRN.
- Args:
gfe_in_c (int): Input channels for the global feature extractor. lfe_in_c (int): Input channels for the local feature extractor. deform_in_c (int): Input channels for the deformer block. num_loop (int): Number of loops for the recurrent layer.
UM2N.model.MRN_LTE module¶
- class MRNLocalTransformerEncoder(gfe_in_c=2, lfe_in_c=4, deform_in_c=7, num_loop=3)[source]¶
Bases:
Module
- Mesh Refinement Network (MRN) implementing global and local feature
extraction and recurrent graph-based deformations. The global feature extraction is performed by a transformer.
- Attributes:
num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.
Initialize MRN.
- Args:
gfe_in_c (int): Input channels for the global feature extractor. lfe_in_c (int): Input channels for the local feature extractor. deform_in_c (int): Input channels for the deformer block. num_loop (int): Number of loops for the recurrent layer.
UM2N.model.MRN_atten module¶
- class MRNAtten(gfe_in_c=2, lfe_in_c=4, deform_in_c=7, num_loop=3)[source]¶
Bases:
Module
- Mesh Refinement Network (MRN) with self attention implementing global and local feature
extraction
and recurrent graph-based deformations.
- Attributes:
num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.
Initialize MRN.
- Args:
gfe_in_c (int): Input channels for the global feature extractor. lfe_in_c (int): Input channels for the local feature extractor. deform_in_c (int): Input channels for the deformer block. num_loop (int): Number of loops for the recurrent layer.
UM2N.model.MRN_fix module¶
- class MRN_fix(gfe_in_c=2, lfe_in_c=4, deform_in_c=7, num_loop=3)[source]¶
Bases:
Module
- Mesh Refinement Network (MRN) implementing global and local feature
extraction
and recurrent graph-based deformations.
- Attributes:
num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.
Initialize MRN.
- Args:
gfe_in_c (int): Input channels for the global feature extractor. lfe_in_c (int): Input channels for the local feature extractor. deform_in_c (int): Input channels for the deformer block. num_loop (int): Number of loops for the recurrent layer.
UM2N.model.MRN_phi module¶
- class MRN_phi(gfe_in_c=2, lfe_in_c=4, deform_in_c=7, num_loop=3)[source]¶
Bases:
Module
- Mesh Recurrent Network (MRN) implementing global and local feature
extraction
and recurrent graph-based deformations for field phi.
- Attributes:
num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.
Initialize MRN.
- Args:
gfe_in_c (int): Input channels for the global feature extractor. lfe_in_c (int): Input channels for the local feature extractor. deform_in_c (int): Input channels for the deformer block. num_loop (int): Number of loops for the recurrent layer.
- class RecurrentGATConv(phi_size=1, hidden_size=512, heads=6, concat=False)[source]¶
Bases:
MessagePassing
Implements a Recurrent Graph Attention Network (GAT) Convolution layer.
- Attributes:
to_hidden (GATv2Conv): Graph Attention layer. to_coord (nn.Sequential): Output layer for coordinates. activation (nn.SELU): Activation function.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
UM2N.model.MRT module¶
- class MRTransformer(num_transformer_in=4, num_transformer_out=16, num_transformer_embed_dim=64, num_transformer_heads=4, num_transformer_layers=1, transformer_training_mask=False, transformer_key_padding_training_mask=False, transformer_attention_training_mask=False, transformer_training_mask_ratio_lower_bound=0.5, transformer_training_mask_ratio_upper_bound=0.9, deform_in_c=7, deform_out_type='coord', num_loop=3, device='cuda')[source]¶
Bases:
Module
- Mesh Refinement Network (MRN) implementing transformer as feature
extrator and recurrent graph-based deformations.
- Attributes:
num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.
Initialize MRN.
- Args:
gfe_in_c (int): Input channels for the global feature extractor. lfe_in_c (int): Input channels for the local feature extractor. deform_in_c (int): Input channels for the deformer block. num_loop (int): Number of loops for the recurrent layer.
- forward(data, input_q, input_kv, mesh_query, sampled_queries, sampled_queries_edge_index, poly_mesh=False)[source]¶
Forward pass for MRN.
- Args:
data (Data): Input data object containing mesh and feature info.
- Returns:
coord (Tensor): Deformed coordinates.
- move(data, input_q, input_kv, mesh_query, sampled_queries, sampled_queries_edge_index, num_step=1, poly_mesh=False)[source]¶
- Move the mesh according to the deformation learned, with given number
steps.
- Args:
data (Data): Input data object containing mesh and feature info. num_step (int): Number of deformation steps.
- Returns:
coord (Tensor): Deformed coordinates.
UM2N.model.MRT_PE module¶
UM2N.model.MRT_phi module¶
- class MRT_phi(gfe_in_c=2, lfe_in_c=4, deform_in_c=7, num_loop=3)[source]¶
Bases:
Module
- Mesh Refinement Network (MRN) implementing global and local feature
extraction and recurrent graph-based deformations. The global feature extraction is performed by a transformer.
- Attributes:
num_loop (int): Number of loops for the recurrent layer. gfe_out_c (int): Output channels for global feature extractor. lfe_out_c (int): Output channels for local feature extractor. hidden_size (int): Size of the hidden layer. gfe (GlobalFeatExtractor): Global feature extractor. lfe (LocalFeatExtractor): Local feature extractor. lin (nn.Linear): Linear layer for feature transformation. deformer (RecurrentGATConv): GAT-based deformer block.
Initialize MRN.
- Args:
gfe_in_c (int): Input channels for the global feature extractor. lfe_in_c (int): Input channels for the local feature extractor. deform_in_c (int): Input channels for the deformer block. num_loop (int): Number of loops for the recurrent layer.
UM2N.model.deformer module¶
- class RecurrentGATConv(coord_size=2, hidden_size=512, heads=6, output_type='coord', concat=False, device='cuda')[source]¶
Bases:
MessagePassing
Implements a Recurrent Graph Attention Network (GAT) Convolution layer.
- Attributes:
to_hidden (GATv2Conv): Graph Attention layer. to_coord (nn.Sequential): Output layer for coordinates. activation (nn.SELU): Activation function.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
UM2N.model.extractor module¶
- class GlobalFeatExtractor(in_c, out_c, drop_p=0.2, use_drop=True)[source]¶
Bases:
Module
Custom PyTorch layer for global feature extraction.
The class employs multiple convolutional layers and dropout layers.
- Attributes:
conv1, conv2, conv3, conv4 (torch.nn.Conv2d): Convolutional layers. dropout (torch.nn.Dropout): Dropout layer. final_pool (torch.nn.AdaptiveAvgPool2d): Final pooling layer.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class LocalFeatExtractor(num_feat=10, out=16)[source]¶
Bases:
MessagePassing
Custom PyTorch geometric layer that performs feature extraction on local graph structure.
The class extends the torch_geometric.nn.MessagePassing class and employs additive aggregation as the message-passing scheme.
- Attributes:
lin_1 (torch.nn.Linear): First linear layer. lin_2 (torch.nn.Linear): Second linear layer. lin_3 (torch.nn.Linear): Third linear layer. activate (torch.nn.SELU): Activation function.
Initialize the layer.
- Args:
num_feat (int): Number of input features per node. out (int): Number of output features per node.
- forward(input, edge_index)[source]¶
Forward pass.
- Args:
input (Tensor): Node features. edge_index (Tensor): Edge indices.
- Returns:
Tensor: Updated node features.
- message(x_i, x_j)[source]¶
Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in
edge_index
. This function can take any argument as input which was initially passed topropagate()
. Furthermore, tensors passed topropagate()
can be mapped to the respective nodes \(i\) and \(j\) by appending_i
or_j
to the variable name, .e.g.x_i
andx_j
.
UM2N.model.gatdeformer module¶
- class DeformGAT(in_channels: int, out_channels: int, heads: int = 1, concat: bool = False, negative_slope: float = 0.2, dropout: float = 0, bias: bool = False, **kwargs)[source]¶
Bases:
MessagePassing
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(coords: Tensor | Tuple[Tensor, Tensor | None], features: Tensor | Tuple[Tensor, Tensor | None], edge_index: Tensor | SparseTensor, bd_mask, poly_mesh)[source]¶
Runs the forward pass of the module.
- message(x_j: Tensor, alpha_j: Tensor, alpha_i: Tensor | None, index: Tensor, ptr: Tensor | None, size_i: int | None) Tensor [source]¶
Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in
edge_index
. This function can take any argument as input which was initially passed topropagate()
. Furthermore, tensors passed topropagate()
can be mapped to the respective nodes \(i\) and \(j\) by appending_i
or_j
to the variable name, .e.g.x_i
andx_j
.
UM2N.model.train_util module¶
- class TangleCounter(num_feat=10, out=16)[source]¶
Bases:
MessagePassing
A PyTorch Geometric Message Passing class for counting tangles in the mesh. This class is deprecated, do not use this option unless you know what you are doing.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- message(x_i, x_j, x_new_i, x_new_j)[source]¶
Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in
edge_index
. This function can take any argument as input which was initially passed topropagate()
. Furthermore, tensors passed topropagate()
can be mapped to the respective nodes \(i\) and \(j\) by appending_i
or_j
to the variable name, .e.g.x_i
andx_j
.
- count_dataset_tangle(dataset, model, device, method='inversion')[source]¶
Computes the average number of tangles in a dataset.
- Args:
dataset (Dataset): The PyTorch Geometric dataset. model (torch.nn.Module): The PyTorch model. device (torch.device): The device to run the computation.
- Returns:
float: The average number of tangles in the dataset.
- count_dataset_tangle_repeat_sampling(dataset, model, device, num_samples=1)[source]¶
Computes the average number of tangles in a dataset.
- Args:
dataset (Dataset): The PyTorch Geometric dataset. model (torch.nn.Module): The PyTorch model. device (torch.device): The device to run the computation.
- Returns:
float: The average number of tangles in the dataset.
- evaluate(loader, model, device, loss_func, use_jacob=False, use_inversion_loss=False, use_inversion_diff_loss=False, use_area_loss=False, weight_deform_loss=1.0, weight_area_loss=1.0, weight_chamfer_loss=0.0, scaler=100)[source]¶
Evaluates a model using the given data loader and loss function.
- Args:
loader (DataLoader): DataLoader object for the evaluation data. model (torch.nn.Module): The PyTorch model to evaluate. device (torch.device): The device to run the computation on. loss_func (callable): Loss function (e.g., MSE, Cross-Entropy). use_jacob (bool): Whether or not to use Jacobian loss. Defaults to.
- Returns:
float: The average evaluation loss across all batches.
- evaluate_repeat(dataset, model, device, loss_func, scaler=100, num_repeat=1)[source]¶
Evaluates model performance when sampling for different number of times. this function will evaluate:
the average loss
the average number of tangles
- Args:
dataset (MeshDataset): The target dataset to evaluate. model (torch.nn.Module): The PyTorch model to evaluate. device (torch.device): The device to run the computation on. loss_func (callable): Loss function (e.g., MSE, Cross-Entropy).
- Returns:
float: The average evaluation loss across all batches.
- evaluate_repeat_sampling(dataset, model, device, loss_func, use_inversion_loss=False, use_inversion_diff_loss=False, use_area_loss=False, scaler=100, batch_size=5, num_samples=1)[source]¶
Evaluates a model using the given data loader and loss function.
- Args:
loader (DataLoader): DataLoader object for the evaluation data. model (torch.nn.Module): The PyTorch model to evaluate. device (torch.device): The device to run the computation on. loss_func (callable): Loss function (e.g., MSE, Cross-Entropy). use_jacob (bool): Whether or not to use Jacobian loss. Defaults to.
- Returns:
float: The average evaluation loss across all batches.
- evaluate_unsupervised(loader, model, device, loss_func, use_jacob=False, use_inversion_loss=False, use_inversion_diff_loss=False, use_area_loss=False, use_convex_loss=False, use_add_random_query=True, finite_difference_grad=True, weight_area_loss=1, weight_deform_loss=1, weight_eq_residual_loss=1, weight_chamfer_loss=1, scaler=100)[source]¶
Evaluates a model using the given data loader and loss function.
- Args:
loader (DataLoader): DataLoader object for the evaluation data. model (torch.nn.Module): The PyTorch model to evaluate. device (torch.device): The device to run the computation on. loss_func (callable): Loss function (e.g., MSE, Cross-Entropy). use_jacob (bool): Whether or not to use Jacobian loss. Defaults to.
- Returns:
float: The average evaluation loss across all batches.
- get_face_area(coord, face)[source]¶
- Calculates the area of a face. using formula:
area = 0.5 * (x1(y2 - y3) + x2(y3 - y1) + x3(y1 - y2))
- Args:
coord (torch.Tensor): The coordinates. face (torch.Tensor): The face tensor.
- get_inversion_diff_loss(out_coord, tar_coord, face, batch_size, scaler=100)[source]¶
Calculates the inversion difference loss for a batch of meshes. That is the difference between the output area and the input area, in terms of the invereted elements. Args:
out_coord (torch.Tensor): The output coordinates. tar_coord (torch.Tensor): The target coordinates. face (torch.Tensor): The face tensor. batch_size (int): The batch size. alpha (float): The loss weight.
- get_inversion_loss(out_coord, in_coord, face, batch_size, scheme='relu', scaler=100)[source]¶
Calculates the inversion loss for a batch of meshes. Args:
out_coord (torch.Tensor): The output coordinates. in_coord (torch.Tensor): The input coordinates. face (torch.Tensor): The face tensor. batch_size (int): The batch size. alpha (float): The loss weight.
- get_inversion_node_loss(out_coord, tar_coord, face, batch_size, scaler=1000)[source]¶
Calculates the loss between the ouput node and input node, for the inverted elements. This will penalise the node which are involved in the tangled elements. Args:
out_coord (torch.Tensor): The output coordinates. tar_coord (torch.Tensor): The target coordinates. face (torch.Tensor): The face tensor. batch_size (int): The batch size. alpha (float): The loss weight.
- get_sample_tangle(out_coords, in_coords, face)[source]¶
Return the number of tangled elements in a single sample.
- load_model(model, weight_path, strict=False)[source]¶
Loads pre-trained weights into a PyTorch model from a given file path.
- Args:
model (torch.nn.Module): The PyTorch model. weight_path (str): File path to the pre-trained model weights.
- Returns:
torch.nn.Module: The model with loaded weights.
- train(loader, model, optimizer, device, loss_func, use_jacob=False, use_inversion_loss=False, use_inversion_diff_loss=False, use_area_loss=False, weight_deform_loss=1.0, weight_area_loss=1.0, weight_chamfer_loss=0.0, scaler=100)[source]¶
- Trains a PyTorch model using the given data loader, optimizer,
and loss function.
- Args:
loader (DataLoader): DataLoader object for the training data. model (torch.nn.Module): The PyTorch model to train. optimizer (Optimizer): The optimizer (e.g., Adam, SGD). device (torch.device): The device to run the computation on. loss_func (callable): Loss function (e.g., MSE, Cross-Entropy). use_jacob (bool): Whether or not to use Jacobian loss.
- Returns:
float: The average training loss across all batches.
- train_unsupervised(loader, model, optimizer, device, loss_func, use_jacob=False, use_inversion_loss=False, use_inversion_diff_loss=False, use_area_loss=False, use_convex_loss=False, use_add_random_query=True, finite_difference_grad=True, weight_area_loss=1, weight_deform_loss=1, weight_chamfer_loss=1, weight_eq_residual_loss=1, scaler=100)[source]¶
- Trains a PyTorch model using the given data loader, optimizer,
and loss function.
- Args:
loader (DataLoader): DataLoader object for the training data. model (torch.nn.Module): The PyTorch model to train. optimizer (Optimizer): The optimizer (e.g., Adam, SGD). device (torch.device): The device to run the computation on. loss_func (callable): Loss function (e.g., MSE, Cross-Entropy). use_jacob (bool): Whether or not to use Jacobian loss.
- Returns:
float: The average training loss across all batches.
UM2N.model.transformer_model module¶
- class MLP_model(input_channels, output_channels, list_hiddens=[128, 128], hidden_act='LeakyReLU', output_act='LeakyReLU', input_norm=None, dropout_prob=0.0)[source]¶
Bases:
Module
Note that list_hiddens should be a list of hidden channels per MLP layers e.g. [64 128 64]
- Args:
input_channels (_type_): _description_ list_hiddens (_type_): _description_ output_channels (_type_): _description_ hidden_act (str, optional): _description_. Defaults to “LeakyReLU”. output_act (str, optional): _description_. Defaults to “LeakyReLU”. dropout_prob (float, optional): _description_. Defaults to 0.0. input_norm (float, optional): one of [“BatchNorm1d”, “LayerNorm”]
- forward(x)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class TransformerBlock(embed_dim, num_heads, dense_ratio=4, list_dropout=[0.1, 0.1, 0.1], activation='GELU')[source]¶
Bases:
Module
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x, k, v, x_cls=None, key_padding_mask=None, attn_mask=None, return_attn=False)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class TransformerModel(*, input_dim, embed_dim, output_dim, num_heads=4, num_layers=3)[source]¶
Bases:
Module
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x, k, v, key_padding_mask=None, attention_mask=None)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.