Source code for UM2N.model.train_util

# Author: Chunyang Wang
# GitHub Username: acse-cw1722
# Modified by Mingrui Zhang

import numpy as np
import torch
import torch.nn as nn

# FIXME: Hack to handle the case where pytorch3d cannot be installed (#41)
try:
    from pytorch3d.loss import chamfer_distance
except ImportError:

[docs] def chamfer_distance(*args): raise ImportError( "chamfer_distance depends on pytorch3d, which cannot be imported" )
from torch_geometric.loader import DataLoader from torch_geometric.nn import MessagePassing, knn_graph __all__ = [ "train", "train_unsupervised", "evaluate", "evaluate_unsupervised", "load_model", "TangleCounter", "count_dataset_tangle", "get_jacob_det", "get_inversion_diff_loss", "get_face_area", "count_dataset_tangle", "get_jacob_det", "get_face_area", "get_inversion_loss", "get_inversion_node_loss", "get_area_loss", "evaluate_repeat_sampling", "count_dataset_tangle_repeat_sampling", "evaluate_repeat", "get_sample_tangle", ] device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
[docs] def get_face_area(coord, face): """ Calculates the area of a face. using formula: area = 0.5 * (x1(y2 - y3) + x2(y3 - y1) + x3(y1 - y2)) Args: coord (torch.Tensor): The coordinates. face (torch.Tensor): The face tensor. """ x = coord[:, 0][face] y = coord[:, 1][face] area = 0.5 * ( x[0, :] * (y[1, :] - y[2, :]) + x[1, :] * (y[2, :] - y[0, :]) + x[2, :] * (y[0, :] - y[1, :]) ) return area
[docs] def get_inversion_loss( out_coord, in_coord, face, batch_size, scheme="relu", scaler=100 ): """ Calculates the inversion loss for a batch of meshes. Args: out_coord (torch.Tensor): The output coordinates. in_coord (torch.Tensor): The input coordinates. face (torch.Tensor): The face tensor. batch_size (int): The batch size. alpha (float): The loss weight. """ loss = None out_area = get_face_area(out_coord, face) in_area = get_face_area(in_coord, face) # restore the sign of the area, ans scale it out_area = torch.sign(in_area) * out_area # hard penalty, use hard condition to penalize the negative area if scheme == "hard": # mask for negative area neg_mask = out_area < 0 neg_area = out_area[neg_mask] tar_area = in_area[neg_mask] # loss should be positive, so we are using -1 here. loss = -1 * ((neg_area / torch.abs(tar_area)).sum()) / batch_size # soft penalty, peanlize the negative area harder than the positive area elif scheme == "relu": loss = torch.nn.ReLU()(-1 * (out_area / torch.abs(in_area))).sum() / batch_size elif scheme == "log": epsilon = 1e-8 loss = ( -1 * torch.log(-1 * (out_area / torch.abs(in_area))).sum() + epsilon ) / batch_size return scaler * loss
[docs] def get_inversion_diff_loss(out_coord, tar_coord, face, batch_size, scaler=100): """ Calculates the inversion difference loss for a batch of meshes. That is the difference between the output area and the input area, in terms of the invereted elements. Args: out_coord (torch.Tensor): The output coordinates. tar_coord (torch.Tensor): The target coordinates. face (torch.Tensor): The face tensor. batch_size (int): The batch size. alpha (float): The loss weight. """ out_area = get_face_area(out_coord, face) tar_area = get_face_area(tar_coord, face) # restore the sign of the area, ans scale it out_area = scaler * torch.sign(tar_area) * out_area tar_area = scaler * torch.sign(tar_area) * tar_area # mask for negative area neg_mask = out_area < 0 inversion_diff = tar_area[neg_mask] - out_area[neg_mask] # loss should be positive, so we are using -1 here. loss = inversion_diff.sum() / batch_size return loss
[docs] def get_inversion_node_loss(out_coord, tar_coord, face, batch_size, scaler=1000): """ Calculates the loss between the ouput node and input node, for the inverted elements. This will penalise the node which are involved in the tangled elements. Args: out_coord (torch.Tensor): The output coordinates. tar_coord (torch.Tensor): The target coordinates. face (torch.Tensor): The face tensor. batch_size (int): The batch size. alpha (float): The loss weight. """ loss = torch.nn.L1Loss() out_area = get_face_area(out_coord, face) tar_area = get_face_area(tar_coord, face) # restore the sign of the area, ans scale it out_area = torch.sign(tar_area) * out_area tar_area = torch.sign(tar_area) * tar_area # mask for negative area neg_mask = out_area < 0 neg_face = face[:, neg_mask] neg_face = neg_face.reshape(-1) inv_nodes = out_coord[neg_face] tar_nodes = tar_coord[neg_face] node_diff = scaler * loss(inv_nodes, tar_nodes) loss = node_diff / batch_size zero = torch.tensor( 0.0, device=out_coord.device, dtype=out_coord.dtype, requires_grad=True ) return loss if len(inv_nodes) > 0 else zero
[docs] def get_area_loss(out_coord, tar_coord, face, batch_size, scaler=100): out_area = get_face_area(out_coord, face) tar_area = get_face_area(tar_coord, face) # restore the sign of the area, ans scale it out_area = scaler * torch.sign(tar_area) * out_area tar_area = scaler * torch.sign(tar_area) * tar_area # mask for negative area area_diff = torch.abs(tar_area - out_area) # area_diff = tar_area - out_area + 100 # loss should be positive, so we are using -1 here. loss = (area_diff.sum()) / batch_size return loss
[docs] def jacobLoss(model, out, data, loss_func): jacob_det = get_jacob_det(model, data) u_loss = loss_func(out, data.y) jacob_det_loss = loss_func(jacob_det, data.jacobian_det) # print("using jacobian") print("u_loss: ", u_loss.item()) print("jacob_det_loss: ", jacob_det_loss.item()) return u_loss + 0.01 * jacob_det_loss
[docs] def get_jacob_det(model, in_data): in_data.x.requires_grad_(True) in_data.mesh_feat.requires_grad_(True) out = model(in_data.to(device)) dXdx1 = torch.autograd.grad( outputs=out[:, 0], inputs=in_data.x, retain_graph=True, create_graph=False, grad_outputs=torch.ones_like(out[:, 0]), )[0][:, :2] dXdy1 = torch.autograd.grad( outputs=out[:, 1], inputs=in_data.x, retain_graph=True, create_graph=False, grad_outputs=torch.ones_like(out[:, 1]), )[0][:, :2] dXdx2 = torch.autograd.grad( outputs=out[:, 0], inputs=in_data.mesh_feat, retain_graph=True, create_graph=False, grad_outputs=torch.ones_like(out[:, 0]), )[0][:, :2] dXdy2 = torch.autograd.grad( outputs=out[:, 1], inputs=in_data.mesh_feat, retain_graph=True, create_graph=False, grad_outputs=torch.ones_like(out[:, 1]), )[0][:, :2] dXdx = dXdx1 + dXdx2 dXdy = dXdy1 + dXdy2 jacobian = torch.stack([dXdx, dXdy], dim=1) determinant = ( jacobian[:, 0, 0] * jacobian[:, 1, 1] - jacobian[:, 0, 1] * jacobian[:, 1, 0] ) return determinant
[docs] class TangleCounter(MessagePassing): """ A PyTorch Geometric Message Passing class for counting tangles in the mesh. This class is deprecated, do not use this option unless you know what you are doing. """ def __init__(self, num_feat=10, out=16): super().__init__(aggr="add") self.num_tangle = 0
[docs] def forward(self, x, x_new, edge_index): self.propagate(edge_index, x=x, x_new=x_new) return self.num_tangle
[docs] def message(self, x_i, x_j, x_new_i, x_new_j): coord_dim = 2 x_i = x_i[:, :coord_dim] x_j = x_j[:, :coord_dim] x_new_i = x_new_i[:, :coord_dim] x_new_j = x_new_j[:, :coord_dim] diff_x = x_i - x_j diff_x_new = x_new_i - x_new_j tangle = diff_x * diff_x_new tangle_arr = tangle.reshape(-1) self.num_tangle = torch.sum(tangle_arr < 0) return tangle
# def count_dataset_tangle(dataset, model, device, method="inversion"): # """ # Computes the average number of tangles in a dataset. # Args: # dataset (Dataset): The PyTorch Geometric dataset. # model (torch.nn.Module): The PyTorch model. # device (torch.device): The device to run the computation. # Returns: # float: The average number of tangles in the dataset. # """ # model.eval() # num_tangle = 0 # if (method == "inversion"): # loader = DataLoader(dataset=dataset, batch_size=1, # shuffle=False) # for data in loader: # with torch.no_grad(): # output_data = model(data.to(device)) # out_area = get_face_area(output_data, data.face) # in_area = get_face_area(data.x[:, :2], data.face) # # restore the sign of the area # out_area = torch.sign(in_area) * out_area # # mask for negative area # neg_mask = out_area < 0 # neg_area = out_area[neg_mask] # # calculate the loss, we want it normalized by the batch size # # and loss should be positive, so we are using -1 here. # num_tangle += len(neg_area) # return num_tangle / len(dataset) # # deprecated, do not use this option unless you know what you are doing # elif (method == "msg"): # for i in range(len(dataset)): # data = dataset[i].to(device) # with torch.no_grad(): # output_data = model(data) # input_edge = data.edge_index # mesh = data.x[:, :2] # mesh_new = output_data # Counter = TangleCounter() # num_tangle += Counter(mesh, mesh_new, input_edge).item() # num_tangle = num_tangle / len(dataset) # return num_tangle # def count_dataset_tangle(dataset, model, device, method="inversion"): # """ # Computes the average number of tangles in a dataset. # Args: # dataset (Dataset): The PyTorch Geometric dataset. # model (torch.nn.Module): The PyTorch model. # device (torch.device): The device to run the computation. # Returns: # float: The average number of tangles in the dataset. # """ # model.eval() # num_tangle = 0 # if (method == "inversion"): # loader = DataLoader(dataset=dataset, batch_size=1, # shuffle=False) # for data in loader: # (output_coord, model_raw_output, out_monitor), (phix, phiy) = model(data.to(device)) # # Compute the new mesh coord given model output phi # bs = 1 # feat_dim = data.mesh_feat.shape[-1] # # mesh_feat [coord_x, coord_y, u, hessian_norm] # node_num = data.mesh_feat.view(bs, -1, feat_dim).shape[1] # out_area = get_face_area(output_coord, data.face) # in_area = get_face_area(data.x[:, :2], data.face) # # restore the sign of the area # out_area = torch.sign(in_area) * out_area # # mask for negative area # neg_mask = out_area < 0 # neg_area = out_area[neg_mask] # # calculate the loss, we want it normalized by the batch size # # and loss should be positive, so we are using -1 here. # num_tangle += len(neg_area) # return num_tangle / len(dataset) # # deprecated, do not use this option unless you know what you are doing # elif (method == "msg"): # for i in range(len(dataset)): # data = dataset[i].to(device) # with torch.no_grad(): # output_data = model(data) # input_edge = data.edge_index # mesh = data.x[:, :2] # mesh_new = output_data # Counter = TangleCounter() # num_tangle += Counter(mesh, mesh_new, input_edge).item() # num_tangle = num_tangle / len(dataset) # return num_tangle
[docs] def train( loader, model, optimizer, device, loss_func, use_jacob=False, use_inversion_loss=False, use_inversion_diff_loss=False, use_area_loss=False, weight_deform_loss=1.0, weight_area_loss=1.0, weight_chamfer_loss=0.0, scaler=100, ): """ Trains a PyTorch model using the given data loader, optimizer, and loss function. Args: loader (DataLoader): DataLoader object for the training data. model (torch.nn.Module): The PyTorch model to train. optimizer (Optimizer): The optimizer (e.g., Adam, SGD). device (torch.device): The device to run the computation on. loss_func (callable): Loss function (e.g., MSE, Cross-Entropy). use_jacob (bool): Whether or not to use Jacobian loss. Returns: float: The average training loss across all batches. """ bs = loader.batch_size model.train() total_loss = 0 total_deform_loss = 0 total_inversion_loss = 0 total_inversion_diff_loss = 0 total_area_loss = 0 total_chamfer_loss = 0.0 for batch in loader: optimizer.zero_grad() data = batch.to(device) out = model(data) loss = 0 inversion_loss = 0 deform_loss = 0 inversion_diff_loss = 0 area_loss = 0 # deformation loss deform_loss = 1000 * ( loss_func(out, data.y) if not use_jacob else jacobLoss(model, out, data, loss_func) ) # Inversion loss if use_inversion_loss: inversion_loss = get_inversion_loss( out, data.y, data.face, batch_size=bs, scaler=scaler ) if use_area_loss: area_loss = get_area_loss(out, data.y, data.face, bs, scaler) chamfer_loss = 100 * chamfer_distance(out.unsqueeze(0), data.y.unsqueeze(0))[0] loss = ( weight_deform_loss * deform_loss + inversion_loss + inversion_diff_loss + weight_area_loss * area_loss + weight_chamfer_loss * chamfer_loss ) # Jacobian loss if use_jacob: loss.backward(retain_graph=True) else: loss.backward() # print_parameter_grad(model) optimizer.step() total_loss += loss.item() total_deform_loss += weight_deform_loss * deform_loss.item() total_inversion_loss += inversion_loss.item() if use_inversion_loss else 0 # noqa total_inversion_diff_loss += ( inversion_diff_loss.item() if use_inversion_diff_loss else 0 ) # noqa total_area_loss += weight_area_loss * area_loss.item() if use_area_loss else 0 total_chamfer_loss += weight_chamfer_loss * chamfer_loss.item() res = { "total_loss": total_loss / len(loader), "deform_loss": total_deform_loss / len(loader), } if use_inversion_loss: res["inversion_loss"] = total_inversion_loss / len(loader) if use_inversion_diff_loss: res["inversion_diff_loss"] = total_inversion_diff_loss / len(loader) if use_area_loss: res["area_loss"] = total_area_loss / len(loader) res["chamfer_loss"] = total_chamfer_loss / len(loader) return res
[docs] def evaluate( loader, model, device, loss_func, use_jacob=False, use_inversion_loss=False, use_inversion_diff_loss=False, use_area_loss=False, weight_deform_loss=1.0, weight_area_loss=1.0, weight_chamfer_loss=0.0, scaler=100, ): """ Evaluates a model using the given data loader and loss function. Args: loader (DataLoader): DataLoader object for the evaluation data. model (torch.nn.Module): The PyTorch model to evaluate. device (torch.device): The device to run the computation on. loss_func (callable): Loss function (e.g., MSE, Cross-Entropy). use_jacob (bool): Whether or not to use Jacobian loss. Defaults to. Returns: float: The average evaluation loss across all batches. """ bs = loader.batch_size model.eval() total_loss = 0 total_deform_loss = 0 total_inversion_loss = 0 total_inversion_diff_loss = 0 total_area_loss = 0 total_chamfer_loss = 0.0 for batch in loader: data = batch.to(device) loss = 0 deform_loss = 0 inversion_loss = 0 inversion_diff_loss = 0 area_loss = 0 with torch.no_grad(): out = model(data) deform_loss = 1000 * ( loss_func(out, data.y) if not use_jacob else jacobLoss(model, out, data, loss_func) ) inversion_loss = 0 if use_inversion_loss: inversion_loss = get_inversion_loss( out, data.y, data.face, batch_size=bs, scaler=scaler ) if use_area_loss: area_loss = get_area_loss(out, data.y, data.face, bs, scaler) chamfer_loss = ( 100 * chamfer_distance(out.unsqueeze(0), data.y.unsqueeze(0))[0] ) loss = ( weight_deform_loss * deform_loss + inversion_loss + inversion_diff_loss + weight_area_loss * area_loss + weight_chamfer_loss * chamfer_loss ) total_loss += loss.item() total_deform_loss += weight_deform_loss * deform_loss.item() total_inversion_diff_loss += ( inversion_diff_loss.item() if use_inversion_diff_loss else 0 ) # noqa total_inversion_loss += inversion_loss.item() if use_inversion_loss else 0 # noqa total_area_loss += ( weight_area_loss * area_loss.item() if use_area_loss else 0 ) total_chamfer_loss += weight_chamfer_loss * chamfer_loss.item() res = { "total_loss": total_loss / len(loader), "deform_loss": total_deform_loss / len(loader), } if use_inversion_loss: res["inversion_loss"] = total_inversion_loss / len(loader) if use_inversion_diff_loss: res["inversion_diff_loss"] = total_inversion_diff_loss / len(loader) if use_area_loss: res["area_loss"] = total_area_loss / len(loader) res["chamfer_loss"] = total_chamfer_loss / len(loader) return res
[docs] def interpolate(u, ori_mesh_x, ori_mesh_y, moved_x, moved_y): """ u: [bs, node_num, 1] ori_mesh_x: [bs, node_num, 1] ori_mesh_y: [bs, node_num, 1] moved_x: [bs, node_num, 1] moved_y: [bs, node_num, 1] Note: node_num equals to sample_num """ batch_size = u.shape[0] sample_num = u.shape[1] # print(f"batch size: {batch_size}, sample num: {sample_num}") u_interpolateds = [] for bs in range(batch_size): # For a sample point of interest, we need to do a weighted summation over all other sample points # To avoid using a loop, we expand an additonal dim of size sample_num original_mesh = torch.cat((ori_mesh_x[bs], ori_mesh_y[bs]), dim=-1) moved_mesh = ( torch.cat((moved_x[bs], moved_y[bs]), dim=-1) .unsqueeze(-2) .repeat(1, sample_num, 1) ) # print(f"new mesh shape {moved_mesh.shape}, original mesh shape {original_mesh.shape}") # print((moved_mesh - original_mesh),(moved_mesh - original_mesh).shape) # print("check dimension ", (moved_mesh - original_mesh)[:, 0]) # The second dimension of distance is the different sample points distance = -torch.norm(moved_mesh - original_mesh, dim=-1) * np.sqrt(sample_num) # print('raw distance ', torch.norm(moved_mesh - original_mesh, dim=-1)) # print('distance ', torch.norm(moved_mesh - original_mesh, dim=-1)* np.sqrt(sample_num)) normalize = nn.Softmax(dim=-1) weight = normalize(distance) # print('weight shape ', weight.shape, u[bs].shape) # print('weight ', weight, u, u[bs].permute(1, 0) * weight) # print(u.shape, weight.shape) u_interpolateds.append( torch.sum(u[bs].permute(1, 0) * weight, dim=-1).unsqueeze(-1) ) # print(f"interpolated shape: {u_interpolateds[-1]}") # print('inte ', u_interpolated) return torch.stack(u_interpolateds, dim=0)
def _generate_samples( num_meshes, num_samples_per_mesh, coords, solution, monitor, redundant_sample_ratio=10, device="cuda", ): meshes = torch.tensor( np.random.uniform( 0, 1, (num_meshes, redundant_sample_ratio * num_samples_per_mesh, 2) ), dtype=torch.float, ).to(device) solution_input = solution.repeat(num_meshes, 1, 1) monitor_input = monitor.repeat(num_meshes, 1, 1) coords_x = coords[:, :, 0].unsqueeze(-1).repeat(num_meshes, 1, 1) coords_y = coords[:, :, 1].unsqueeze(-1).repeat(num_meshes, 1, 1) new_meshes_x = meshes[:, :, 0].unsqueeze(-1) new_meshes_y = meshes[:, :, 1].unsqueeze(-1) solutions = interpolate( solution_input, coords_x, coords_y, new_meshes_x, new_meshes_y ) monitors = interpolate( monitor_input, coords_x, coords_y, new_meshes_x, new_meshes_y ) meshes_ = [] soluitons_ = [] monitors_ = [] # resample according to the monitor values for bs in range(monitors.shape[0]): prob = monitors[bs, :, 0] / torch.sum(monitors[bs, :, 0]) index = np.random.choice( a=meshes.shape[1], size=num_samples_per_mesh, replace=False, p=prob.cpu().numpy(), ) # print(torch.max(prob), torch.min(prob), torch.max(monitors), torch.min(monitors)) meshes_.append(meshes[bs, index, :]) soluitons_.append(solutions[bs, index, :]) monitors_.append(monitors[bs, index, :]) return ( torch.stack(meshes_, dim=0), torch.stack(soluitons_, dim=0), torch.stack(monitors_, dim=0), )
[docs] def generate_samples(bs, num_samples_per_mesh, data, num_meshes=5, device="cuda"): # num_meshes = 5 # num_nodes = coord_ori.shape[-2] // bs # samples_q = data.mesh_feat[:, :4].view(bs, -1, 4) meshes_collection = [] solutions_collection = [] monitors_collection = [] for b in range(bs): coords = data.mesh_feat[:, :2].view(bs, -1, 2)[b, :, :].view(1, -1, 2) solution = data.mesh_feat[:, 2].view(bs, -1, 1)[b, :, :].view(1, -1, 1) monitor = data.mesh_feat[:, 3].view(bs, -1, 1)[b, :, :].view(1, -1, 1) meshes, solutions, monitors = _generate_samples( num_meshes=num_meshes, num_samples_per_mesh=num_samples_per_mesh, coords=coords, solution=solution, monitor=monitor, device=device, ) # print(f"output meshes: {meshes.shape} solutions: {solutions.shape} monitor: {monitors.shape}") # merge the addtional sampled attributes (mesh, solution, monitor) to a large graph within one sample meshes_collection.append( torch.cat([coords.view(-1, 2), meshes.view(-1, 2)], dim=0) ) solutions_collection.append( torch.cat([solution.view(-1, 1), solutions.view(-1, 1)], dim=0) ) monitors_collection.append( torch.cat([monitor.view(-1, 1), monitors.view(-1, 1)], dim=0) ) # merge all enhanced sampled samples along the batch size dimension meshes_input = torch.stack(meshes_collection, dim=0) solutions_input = torch.stack(solutions_collection, dim=0) monitors_input = torch.stack(monitors_collection, dim=0) # merge all batched sampled attributes along feature dim for transformer key and value samples_kv = torch.cat([meshes_input, solutions_input, monitors_input], dim=-1) return samples_kv
[docs] def compute_finite_difference(field): # Field: [bs, x_shape, y_shape] f_x = torch.zeros_like(field) f_x[:, :-1, :] = torch.diff(field, dim=-2) f_x[:, -1, :] = f_x[:, -2, :] f_y = torch.zeros_like(field) f_y[:, :, :-1] = torch.diff(field, dim=-1) f_y[:, :, -1] = f_y[:, :, -2] inv_dx = field.shape[-2] - 1 inv_dy = field.shape[-1] - 1 return f_x * inv_dx, f_y * inv_dy
[docs] def generate_samples_structured_grid(coords, field, grid_resolution=100, device="cuda"): num_meshes = coords.shape[0] nx = grid_resolution ny = grid_resolution x = np.linspace(0, 1, nx) y = np.linspace(0, 1, ny) uniform_grid = ( torch.tensor(np.array(np.meshgrid(x, y)), dtype=torch.float) .reshape(1, 2, -1) .repeat(num_meshes, 1, 1) .permute(0, 2, 1) .to(device) ) field_input = field.view(num_meshes, -1, field.shape[-1]) coords_x = coords[:, :, 0].unsqueeze(-1) coords_y = coords[:, :, 1].unsqueeze(-1) new_meshes_x = uniform_grid[:, :, 0].unsqueeze(-1) new_meshes_y = uniform_grid[:, :, 1].unsqueeze(-1) # Interpolate to dense structured grid field = interpolate(field_input, coords_x, coords_y, new_meshes_x, new_meshes_y) field_x_, field_y_ = compute_finite_difference( field.view(field.shape[0], grid_resolution, grid_resolution) ) field_x_ = field_x_.view(num_meshes, -1, 1) field_y_ = field_y_.view(num_meshes, -1, 1) # Interpolate back to original mesh field_x = interpolate(field_x_, new_meshes_x, new_meshes_y, coords_x, coords_y) field_y = interpolate(field_y_, new_meshes_x, new_meshes_y, coords_x, coords_y) # print(f"coords shape {coords.shape} interp field x {field_x.shape} interp field y {field_y.shape}") return uniform_grid, field, field_x_, field_y_, field_x, field_y
[docs] def construct_graph(sampled_coords, num_neighbors=6, device="cuda"): bs = sampled_coords.shape[0] num_per_mesh = sampled_coords.shape[1] batch = ( torch.tensor([x for x in range(bs)]) .unsqueeze(-1) .repeat(1, num_per_mesh) .reshape(-1) .to(device) ) edge_index = knn_graph( sampled_coords.view(-1, 2).to(device), k=num_neighbors, batch=batch, loop=False ) return edge_index
[docs] def compute_phi_hessian( mesh_query_x, mesh_query_y, phix, phiy, out_monitor, bs, data, loss_func, finite_difference_grad=False, ): feat_dim = data.mesh_feat.shape[-1] node_num = data.mesh_feat.view(bs, -1, feat_dim).shape[1] sampled_num = mesh_query_x.shape[0] // bs sampled_ratio = sampled_num // node_num # equation residual loss loss_eq_residual = torch.tensor(0.0) # Convex loss loss_convex = torch.tensor(0.0) if phix is not None and phiy is not None: if not finite_difference_grad: # print(f"phix: {phix.shape}, phiy: {phiy.shape}") hessian_seed = torch.ones(phix.shape).to(device) phixx = torch.autograd.grad( phix, mesh_query_x, grad_outputs=hessian_seed, retain_graph=True, create_graph=True, allow_unused=True, )[0] phixy = torch.autograd.grad( phix, mesh_query_y, grad_outputs=hessian_seed, retain_graph=True, create_graph=True, allow_unused=True, )[0] phiyx = torch.autograd.grad( phiy, mesh_query_x, grad_outputs=hessian_seed, retain_graph=True, create_graph=True, allow_unused=True, )[0] phiyy = torch.autograd.grad( phiy, mesh_query_y, grad_outputs=hessian_seed, retain_graph=True, create_graph=True, allow_unused=True, )[0] else: # TODO: the finite difference method here returns the (u_y, u_x), whose order is different from ad _, _, _, _, phixy, phixx = generate_samples_structured_grid( torch.stack([mesh_query_x, mesh_query_y], dim=-1), phix ) _, _, _, _, phiyy, phiyx = generate_samples_structured_grid( torch.stack([mesh_query_x, mesh_query_y], dim=-1), phiy ) # print(f"phix grad: {phix_grad.shape}, phiy grad: {phiy_grad.shape}") # phixx = phix_grad[:, 0] # phixy = phix_grad[:, 1] # phiyx = phiy_grad[:, 0] # phiyy = phiy_grad[:, 1] # print(f"phixx grad: {torch.sum(phixx)}, phixy grad: {torch.sum(phixy)}, phiyx grad: {torch.sum(phiyx)}, phiyy grad: {torch.sum(phiyy)}") det_hessian = (phixx + 1) * (phiyy + 1) - phixy * phiyx det_hessian = det_hessian.view(bs, sampled_num, 1) # jacobian_x = data.mesh_feat[:, 4].view(bs, node_num, 1) # jacobian_y = data.mesh_feat[:, 5].view(bs, node_num, 1) hessian_norm = ( data.mesh_feat[:, 3].view(bs, node_num, 1).repeat(1, sampled_ratio, 1) ) # solution = data.mesh_feat[:, 1].resahpe(bs, node_num, 1) # original_mesh_x = data.mesh_feat[:, 0].view(bs, node_num, 1) # original_mesh_y = data.mesh_feat[:, 1].view(bs, node_num, 1) moved_x = phix.view(bs, sampled_num, 1) + mesh_query_x.view(bs, sampled_num, 1) moved_y = phiy.view(bs, sampled_num, 1) + mesh_query_y.view(bs, sampled_num, 1) # print(f"diff x:{torch.abs(original_mesh_x - moved_x).mean()}, diff y:{torch.abs(original_mesh_y - moved_y).mean()}") # Interpolate on new moved mesh # print(hessian_norm.shape, mesh_query_x.shape, moved_x.shape) hessian_norm_ = interpolate( hessian_norm, mesh_query_x.view(bs, sampled_num, 1), mesh_query_y.view(bs, sampled_num, 1), moved_x, moved_y, ) enhanced_hessian_norm = hessian_norm_ # + out_monitor.view(bs, node_num, 1) # =========================== jacobian related attempts ================== # jac_x = interpolate(jacobian_x, original_mesh_x, original_mesh_y, moved_x, moved_y) # jac_y = interpolate(jacobian_y, original_mesh_x, original_mesh_y, moved_x, moved_y) # # alpha = torch.sum(torch.sqrt(torch.abs(jac_x)**2 + torch.abs(jac_y)**2), dim=(-1, -2)) / node_num**2 # alpha = 1.0 # phixx = phixx.reshape(bs, node_num, 1) # phiyx = phiyx.reshape(bs, node_num, 1) # phiyy = phiyy.reshape(bs, node_num, 1) # phixy = phixy.reshape(bs, node_num, 1) # jac_xi_1 = jac_x * (1 + phixx) + jac_y * phiyx # jac_xi_2 = jac_x * phixy + jac_y * (1 + phiyy) # print(torch.sum(jac_xi_1), torch.sum(jac_xi_2)) # monitor = monitor_grad(alpha, jac_xi_1, jac_xi_2) /1000 # =========================== # TODO: here the 1 + 5 * heesian_norm is the monitor function used when generating dataset lhs = (1 + 5 * enhanced_hessian_norm) * det_hessian rhs = torch.sum(hessian_norm, dim=(1, 2)) / sampled_num rhs = rhs.unsqueeze(-1).repeat(1, sampled_num).unsqueeze(-1) loss_eq_residual = 1000 * loss_func(lhs, rhs) # print(torch.sum(hessian_norm_ - hessian_norm), hessian_norm_.shape, det_hessian.shape, lhs.shape, rhs.shape) # print(f"diff between interpolation jac x {torch.sum(jacobian_x - jac_x)} alpha: {alpha} monitor: {torch.sum(monitor)} det_hessian {torch.sum(det_hessian)} lhs {torch.sum(lhs)} rhs {torch.sum(rhs)}") # Convex loss # if use_convex_loss: loss_convex = torch.mean( torch.min(torch.tensor(0).type_as(phixx).to(device), 1 + phixx) ** 2 + torch.min(torch.tensor(0).type_as(phiyy).to(device), 1 + phiyy) ** 2 ) return ( loss_eq_residual, loss_convex, )
[docs] def model_forward(bs, data, model, use_add_random_query=True): # Create mesh query for deformer, seperate from the original mesh as feature for encoder mesh_query_x = data.mesh_feat[:, 0].view(-1, 1).detach().clone() mesh_query_y = data.mesh_feat[:, 1].view(-1, 1).detach().clone() mesh_query_x.requires_grad = True mesh_query_y.requires_grad = True mesh_query = torch.cat([mesh_query_x, mesh_query_y], dim=-1) num_nodes = mesh_query.shape[-2] // bs coord_ori_x = data.mesh_feat[:, 0].view(-1, 1) coord_ori_y = data.mesh_feat[:, 1].view(-1, 1) coord_ori_x.requires_grad = True coord_ori_y.requires_grad = True coord_ori = torch.cat([coord_ori_x, coord_ori_y], dim=-1) if use_add_random_query: # Generate random mesh queries for unsupervised learning sampled_queries = generate_samples( bs=bs, num_samples_per_mesh=num_nodes, num_meshes=1, data=data, device=device, ) sampled_queries_edge_index = construct_graph( sampled_queries[:, :, :2], num_neighbors=6 ) mesh_sampled_queries_x = sampled_queries[:, :, 0].view(-1, 1).detach() mesh_sampled_queries_y = sampled_queries[:, :, 1].view(-1, 1).detach() mesh_sampled_queries_x.requires_grad = True mesh_sampled_queries_y.requires_grad = True mesh_sampled_queries = torch.cat( [mesh_sampled_queries_x, mesh_sampled_queries_y], dim=-1 ).view(-1, 2) num_nodes = coord_ori.shape[-2] // bs # input_kv = generate_samples( # bs=bs, num_samples_per_mesh=num_nodes, data=data, device=device # ) else: mesh_sampled_queries = None # input_kv = None sampled_queries_edge_index = None # input_q = data.mesh_feat[:, :4] input_q = data.mesh_feat # print(f"batch size: {bs}, num_nodes: {num_nodes}, input q", input_q.shape, "input_kv ", input_kv.shape) # if not use_add_random_query: # mesh_sampled_queries = None # sampled_queries_edge_index = None (output_coord_all, output, out_monitor), (phix, phiy) = model( data, input_q, input_q, mesh_query, mesh_sampled_queries, sampled_queries_edge_index, ) if not use_add_random_query: phix = None phiy = None # (output_coord_all, output, out_monitor), (phix, phiy) = model(data, input_q, input_kv, mesh_query, sampled_queries, sampled_queries_edge_index) output_coord = output_coord_all[: num_nodes * bs] # print(output_coord_all.shape, output_coord.shape) # mesh_query_x_all = torch.cat([mesh_query_x, mesh_sampled_queries[:, :, 0].view(-1, 1)], dim=0) # mesh_query_y_all = torch.cat([mesh_query_y, mesh_sampled_queries[:, :, 1].view(-1, 1)], dim=0) if use_add_random_query: mesh_query_x_all = mesh_sampled_queries_x mesh_query_y_all = mesh_sampled_queries_y else: mesh_query_x_all = mesh_query_x mesh_query_y_all = mesh_query_y return ( output_coord, output, out_monitor, phix, phiy, mesh_query_x_all, mesh_query_y_all, )
# def chamfer_distance(mesh_test, mesh_target): # # (batch, node, feature) # batch_size = mesh_test.shape[0] # node_num = mesh_test.shape[1] # chamfer_distance_val = 0.0 # for bs in range(batch_size): # mesh_stack = mesh_test[bs].unsqueeze(-2).repeat(1, node_num, 1) # # print("output mesh ", mesh_stack.shape) # # diff shape: (distance of a node in mesh_test to every node of mesh_target , number of node, feature) # diff = torch.norm(mesh_stack - mesh_target[bs], dim=-1) # min_diff = torch.min(diff, dim=0) # # min_diff[0] values, min_diff[1] index # min_diff_sum = torch.sum(min_diff[0]) # chamfer_distance_val += min_diff_sum / node_num # # diff = mesh_stack - mesh_target[bs] # # print(diff, diff[0]) # # print("min diff ", min_diff) # return chamfer_distance_val / batch_size
[docs] def sample_nodes_by_monitor( meshes, meshes_target, monitors, num_samples_per_mesh=100, random_seed=666 ): # Set random seed np.random.seed(random_seed) # resample according to the monitor values meshes_ = [] meshes_target_ = [] monitors = torch.max(monitors, torch.tensor([0.0]).to(device)) for bs in range(monitors.shape[0]): prob = monitors[bs, :, 0] / torch.sum(monitors[bs, :, 0]) index = np.random.choice( a=meshes.shape[1], size=num_samples_per_mesh, replace=False, p=prob.cpu().numpy(), ) # print(torch.max(prob), torch.min(prob), torch.max(monitors), torch.min(monitors)) meshes_.append(meshes[bs, index, :]) meshes_target_.append(meshes_target[bs, index, :]) test_meshes = torch.stack(meshes_, dim=0) target_meshes = torch.stack(meshes_target_, dim=0) return test_meshes, target_meshes
[docs] def train_unsupervised( loader, model, optimizer, device, loss_func, use_jacob=False, use_inversion_loss=False, use_inversion_diff_loss=False, use_area_loss=False, use_convex_loss=False, use_add_random_query=True, finite_difference_grad=True, weight_area_loss=1, weight_deform_loss=1, weight_chamfer_loss=1, weight_eq_residual_loss=1, scaler=100, ): """ Trains a PyTorch model using the given data loader, optimizer, and loss function. Args: loader (DataLoader): DataLoader object for the training data. model (torch.nn.Module): The PyTorch model to train. optimizer (Optimizer): The optimizer (e.g., Adam, SGD). device (torch.device): The device to run the computation on. loss_func (callable): Loss function (e.g., MSE, Cross-Entropy). use_jacob (bool): Whether or not to use Jacobian loss. Returns: float: The average training loss across all batches. """ bs = loader.batch_size model.train() total_loss = 0 total_eq_residual_loss = 0 total_convex_loss = 0 total_deform_loss = 0 total_inversion_loss = 0 total_inversion_diff_loss = 0 total_area_loss = 0 total_chamfer_loss_loss = 0 for batch in loader: optimizer.zero_grad() data = batch.to(device) ( output_coord, output, out_monitor, phix, phiy, mesh_query_x_all, mesh_query_y_all, ) = model_forward(bs, data, model, use_add_random_query=use_add_random_query) if use_add_random_query: loss_eq_residual, loss_convex = compute_phi_hessian( mesh_query_x_all, mesh_query_y_all, phix, phiy, out_monitor, bs, data, loss_func=loss_func, finite_difference_grad=finite_difference_grad, ) else: loss_eq_residual, loss_convex = torch.tensor(0.0), torch.tensor(0.0) if not use_convex_loss: loss_convex = torch.tensor(0.0) loss = 0 inversion_loss = 0 deform_loss = torch.tensor(0.0) inversion_diff_loss = 0 area_loss = 0 # deformation loss deform_loss = 1000 * ( loss_func(output_coord, data.y) if not use_jacob else jacobLoss(model, output_coord, data, loss_func) ) # Chamfer loss coord_dim = output_coord.shape[-1] # Sampling according to monitor values, sampled 100 nodes per mesh mesh_test_raw = output_coord.view(bs, -1, coord_dim) mesh_target_raw = data.y.view(bs, -1, coord_dim) monitors = data.mesh_feat[:, 2].view(bs, -1, 1) mesh_test_sampled, mesh_target_sampled = sample_nodes_by_monitor( meshes=mesh_test_raw, meshes_target=mesh_target_raw, monitors=monitors ) # chamfer_loss = 100 * (chamfer_distance(mesh_test_sampled, mesh_target_sampled) + chamfer_distance(mesh_target_sampled, mesh_test_sampled)) chamfer_loss = ( 100 * chamfer_distance(output_coord.unsqueeze(0), data.y.unsqueeze(0))[0] ) # print(output_coord.shape, data.y.shape, chamfer_loss) # Inversion loss if use_inversion_loss: inversion_loss = get_inversion_loss( output_coord, data.y, data.face, batch_size=bs, scaler=scaler ) # if use_area_loss: area_loss = get_area_loss(output_coord, data.y, data.face, bs, scaler) loss = ( weight_deform_loss * deform_loss + weight_chamfer_loss * chamfer_loss + inversion_loss + inversion_diff_loss + weight_area_loss * area_loss + weight_eq_residual_loss * loss_eq_residual + loss_convex ) # Jacobian loss if use_jacob: loss.backward(retain_graph=True) else: loss.backward() # print_parameter_grad(model) optimizer.step() total_loss += loss.item() total_eq_residual_loss += loss_eq_residual.item() total_convex_loss += loss_convex.item() if use_convex_loss else 0 total_deform_loss += weight_deform_loss * deform_loss.item() total_chamfer_loss_loss += weight_chamfer_loss * chamfer_loss.item() total_inversion_loss += inversion_loss.item() if use_inversion_loss else 0 # noqa total_inversion_diff_loss += ( inversion_diff_loss.item() if use_inversion_diff_loss else 0 ) # noqa total_area_loss += weight_area_loss * area_loss.item() res = { "total_loss": total_loss / len(loader), "deform_loss": total_deform_loss / len(loader), "equation_residual": total_eq_residual_loss / len(loader), "chamfer_loss": total_chamfer_loss_loss / len(loader), } if use_convex_loss: res["convex_loss"] = total_convex_loss / len(loader) if use_inversion_loss: res["inversion_loss"] = total_inversion_loss / len(loader) if use_inversion_diff_loss: res["inversion_diff_loss"] = total_inversion_diff_loss / len(loader) if use_area_loss: res["area_loss"] = total_area_loss / len(loader) return res
[docs] def evaluate_unsupervised( loader, model, device, loss_func, use_jacob=False, use_inversion_loss=False, use_inversion_diff_loss=False, use_area_loss=False, use_convex_loss=False, use_add_random_query=True, finite_difference_grad=True, weight_area_loss=1, weight_deform_loss=1, weight_eq_residual_loss=1, weight_chamfer_loss=1, scaler=100, ): """ Evaluates a model using the given data loader and loss function. Args: loader (DataLoader): DataLoader object for the evaluation data. model (torch.nn.Module): The PyTorch model to evaluate. device (torch.device): The device to run the computation on. loss_func (callable): Loss function (e.g., MSE, Cross-Entropy). use_jacob (bool): Whether or not to use Jacobian loss. Defaults to. Returns: float: The average evaluation loss across all batches. """ bs = loader.batch_size model.eval() total_loss = 0 total_eq_residual_loss = 0 total_convex_loss = 0 total_deform_loss = 0 total_inversion_loss = 0 total_inversion_diff_loss = 0 total_area_loss = 0 total_chamfer_loss_loss = 0 for batch in loader: data = batch.to(device) loss = 0 deform_loss = 0 inversion_loss = 0 inversion_diff_loss = 0 area_loss = 0 # with torch.no_grad(): ( output_coord, output, out_monitor, phix, phiy, mesh_query_x_all, mesh_query_y_all, ) = model_forward(bs, data, model, use_add_random_query=use_add_random_query) if use_add_random_query: loss_eq_residual, loss_convex = compute_phi_hessian( mesh_query_x_all, mesh_query_y_all, phix, phiy, out_monitor, bs, data, loss_func=loss_func, finite_difference_grad=finite_difference_grad, ) else: loss_eq_residual, loss_convex = torch.tensor(0.0), torch.tensor(0.0) if not use_convex_loss: loss_convex = torch.tensor(0.0) deform_loss = 1000 * ( loss_func(output_coord, data.y) if not use_jacob else jacobLoss(model, output_coord, data, loss_func) ) inversion_loss = 0 if use_inversion_loss: inversion_loss = get_inversion_loss( output_coord, data.y, data.face, batch_size=bs, scaler=scaler ) # if use_area_loss: area_loss = get_area_loss(output_coord, data.y, data.face, bs, scaler) # Chamfer loss coord_dim = output_coord.shape[-1] # Sampling according to monitor values, sampled 100 nodes per mesh mesh_test_raw = output_coord.view(bs, -1, coord_dim) mesh_target_raw = data.y.view(bs, -1, coord_dim) monitors = data.mesh_feat[:, 2].view(bs, -1, 1) mesh_test_sampled, mesh_target_sampled = sample_nodes_by_monitor( meshes=mesh_test_raw, meshes_target=mesh_target_raw, monitors=monitors ) chamfer_loss = ( 100 * chamfer_distance(output_coord.unsqueeze(0), data.y.unsqueeze(0))[0] ) loss = ( weight_deform_loss * deform_loss + weight_chamfer_loss * chamfer_loss + inversion_loss + inversion_diff_loss + weight_area_loss * area_loss + weight_eq_residual_loss * loss_eq_residual + loss_convex ) total_loss += loss.item() total_eq_residual_loss += loss_eq_residual.item() total_chamfer_loss_loss += weight_chamfer_loss * chamfer_loss.item() total_convex_loss += loss_convex.item() if use_convex_loss else 0 total_deform_loss += weight_deform_loss * deform_loss.item() total_inversion_diff_loss += ( inversion_diff_loss.item() if use_inversion_diff_loss else 0 ) # noqa total_inversion_loss += inversion_loss.item() if use_inversion_loss else 0 # noqa total_area_loss += weight_area_loss * area_loss.item() res = { "total_loss": total_loss / len(loader), "deform_loss": total_deform_loss / len(loader), "equation_residual": total_eq_residual_loss / len(loader), "chamfer_loss": total_chamfer_loss_loss / len(loader), } if use_convex_loss: res["convex_loss"] = total_convex_loss / len(loader) if use_inversion_loss: res["inversion_loss"] = total_inversion_loss / len(loader) if use_inversion_diff_loss: res["inversion_diff_loss"] = total_inversion_diff_loss / len(loader) if use_area_loss: res["area_loss"] = total_area_loss / len(loader) return res
[docs] def get_sample_tangle(out_coords, in_coords, face): """ Return the number of tangled elements in a single sample. """ out_area = get_face_area(out_coords, face) in_area = get_face_area(in_coords, face) out_area = torch.sign(in_area) * out_area neg_mask = out_area < 0 neg_area = out_area[neg_mask] num_tangle = len(neg_area) return num_tangle
[docs] def count_dataset_tangle(dataset, model, device, method="inversion"): """ Computes the average number of tangles in a dataset. Args: dataset (Dataset): The PyTorch Geometric dataset. model (torch.nn.Module): The PyTorch model. device (torch.device): The device to run the computation. Returns: float: The average number of tangles in the dataset. """ model.eval() num_tangle = 0 if method == "inversion": loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False) bs = loader.batch_size for data in loader: with torch.no_grad(): data = data.to(device) # # Create mesh query for deformer, seperate from the original mesh as feature for encoder # mesh_query_x = data.mesh_feat[:, 0].view(-1, 1).detach().clone() # mesh_query_y = data.mesh_feat[:, 1].view(-1, 1).detach().clone() # mesh_query_x.requires_grad = True # mesh_query_y.requires_grad = True # mesh_query = torch.cat([mesh_query_x, mesh_query_y], dim=-1) # num_nodes = mesh_query.shape[-2] // bs # # Generate random mesh queries for unsupervised learning # sampled_queries = generate_samples(bs=bs, num_samples_per_mesh=num_nodes, data=data, num_meshes=1, device=device) # sampled_queries_edge_index = construct_graph(sampled_queries[:, :, :2]) # mesh_sampled_queries_x = sampled_queries[:, :, 0].view(-1, 1).detach() # mesh_sampled_queries_y = sampled_queries[:, :, 1].view(-1, 1).detach() # mesh_sampled_queries_x.requires_grad = True # mesh_sampled_queries_y.requires_grad = True # mesh_sampled_queries = torch.cat([mesh_sampled_queries_x, mesh_sampled_queries_y], dim=-1).view(-1, 2) # coord_ori_x = data.mesh_feat[:, 0].view(-1, 1) # coord_ori_y = data.mesh_feat[:, 1].view(-1, 1) # coord_ori_x.requires_grad = True # coord_ori_y.requires_grad = True # coord_ori = torch.cat([coord_ori_x, coord_ori_y], dim=-1) # num_nodes = coord_ori.shape[-2] // bs # input_q = torch.cat([mesh_query, data.mesh_feat[:, 2:4]], dim=-1) # input_kv = generate_samples(bs=bs, num_samples_per_mesh=num_nodes, data=data, device=device) # # print(f"batch size: {bs}, num_nodes: {num_nodes}, input q", input_q.shape, "input_kv ", input_kv.shape) # (output_coord_all, output, out_monitor), (phix, phiy) = model(data.to(device), input_q.to(device), input_q.to(device), mesh_query.to(device), mesh_sampled_queries.to(device), sampled_queries_edge_index) # # (output_coord_all, output, out_monitor), (phix, phiy) = model(data, input_q, input_kv, mesh_query, sampled_queries, sampled_queries_edge_index) # output_data = output_coord_all[:num_nodes*bs] ( output_coord, output, out_monitor, phix, phiy, mesh_query_x_all, mesh_query_y_all, ) = model_forward(bs, data, model) out_area = get_face_area(output_coord, data.face) in_area = get_face_area(data.x[:, :2], data.face) # restore the sign of the area out_area = torch.sign(in_area) * out_area # mask for negative area neg_mask = out_area < 0 neg_area = out_area[neg_mask] # calculate the loss, we want it normalized by the batch size # and loss should be positive, so we are using -1 here. num_tangle += len(neg_area) return num_tangle / len(dataset) # deprecated, do not use this option unless you know what you are doing elif method == "msg": for i in range(len(dataset)): data = dataset[i].to(device) with torch.no_grad(): output_data = model(data) input_edge = data.edge_index mesh = data.x[:, :2] mesh_new = output_data Counter = TangleCounter() num_tangle += Counter(mesh, mesh_new, input_edge).item() num_tangle = num_tangle / len(dataset) return num_tangle
[docs] def evaluate_repeat_sampling( dataset, model, device, loss_func, use_inversion_loss=False, use_inversion_diff_loss=False, use_area_loss=False, scaler=100, batch_size=5, num_samples=1, ): """ Evaluates a model using the given data loader and loss function. Args: loader (DataLoader): DataLoader object for the evaluation data. model (torch.nn.Module): The PyTorch model to evaluate. device (torch.device): The device to run the computation on. loss_func (callable): Loss function (e.g., MSE, Cross-Entropy). use_jacob (bool): Whether or not to use Jacobian loss. Defaults to. Returns: float: The average evaluation loss across all batches. """ model.eval() bs = batch_size loaders = [ DataLoader(dataset=dataset, batch_size=bs, shuffle=False) for i in range(num_samples) ] data_iters = [iter(loader) for loader in loaders] total_loss = 0 total_deform_loss = 0 total_inversion_loss = 0 total_inversion_diff_loss = 0 total_area_loss = 0 for i in range(len(loaders[0])): data_list = [next(data_iter) for data_iter in data_iters] data_list = [data.to(device) for data in data_list] loss = 0 deform_loss = 0 inversion_loss = 0 inversion_diff_loss = 0 area_loss = 0 with torch.no_grad(): out = [model(data) for data in data_list] out = torch.stack(out, dim=0) out = torch.mean(out, dim=0) deform_loss = 1000 * (loss_func(out, data_list[0].y)) if use_area_loss: area_loss = get_area_loss( out, data_list[0].y, data_list[0].face, bs, scaler ) loss = inversion_loss + deform_loss total_loss += loss.item() total_deform_loss += deform_loss.item() total_inversion_diff_loss += ( inversion_diff_loss.item() if use_inversion_diff_loss else 0 ) # noqa total_inversion_loss += inversion_loss.item() if use_inversion_loss else 0 # noqa total_area_loss += area_loss.item() if use_area_loss else 0 res = { "total_loss": total_loss / len(loaders[0]), "deform_loss": total_deform_loss / len(loaders[0]), } if use_inversion_loss: res["inversion_loss"] = total_inversion_loss / len(loaders[0]) if use_inversion_diff_loss: res["inversion_diff_loss"] = total_inversion_diff_loss / len(loaders[0]) # noqa if use_area_loss: res["area_loss"] = total_area_loss / len(loaders[0]) return res
[docs] def count_dataset_tangle_repeat_sampling(dataset, model, device, num_samples=1): """ Computes the average number of tangles in a dataset. Args: dataset (Dataset): The PyTorch Geometric dataset. model (torch.nn.Module): The PyTorch model. device (torch.device): The device to run the computation. Returns: float: The average number of tangles in the dataset. """ model.eval() num_tangle = 0 loaders = [ DataLoader(dataset=dataset, batch_size=1, shuffle=False) for i in range(num_samples) ] data_iters = [iter(loader) for loader in loaders] for i in range(len(loaders[0])): with torch.no_grad(): data_list = [next(data_iter) for data_iter in data_iters] data_list = [data.to(device) for data in data_list] output_data = [model(data.to(device)) for data in data_list] output_data = torch.stack(output_data, dim=0) output_data = torch.mean(output_data, dim=0) out_area = get_face_area(output_data, data_list[0].face) in_area = get_face_area(data_list[0].x[:, :2], data_list[0].face) # restore the sign of the area out_area = torch.sign(in_area) * out_area # mask for negative area neg_mask = out_area < 0 neg_area = out_area[neg_mask] # calculate the loss, we want it normalized by the batch size # and loss should be positive, so we are using -1 here. num_tangle += len(neg_area) return num_tangle / len(dataset)
[docs] def evaluate_repeat(dataset, model, device, loss_func, scaler=100, num_repeat=1): """ Evaluates model performance when sampling for different number of times. this function will evaluate: 1. the average loss 2. the average number of tangles Args: dataset (MeshDataset): The target dataset to evaluate. model (torch.nn.Module): The PyTorch model to evaluate. device (torch.device): The device to run the computation on. loss_func (callable): Loss function (e.g., MSE, Cross-Entropy). Returns: float: The average evaluation loss across all batches. """ model.eval() loaders = [ DataLoader(dataset=dataset, batch_size=1, shuffle=False) for i in range(num_repeat) ] data_iters = [iter(loader) for loader in loaders] num_tangle = 0 total_loss = 0 total_deform_loss = 0 total_area_loss = 0 for i in range(len(loaders[0])): data_list = [next(data_iter) for data_iter in data_iters] data_list = [data.to(device) for data in data_list] loss = 0 deform_loss = 0 inversion_loss = 0 area_loss = 0 with torch.no_grad(): out = [model(data) for data in data_list] out = torch.stack(out, dim=0) out = torch.mean(out, dim=0) # calculate the loss deform_loss = 1000 * (loss_func(out, data_list[0].y)) area_loss = get_area_loss(out, data_list[0].y, data_list[0].face, 1, scaler) loss = inversion_loss + deform_loss total_loss += loss.item() total_deform_loss += deform_loss.item() total_area_loss += area_loss.item() # calculate the number of tangles in_area = get_face_area(data_list[0].x[:, :2], data_list[0].face) out_area = get_face_area(out, data_list[0].face) out_area = torch.sign(in_area) * out_area neg_mask = out_area < 0 neg_area = out_area[neg_mask] num_tangle += len(neg_area) return { "total_loss": total_loss / len(dataset), "deform_loss": total_deform_loss / len(dataset), "tangle": num_tangle / len(dataset), "area_loss": total_area_loss / len(dataset), }
[docs] def load_model(model, weight_path, strict=False): """ Loads pre-trained weights into a PyTorch model from a given file path. Args: model (torch.nn.Module): The PyTorch model. weight_path (str): File path to the pre-trained model weights. Returns: torch.nn.Module: The model with loaded weights. """ try: model.load_state_dict(torch.load(weight_path), strict=strict) except RuntimeError: state_dict = torch.load(weight_path, map_location=torch.device("cpu")) model.load_state_dict(state_dict, strict=strict) return model