Source code for UM2N.test.tangle

# Author: Chunyang Wang
# GitHub Username: acse-cw1722

import os

import firedrake as fd
import matplotlib.pyplot as plt
import movement as mv
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

__all__ = ["check_dataset_tangle", "plot_prediction", "plot_sample"]


def check_tangle_pi(model, x):
    pass


[docs] def check_dataset_tangle( dataset, model, n_elem_x, n_elem_y, ): """ Return the percentage of tangling grid of a mesh in a dataset. """ num_tangled = 0 for idx in range(len(dataset)): mesh = fd.UnitSquareMesh(n_elem_x, n_elem_y) checker = mv.MeshTanglingChecker(mesh, mode="warn") check_item = dataset[idx] out = model(check_item.to(device)).detach().numpy() mesh.coordinates.dat.data[:, 0] = out[:, 0] mesh.coordinates.dat.data[:, 1] = out[:, 1] num_tangled += checker.check() return num_tangled / len(dataset)
[docs] def plot_prediction( data_set, model, prediction_dir, mode, n_elem_x, n_elem_y, loss_fn, savefig=True ): num_data = len(data_set) for idx in range(num_data): val_item = data_set[idx] plot_sample( model, val_item, prediction_dir, loss_fn, n_elem_x, n_elem_y, idx, mode, savefig, )
[docs] def plot_sample( model, val_item, prediction_dir, loss_fn, n_elem_x, n_elem_y, idx, mode, savefig=True, ): out = model(val_item.to(device)) # calculate the loss loss = 1000 * loss_fn(out, val_item.y).item() out = out.detach().numpy() # construct the mesh val_mesh = fd.UnitSquareMesh(n_elem_x, n_elem_y) val_new_mesh = fd.UnitSquareMesh(n_elem_x, n_elem_y) # init checker checker = mv.MeshTanglingChecker(val_new_mesh, mode="warn") # construct the predicted/target mesh val_mesh.coordinates.dat.data[:] = val_item.y[:] val_new_mesh.coordinates.dat.data[:] = out[:] num_tangle = checker.check() # plot the mesh, tangle/loss info fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(17, 8)) fd.triplot(val_mesh, axes=ax1) fd.triplot(val_new_mesh, axes=ax2) ax1.set_title("Target mesh") ax2.set_title("Predicted mesh") ax2.text( 0.5, -0.05, f"Num Tangle: {num_tangle}", ha="center", va="center", transform=ax2.transAxes, fontsize=14, ) fig.text(0.5, 0.01, f"Loss: {loss:.4f}", ha="center", va="center", fontsize=16) if savefig: fig.savefig(os.path.join(prediction_dir, f"{mode}_plot_{idx}.png"))