Source code for UM2N.test.bench_swirl

# Author: Chunyang Wang
# GitHub username: chunyang-w

import os
import pickle
import time
from pprint import pprint  # noqa

import firedrake as fd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from torch_geometric.loader import DataLoader

import UM2N
from UM2N.model.train_util import model_forward


[docs] def get_log_og(log_path, idx): """ Read log file from dataset log dir and return value in it """ df = pd.read_csv(os.path.join(log_path, f"log{idx:04d}.csv")) return { "error_og": df["error_og"][0], "error_adapt": df["error_adapt"][0], "time": df["time"][0], }
[docs] class SwirlEvaluator: """ Evaluate error for advection swirl problem: 1. Solver implementation for the swirl problem 2. Error & Time evaluation """ def __init__( self, mesh, mesh_coarse, mesh_fine, mesh_new, mesh_model, dataset, model, eval_dir, ds_root, **kwargs, ): # noqa """ Init the problem: 1. define problem on fine mesh and coarse mesh 2. init function space on fine & coarse mesh """ self.device = kwargs.pop("device", "cuda") self.model_used = kwargs.pop("model_used", "MRTransformer") # mesh vars self.mesh_coarse = mesh_coarse # coarse mesh self.mesh = mesh # mesh buffer for solving equations self.mesh_fine = mesh_fine # fine mesh self.mesh_new = mesh_new # adapted mesh by MA self.mesh_model = mesh_model # adapted mesh by model # evaluation vars self.dataset = dataset # dataset containing all data self.model = model # the NN model self.eval_dir = eval_dir # evaluation root dir self.ds_root = ds_root self.log_path = os.path.join(eval_dir, "log") self.plot_path = os.path.join(eval_dir, "plot") self.plot_more_path = os.path.join(eval_dir, "plot_more") self.plot_data_path = os.path.join(eval_dir, "plot_data") self.save_interval = kwargs.pop("save_interval", 5) self.num_samples_to_eval = kwargs.pop("num_samples_to_eval", 100) # Init coords setup self.init_coord = self.mesh.coordinates.vector().array().reshape(-1, 2) self.init_coord_fine = ( self.mesh_fine.coordinates.vector().array().reshape(-1, 2) ) # noqa self.best_coord = self.mesh.coordinates.vector().array().reshape(-1, 2) self.adapt_coord = self.mesh.coordinates.vector().array().reshape(-1, 2) # noqa # error measuring vars self.error_adapt_list = [] self.error_og_list = [] self.best_error_iter = 0 # X and Y coordinates self.x, self.y = fd.SpatialCoordinate(mesh) self.x_fine, self.y_fine = fd.SpatialCoordinate(self.mesh_fine) # function space on coarse mesh self.scalar_space = fd.FunctionSpace(self.mesh, "CG", 1) self.vector_space = fd.VectorFunctionSpace(self.mesh, "CG", 1) self.tensor_space = fd.TensorFunctionSpace(self.mesh, "CG", 1) # function space on fine mesh self.scalar_space_fine = fd.FunctionSpace(self.mesh_fine, "CG", 1) self.vector_space_fine = fd.VectorFunctionSpace(self.mesh_fine, "CG", 1) # noqa # Test/Trial function on coarse mesh self.du_trial = fd.TrialFunction(self.scalar_space) self.phi = fd.TestFunction(self.scalar_space) # Test/Trial function on fine mesh self.du_trial_fine = fd.TrialFunction(self.scalar_space_fine) self.phi_fine = fd.TestFunction(self.scalar_space_fine) # normal function on coarse / fine mesh self.n = fd.FacetNormal(self.mesh) self.n_fine = fd.FacetNormal(self.mesh_fine) # simulation params self.T = kwargs.pop("T", 1) self.t = 0.0 self.n_step = kwargs.pop("n_step", 500) self.threshold = ( self.T / 2 ) # Time point the swirl direction get reverted # noqa self.dt = self.T / self.n_step # self.dt = kwargs.pop("dt", 1e-3) self.dtc = fd.Constant(self.dt) # initial condition params self.sigma = kwargs.pop("sigma", (0.05 / 3)) self.alpha = kwargs.pop("alpha", 1.5) self.r_0 = kwargs.pop("r_0", 0.2) self.x_0 = kwargs.pop("x_0", 0.25) self.y_0 = kwargs.pop("y_0", 0.25) # initital condition of u on coarse / fine mesh u_init_exp = UM2N.get_u_0( self.x, self.y, self.r_0, self.x_0, self.y_0, self.sigma ) # noqa u_init_exp_fine = UM2N.get_u_0( self.x_fine, self.y_fine, self.r_0, self.x_0, self.y_0, self.sigma ) # noqa self.u_init = fd.Function(self.scalar_space).interpolate(u_init_exp) self.u_init_fine = fd.Function(self.scalar_space_fine).interpolate( u_init_exp_fine ) # noqa # PDE vars on coarse & fine mesh # solution field u self.u = fd.Function(self.scalar_space).assign(self.u_init) self.u1 = fd.Function(self.scalar_space) self.u2 = fd.Function(self.scalar_space) self.u_fine = fd.Function(self.scalar_space_fine).assign(self.u_init_fine) # noqa self.u1_fine = fd.Function(self.scalar_space_fine) self.u2_fine = fd.Function(self.scalar_space_fine) self.u_in = fd.Constant(0.0) self.u_in_fine = fd.Constant(0.0) # temp vars for saving u on coarse & fine mesh self.u_cur = fd.Function( self.scalar_space ) # solution from current time step # noqa self.u_cur_fine = fd.Function(self.scalar_space_fine) self.u_hess = fd.Function( self.scalar_space ) # buffer for hessian solver usage # noqa # buffers self.u_fine_buffer = fd.Function(self.scalar_space_fine).assign( self.u_init_fine ) # noqa self.coarse_adapt = fd.Function(self.scalar_space) self.coarse_2_fine = fd.Function(self.scalar_space_fine) self.coarse_2_fine_original = fd.Function(self.scalar_space_fine) # velocity field - the swirl: c self.c = fd.Function(self.vector_space) self.c_fine = fd.Function(self.vector_space_fine) self.cn = 0.5 * (fd.dot(self.c, self.n) + abs(fd.dot(self.c, self.n))) self.cn_fine = 0.5 * ( fd.dot(self.c_fine, self.n_fine) + abs(fd.dot(self.c_fine, self.n_fine)) ) # noqa # PDE problem RHS on coarse & fine mesh self.a = self.phi * self.du_trial * fd.dx(domain=self.mesh) self.a_fine = self.phi_fine * self.du_trial_fine * fd.dx(domain=self.mesh_fine) # noqa # PDE problem LHS on coarse & fine mesh # on coarse mesh self.L1 = self.dtc * ( self.u * fd.div(self.phi * self.c) * fd.dx(domain=self.mesh) # noqa - fd.conditional( fd.dot(self.c, self.n) < 0, self.phi * fd.dot(self.c, self.n) * self.u_in, 0.0, ) * fd.ds(domain=self.mesh) # noqa - fd.conditional( fd.dot(self.c, self.n) > 0, self.phi * fd.dot(self.c, self.n) * self.u, 0.0, ) * fd.ds(domain=self.mesh) # noqa - (self.phi("+") - self.phi("-")) * (self.cn("+") * self.u("+") - self.cn("-") * self.u("-")) * fd.dS(domain=self.mesh) ) # noqa self.L2 = fd.replace(self.L1, {self.u: self.u1}) self.L3 = fd.replace(self.L1, {self.u: self.u2}) # on fine mesh self.L1_fine = self.dtc * ( self.u_fine * fd.div(self.phi_fine * self.c_fine) * fd.dx(domain=self.mesh_fine) # noqa - fd.conditional( fd.dot(self.c_fine, self.n_fine) < 0, self.phi_fine * fd.dot(self.c_fine, self.n_fine) * self.u_in_fine, 0.0, ) * fd.ds(domain=self.mesh_fine) # noqa - fd.conditional( fd.dot(self.c_fine, self.n_fine) > 0, self.phi_fine * fd.dot(self.c_fine, self.n_fine) * self.u_fine, 0.0, ) * fd.ds(domain=self.mesh_fine) # noqa - (self.phi_fine("+") - self.phi_fine("-")) * ( self.cn_fine("+") * self.u_fine("+") - self.cn_fine("-") * self.u_fine("-") ) * fd.dS(domain=self.mesh_fine) ) # noqa self.L2_fine = fd.replace(self.L1_fine, {self.u_fine: self.u1_fine}) self.L3_fine = fd.replace(self.L1_fine, {self.u_fine: self.u2_fine}) # vars for storing final solutions self.du = fd.Function(self.scalar_space) self.du_fine = fd.Function(self.scalar_space_fine) # PDE solver (one coarse & fine mesh) setup: params = { "ksp_type": "preonly", "pc_type": "bjacobi", "sub_pc_type": "ilu", } # noqa # On coarse mesh self.prob1 = fd.LinearVariationalProblem(self.a, self.L1, self.du) self.solv1 = fd.LinearVariationalSolver(self.prob1, solver_parameters=params) # noqa self.prob2 = fd.LinearVariationalProblem(self.a, self.L2, self.du) self.solv2 = fd.LinearVariationalSolver(self.prob2, solver_parameters=params) # noqa self.prob3 = fd.LinearVariationalProblem(self.a, self.L3, self.du) self.solv3 = fd.LinearVariationalSolver(self.prob3, solver_parameters=params) # noqa # On fine mesh self.prob1_fine = fd.LinearVariationalProblem( self.a_fine, self.L1_fine, self.du_fine ) # noqa self.solv1_fine = fd.LinearVariationalSolver( self.prob1_fine, solver_parameters=params ) # noqa self.prob2_fine = fd.LinearVariationalProblem( self.a_fine, self.L2_fine, self.du_fine ) # noqa self.solv2_fine = fd.LinearVariationalSolver( self.prob2_fine, solver_parameters=params ) # noqa self.prob3_fine = fd.LinearVariationalProblem( self.a_fine, self.L3_fine, self.du_fine ) # noqa self.solv3_fine = fd.LinearVariationalSolver( self.prob3_fine, solver_parameters=params ) # noqa
[docs] def solve_u(self, t): """ Solve the PDE problem using RK (SSPRK) scheme on the coarse mesh store the solution field to a varaible: self.u_cur """ c_exp = UM2N.get_c(self.x, self.y, t, alpha=self.alpha) c_temp = fd.Function(self.vector_space).interpolate(c_exp) self.c.project(c_temp) self.solv1.solve() self.u1.assign(self.u + self.du) self.solv2.solve() self.u2.assign(0.75 * self.u + 0.25 * (self.u1 + self.du)) self.solv3.solve() self.u_cur.assign((1.0 / 3.0) * self.u + (2.0 / 3.0) * (self.u2 + self.du))
[docs] def solve_u_fine(self, t): """ Solve the PDE problem using RK (SSPRK) scheme on the fine mesh store the solution field to a varaible: self.u_cur_fine """ c_exp = UM2N.get_c(self.x_fine, self.y_fine, t, alpha=self.alpha) c_temp = fd.Function(self.vector_space_fine).interpolate(c_exp) self.c_fine.project(c_temp) self.solv1_fine.solve() self.u1_fine.assign(self.u_fine + self.du_fine) self.solv2_fine.solve() self.u2_fine.assign(0.75 * self.u_fine + 0.25 * (self.u1_fine + self.du_fine)) # noqa self.solv3_fine.solve() self.u_cur_fine.assign( (1.0 / 3.0) * self.u_fine + (2.0 / 3.0) * (self.u2_fine + self.du_fine) ) # noqa
[docs] def project_u_(self): self.u.project(self.u_fine_buffer) return
[docs] def make_log_dir(self): UM2N.mkdir_if_not_exist(self.log_path)
[docs] def make_plot_dir(self): UM2N.mkdir_if_not_exist(self.plot_path)
[docs] def make_plot_more_dir(self): UM2N.mkdir_if_not_exist(self.plot_more_path)
[docs] def make_plot_data_dir(self): UM2N.mkdir_if_not_exist(self.plot_data_path)
[docs] def eval_problem(self, model_name="model"): print("In eval problem") self.t = 0.0 step = 0 idx = 0 eval_cnt = 0 res = { "deform_loss": None, # nodal position loss "tangled_element": None, # tangled elements on a mesh # noqa "error_og": None, # PDE error on original uniform mesh # noqa "error_model": None, # PDE error on model generated mesh # noqa "error_ma": None, # PDE error on MA generated mesh # noqa "error_reduction_MA": None, # PDE error reduced by using MA mesh # noqa "error_reduction_model": None, # PDE error reduced by using model mesh # noqa "time_consumption_model": None, # time consumed generating mesh inferenced by the model # noqa "time_consumption_MA": None, # time consumed generating mesh by Monge-Ampere method # noqa "acceration_ratio": None, # time_consumption_ma / time_consumption_model # noqa } for i in range(self.n_step): print("evalutation, time: ", self.t) # data loading from raw file raw_data_path = self.dataset.file_names[idx] raw_data = np.load(raw_data_path, allow_pickle=True).item() data_t = raw_data.get("swirl_params")["t"] y = raw_data.get("y") # error tracking lists init self.error_adapt_list = [] self.error_og_list = [] sample = next( iter(DataLoader([self.dataset[idx]], batch_size=1, shuffle=False)) ) # solve PDE problem on fine mesh self.solve_u_fine(self.t) if abs(self.t - data_t) < 1e-5: print( f"---- evaluating samples: step: {step}, t: {self.t:.5f}, data_t: {data_t:.5f}" ) # noqa # Evaluation time step hit # initiate model inferencing ... self.model.eval() bs = 1 sample = sample.to(self.device) self.model = self.model.to(self.device) with torch.no_grad(): start = time.perf_counter() if self.model_used == "MRTransformer" or self.model_used == "M2T": data = sample ( output_coord, output, out_monitor, phix, phiy, mesh_query_x_all, mesh_query_y_all, ) = model_forward( bs, data, self.model, use_add_random_query=False, ) out = output_coord # # Create mesh query for deformer, seperate from the original mesh as feature for encoder # mesh_query_x = ( # sample.mesh_feat[:, 0].view(-1, 1).detach().clone() # ) # mesh_query_y = ( # sample.mesh_feat[:, 1].view(-1, 1).detach().clone() # ) # mesh_query_x.requires_grad = True # mesh_query_y.requires_grad = True # mesh_query = torch.cat([mesh_query_x, mesh_query_y], dim=-1) # num_nodes = mesh_query.shape[-2] // bs # # Generate random mesh queries for unsupervised learning # sampled_queries = generate_samples( # bs=bs, # num_samples_per_mesh=num_nodes, # num_meshes=5, # data=sample, # device=self.device, # ) # sampled_queries_edge_index = construct_graph( # sampled_queries[:, :, :2], num_neighbors=6 # ) # mesh_sampled_queries_x = ( # sampled_queries[:, :, 0].view(-1, 1).detach() # ) # mesh_sampled_queries_y = ( # sampled_queries[:, :, 1].view(-1, 1).detach() # ) # mesh_sampled_queries_x.requires_grad = True # mesh_sampled_queries_y.requires_grad = True # mesh_sampled_queries = torch.cat( # [mesh_sampled_queries_x, mesh_sampled_queries_y], dim=-1 # ).view(-1, 2) # coord_ori_x = sample.mesh_feat[:, 0].view(-1, 1) # coord_ori_y = sample.mesh_feat[:, 1].view(-1, 1) # coord_ori_x.requires_grad = True # coord_ori_y.requires_grad = True # coord_ori = torch.cat([coord_ori_x, coord_ori_y], dim=-1) # num_nodes = coord_ori.shape[-2] // bs # input_q = sample.mesh_feat[:, :4] # input_kv = generate_samples( # bs=bs, # num_samples_per_mesh=num_nodes, # data=sample, # device=self.device, # ) # # print(f"batch size: {bs}, num_nodes: {num_nodes}, input q", input_q.shape, "input_kv ", input_kv.shape) # (output_coord_all, output, out_monitor), (phix, phiy) = ( # self.model( # sample, # input_q, # input_q, # mesh_query, # sampled_queries=None, # sampled_queries_edge_index=None, # ) # ) # # (output_coord_all, output, out_monitor), (phix, phiy) = model(data, input_q, input_kv, mesh_query, sampled_queries, sampled_queries_edge_index) # out = output_coord_all[: num_nodes * bs] elif self.model_used == "M2N": out = self.model(sample) elif self.model_used == "MRN": out = self.model(sample) else: raise Exception(f"model {self.model_used} not implemented.") end = time.perf_counter() dur_ms = (end - start) * 1000 # TODO: do not uncomment this, there are accumulation errors # # calculate solution on fine mesh # function_space_fine = fd.FunctionSpace(self.mesh_fine, "CG", 1) # self.solve_u_fine(self.t) # self.u_fine = fd.Function(function_space_fine).project( # self.u_cur_fine # ) # noqa # calculate solution on original mesh self.mesh.coordinates.dat.data[:] = self.init_coord self.project_u_() self.solve_u(self.t) function_space = fd.FunctionSpace(self.mesh_coarse, "CG", 1) self.uh = fd.Function(function_space).project(self.u_cur) # calculate solution on adapted mesh self.adapt_coord = y self.mesh.coordinates.dat.data[:] = self.adapt_coord self.mesh_new.coordinates.dat.data[:] = self.adapt_coord self.project_u_() self.solve_u(self.t) function_space_new = fd.FunctionSpace(self.mesh_new, "CG", 1) # noqa self.uh_new = fd.Function(function_space_new).project(self.u_cur) # noqa # calculate solution on model output mesh self.adapt_coord = out.detach().cpu().numpy() self.mesh.coordinates.dat.data[:] = self.adapt_coord self.mesh_model.coordinates.dat.data[:] = self.adapt_coord self.project_u_() self.solve_u(self.t) function_space_model = fd.FunctionSpace(self.mesh_model, "CG", 1) # noqa self.uh_model = fd.Function(function_space_model).project(self.u_cur) # noqa # check mesh integrity - Only perform evaluation on non-tangling mesh # noqa num_tangle = UM2N.get_sample_tangle(out, sample.x[:, :2], sample.face) # noqa if isinstance(num_tangle, torch.Tensor): num_tangle = num_tangle.item() if num_tangle > 0: # has tangled elems: res["tangled_element"] = num_tangle res["error_model"] = -1 else: # mesh is valid, perform evaluation: 1. res["tangled_element"] = num_tangle self.mesh_model.coordinates.dat.data[:] = out.detach().cpu().numpy() # solve on fine mesh function_space_fine = fd.FunctionSpace(self.mesh_fine, "CG", 1) # self.solve_u_fine(self.t) u_fine = fd.Function(function_space_fine).project(self.u_cur_fine) # noqa fig, plot_data_dict = UM2N.plot_compare( self.mesh_fine, self.mesh_coarse, self.mesh_new, self.mesh_model, u_fine, self.uh, self.uh_new, self.uh_model, raw_data.get("hessian_norm")[:, 0], raw_data.get("monitor_val")[:, 0], num_tangle, model_name, ) res["deform_loss"] = 1000 * torch.nn.L1Loss()(out, sample.y).item() plot_data_dict["deform_loss"] = res["deform_loss"] fig.savefig(os.path.join(self.plot_more_path, f"plot_{idx:04d}.png")) # noqa plt.close(fig) # Save plot data with open( os.path.join(self.plot_data_path, f"plot_data_{idx:04d}.pkl"), "wb" ) as p: pickle.dump(plot_data_dict, p) # ======================== Legacy plotting ======================== # # error measuring # ( # u_fine, # u_og_fine, # u_ma_fine, # u_og_coarse, # u_ma_coarse, # error_og, # error_adapt, # ) = self.get_error() error_model = plot_data_dict["error_norm_model"] error_og = plot_data_dict["error_norm_original"] error_ma = plot_data_dict["error_norm_ma"] print(f"error_og: {error_og}, \terror_ma: {error_ma}") res["error_og"] = error_og res["error_ma"] = error_ma res["error_model"] = error_model print("inspect out type: ", type(out.detach().cpu().numpy())) # get time_MA by reading log file res["time_consumption_MA"] = get_log_og( os.path.join(self.ds_root, "log"), idx )["time"] print(res) # metric calculation # res["deform_loss"] = 1000 * torch.nn.L1Loss()(out, sample.y).item() res["time_consumption_model"] = dur_ms res["acceration_ratio"] = ( res["time_consumption_MA"] / res["time_consumption_model"] ) # noqa res["error_reduction_MA"] = (res["error_og"] - res["error_ma"]) / res[ "error_og" ] # noqa res["error_reduction_model"] = ( res["error_og"] - res["error_model"] ) / res["error_og"] # noqa # save file df = pd.DataFrame(res, index=[0]) df.to_csv(os.path.join(self.log_path, f"log_{idx:04d}.csv")) # plot compare mesh plot_fig = UM2N.plot_mesh_compare_benchmark( out.detach().cpu().numpy(), sample.y.detach().cpu().numpy(), sample.face.detach().cpu().numpy(), deform_loss=res["deform_loss"], pde_loss_model=res["error_model"], pde_loss_reduction_model=res["error_reduction_model"], pde_loss_MA=res["error_ma"], pde_loss_reduction_MA=res["error_reduction_MA"], tangle=res["tangled_element"], ) plot_fig.savefig(os.path.join(self.plot_path, f"plot_{idx:04d}.png")) # noqa # plotting (visulisation during sovling) plot = False if plot is True: self.plot_res() plt.show() idx += 1 plt.close() eval_cnt += 1 # time stepping and prep for next solving iter self.t += self.dt step += 1 self.u_fine.assign(self.u_cur_fine) self.u_fine_buffer.assign(self.u_cur_fine) if eval_cnt >= self.num_samples_to_eval: break return
[docs] def vis_evaluate(self, sample): """ It would be great if we have some visuals here to assist out judgment. """ print("In evaluation VISUALISATION") self.mesh.coordinates.dat.data[:] = sample.y fd.triplot(self.mesh) self.mesh.coordinates.dat.data[:] = self.init_coord plt.show() return
[docs] def get_error(self): # solve on fine mesh function_space_fine = fd.FunctionSpace(self.mesh_fine, "CG", 1) self.solve_u_fine(self.t) u_fine = fd.Function(function_space_fine).project(self.u_cur_fine) # noqa # solve on coarse mesh self.mesh.coordinates.dat.data[:] = self.init_coord self.project_u_() self.solve_u(self.t) function_space_coarse = fd.FunctionSpace(self.mesh, "CG", 1) # u_og_2_coarse = fd.project(self.u_cur, function_space_coarse) u_og_2_coarse = fd.Function(function_space_coarse).project(self.u_cur) # noqa u_og_2_fine = fd.project(self.u_cur, function_space_fine) # solve on coarse adapt mesh self.mesh.coordinates.dat.data[:] = self.adapt_coord self.project_u_() self.solve_u(self.t) function_space_coarse = fd.FunctionSpace(self.mesh, "CG", 1) # u_adapt_2_coarse = fd.project(self.u_cur, function_space_coarse) u_adapt_2_coarse = fd.Function(function_space_coarse).project(self.u_cur) # noqa u_adapt_2_fine = fd.project(self.u_cur, function_space_fine) # error calculation error_og = fd.errornorm(u_fine, u_og_2_fine, norm_type="L2") error_adapt = fd.errornorm(u_fine, u_adapt_2_fine, norm_type="L2") # put mesh to init state self.mesh.coordinates.dat.data[:] = self.init_coord # u_fine_raw_data = u_fine.dat.data[:] return ( u_fine, u_og_2_fine, u_adapt_2_fine, u_og_2_coarse, u_adapt_2_coarse, error_og, error_adapt, )
# return u_fine_raw_data, u_og_raw_data, u_adapt_raw_data, error_og, error_adapt
[docs] def plot_res(self): fig = plt.figure(figsize=(15, 10)) ax1 = fig.add_subplot(2, 3, 1, projection="3d") ax1.set_title("Solution on fine mesh") fd.trisurf(self.u_cur_fine, axes=ax1) ax2 = fig.add_subplot(2, 3, 2, projection="3d") ax2.set_title("Solution on original mesh") fd.trisurf(self.uh, axes=ax2) ax3 = fig.add_subplot(2, 3, 3, projection="3d") ax3.set_title("Solution on adapt mesh") fd.trisurf(self.uh_new, axes=ax3) ax4 = fig.add_subplot(2, 3, 4) ax4.set_title("Fine mesh") fd.triplot(self.mesh_fine, axes=ax4) ax5 = fig.add_subplot(2, 3, 5) ax5.set_title("Orignal mesh") fd.tripcolor(self.uh, axes=ax5, cmap="coolwarm") fd.triplot(self.mesh, axes=ax5) ax6 = fig.add_subplot(2, 3, 6) ax6.set_title("adapted mesh") fd.tripcolor(self.uh_new, axes=ax6, cmap="coolwarm") fd.triplot(self.mesh_new, axes=ax6) return fig