# Author: Chunyang Wang
# GitHub username: chunyang-w
import os
import pickle
import time
from pprint import pprint  # noqa
import firedrake as fd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from torch_geometric.loader import DataLoader
import UM2N
from UM2N.model.train_util import model_forward
[docs]
def get_log_og(log_path, idx):
    """
    Read log file from dataset log dir and return value in it
    """
    df = pd.read_csv(os.path.join(log_path, f"log{idx:04d}.csv"))
    return {
        "error_og": df["error_og"][0],
        "error_adapt": df["error_adapt"][0],
        "time": df["time"][0],
    } 
[docs]
class SwirlEvaluator:
    """
    Evaluate error for advection swirl problem:
        1. Solver implementation for the swirl problem
        2. Error & Time evaluation
    """
    def __init__(
        self,
        mesh,
        mesh_coarse,
        mesh_fine,
        mesh_new,
        mesh_model,
        dataset,
        model,
        eval_dir,
        ds_root,
        **kwargs,
    ):  # noqa
        """
        Init the problem:
            1. define problem on fine mesh and coarse mesh
            2. init function space on fine & coarse mesh
        """
        self.device = kwargs.pop("device", "cuda")
        self.model_used = kwargs.pop("model_used", "MRTransformer")
        # mesh vars
        self.mesh_coarse = mesh_coarse  # coarse mesh
        self.mesh = mesh  # mesh buffer for solving equations
        self.mesh_fine = mesh_fine  # fine mesh
        self.mesh_new = mesh_new  # adapted mesh by MA
        self.mesh_model = mesh_model  # adapted mesh by model
        # evaluation vars
        self.dataset = dataset  # dataset containing all data
        self.model = model  # the NN model
        self.eval_dir = eval_dir  # evaluation root dir
        self.ds_root = ds_root
        self.log_path = os.path.join(eval_dir, "log")
        self.plot_path = os.path.join(eval_dir, "plot")
        self.plot_more_path = os.path.join(eval_dir, "plot_more")
        self.plot_data_path = os.path.join(eval_dir, "plot_data")
        self.save_interval = kwargs.pop("save_interval", 5)
        self.num_samples_to_eval = kwargs.pop("num_samples_to_eval", 100)
        # Init coords setup
        self.init_coord = self.mesh.coordinates.vector().array().reshape(-1, 2)
        self.init_coord_fine = (
            self.mesh_fine.coordinates.vector().array().reshape(-1, 2)
        )  # noqa
        self.best_coord = self.mesh.coordinates.vector().array().reshape(-1, 2)
        self.adapt_coord = self.mesh.coordinates.vector().array().reshape(-1, 2)  # noqa
        # error measuring vars
        self.error_adapt_list = []
        self.error_og_list = []
        self.best_error_iter = 0
        # X and Y coordinates
        self.x, self.y = fd.SpatialCoordinate(mesh)
        self.x_fine, self.y_fine = fd.SpatialCoordinate(self.mesh_fine)
        # function space on coarse mesh
        self.scalar_space = fd.FunctionSpace(self.mesh, "CG", 1)
        self.vector_space = fd.VectorFunctionSpace(self.mesh, "CG", 1)
        self.tensor_space = fd.TensorFunctionSpace(self.mesh, "CG", 1)
        # function space on fine mesh
        self.scalar_space_fine = fd.FunctionSpace(self.mesh_fine, "CG", 1)
        self.vector_space_fine = fd.VectorFunctionSpace(self.mesh_fine, "CG", 1)  # noqa
        # Test/Trial function on coarse mesh
        self.du_trial = fd.TrialFunction(self.scalar_space)
        self.phi = fd.TestFunction(self.scalar_space)
        # Test/Trial function on fine mesh
        self.du_trial_fine = fd.TrialFunction(self.scalar_space_fine)
        self.phi_fine = fd.TestFunction(self.scalar_space_fine)
        # normal function on coarse / fine mesh
        self.n = fd.FacetNormal(self.mesh)
        self.n_fine = fd.FacetNormal(self.mesh_fine)
        # simulation params
        self.T = kwargs.pop("T", 1)
        self.t = 0.0
        self.n_step = kwargs.pop("n_step", 500)
        self.threshold = (
            self.T / 2
        )  # Time point the swirl direction get reverted  # noqa
        self.dt = self.T / self.n_step
        # self.dt = kwargs.pop("dt", 1e-3)
        self.dtc = fd.Constant(self.dt)
        # initial condition params
        self.sigma = kwargs.pop("sigma", (0.05 / 3))
        self.alpha = kwargs.pop("alpha", 1.5)
        self.r_0 = kwargs.pop("r_0", 0.2)
        self.x_0 = kwargs.pop("x_0", 0.25)
        self.y_0 = kwargs.pop("y_0", 0.25)
        # initital condition of u on coarse / fine mesh
        u_init_exp = UM2N.get_u_0(
            self.x, self.y, self.r_0, self.x_0, self.y_0, self.sigma
        )  # noqa
        u_init_exp_fine = UM2N.get_u_0(
            self.x_fine, self.y_fine, self.r_0, self.x_0, self.y_0, self.sigma
        )  # noqa
        self.u_init = fd.Function(self.scalar_space).interpolate(u_init_exp)
        self.u_init_fine = fd.Function(self.scalar_space_fine).interpolate(
            u_init_exp_fine
        )  # noqa
        # PDE vars on coarse & fine mesh
        #       solution field u
        self.u = fd.Function(self.scalar_space).assign(self.u_init)
        self.u1 = fd.Function(self.scalar_space)
        self.u2 = fd.Function(self.scalar_space)
        self.u_fine = fd.Function(self.scalar_space_fine).assign(self.u_init_fine)  # noqa
        self.u1_fine = fd.Function(self.scalar_space_fine)
        self.u2_fine = fd.Function(self.scalar_space_fine)
        self.u_in = fd.Constant(0.0)
        self.u_in_fine = fd.Constant(0.0)
        #       temp vars for saving u on coarse & fine mesh
        self.u_cur = fd.Function(
            self.scalar_space
        )  # solution from current time step  # noqa
        self.u_cur_fine = fd.Function(self.scalar_space_fine)
        self.u_hess = fd.Function(
            self.scalar_space
        )  # buffer for hessian solver usage  # noqa
        #       buffers
        self.u_fine_buffer = fd.Function(self.scalar_space_fine).assign(
            self.u_init_fine
        )  # noqa
        self.coarse_adapt = fd.Function(self.scalar_space)
        self.coarse_2_fine = fd.Function(self.scalar_space_fine)
        self.coarse_2_fine_original = fd.Function(self.scalar_space_fine)
        #       velocity field - the swirl: c
        self.c = fd.Function(self.vector_space)
        self.c_fine = fd.Function(self.vector_space_fine)
        self.cn = 0.5 * (fd.dot(self.c, self.n) + abs(fd.dot(self.c, self.n)))
        self.cn_fine = 0.5 * (
            fd.dot(self.c_fine, self.n_fine) + abs(fd.dot(self.c_fine, self.n_fine))
        )  # noqa
        # PDE problem RHS on coarse & fine mesh
        self.a = self.phi * self.du_trial * fd.dx(domain=self.mesh)
        self.a_fine = self.phi_fine * self.du_trial_fine * fd.dx(domain=self.mesh_fine)  # noqa
        # PDE problem LHS on coarse & fine mesh
        #       on coarse mesh
        self.L1 = self.dtc * (
            self.u * fd.div(self.phi * self.c) * fd.dx(domain=self.mesh)  # noqa
            - fd.conditional(
                fd.dot(self.c, self.n) < 0,
                self.phi * fd.dot(self.c, self.n) * self.u_in,
                0.0,
            )
            * fd.ds(domain=self.mesh)  # noqa
            - fd.conditional(
                fd.dot(self.c, self.n) > 0,
                self.phi * fd.dot(self.c, self.n) * self.u,
                0.0,
            )
            * fd.ds(domain=self.mesh)  # noqa
            - (self.phi("+") - self.phi("-"))
            * (self.cn("+") * self.u("+") - self.cn("-") * self.u("-"))
            * fd.dS(domain=self.mesh)
        )  # noqa
        self.L2 = fd.replace(self.L1, {self.u: self.u1})
        self.L3 = fd.replace(self.L1, {self.u: self.u2})
        #       on fine mesh
        self.L1_fine = self.dtc * (
            self.u_fine
            * fd.div(self.phi_fine * self.c_fine)
            * fd.dx(domain=self.mesh_fine)  # noqa
            - fd.conditional(
                fd.dot(self.c_fine, self.n_fine) < 0,
                self.phi_fine * fd.dot(self.c_fine, self.n_fine) * self.u_in_fine,
                0.0,
            )
            * fd.ds(domain=self.mesh_fine)  # noqa
            - fd.conditional(
                fd.dot(self.c_fine, self.n_fine) > 0,
                self.phi_fine * fd.dot(self.c_fine, self.n_fine) * self.u_fine,
                0.0,
            )
            * fd.ds(domain=self.mesh_fine)  # noqa
            - (self.phi_fine("+") - self.phi_fine("-"))
            * (
                self.cn_fine("+") * self.u_fine("+")
                - self.cn_fine("-") * self.u_fine("-")
            )
            * fd.dS(domain=self.mesh_fine)
        )  # noqa
        self.L2_fine = fd.replace(self.L1_fine, {self.u_fine: self.u1_fine})
        self.L3_fine = fd.replace(self.L1_fine, {self.u_fine: self.u2_fine})
        # vars for storing final solutions
        self.du = fd.Function(self.scalar_space)
        self.du_fine = fd.Function(self.scalar_space_fine)
        # PDE solver (one coarse & fine mesh) setup:
        params = {
            "ksp_type": "preonly",
            "pc_type": "bjacobi",
            "sub_pc_type": "ilu",
        }  # noqa
        #       On coarse mesh
        self.prob1 = fd.LinearVariationalProblem(self.a, self.L1, self.du)
        self.solv1 = fd.LinearVariationalSolver(self.prob1, solver_parameters=params)  # noqa
        self.prob2 = fd.LinearVariationalProblem(self.a, self.L2, self.du)
        self.solv2 = fd.LinearVariationalSolver(self.prob2, solver_parameters=params)  # noqa
        self.prob3 = fd.LinearVariationalProblem(self.a, self.L3, self.du)
        self.solv3 = fd.LinearVariationalSolver(self.prob3, solver_parameters=params)  # noqa
        #       On fine mesh
        self.prob1_fine = fd.LinearVariationalProblem(
            self.a_fine, self.L1_fine, self.du_fine
        )  # noqa
        self.solv1_fine = fd.LinearVariationalSolver(
            self.prob1_fine, solver_parameters=params
        )  # noqa
        self.prob2_fine = fd.LinearVariationalProblem(
            self.a_fine, self.L2_fine, self.du_fine
        )  # noqa
        self.solv2_fine = fd.LinearVariationalSolver(
            self.prob2_fine, solver_parameters=params
        )  # noqa
        self.prob3_fine = fd.LinearVariationalProblem(
            self.a_fine, self.L3_fine, self.du_fine
        )  # noqa
        self.solv3_fine = fd.LinearVariationalSolver(
            self.prob3_fine, solver_parameters=params
        )  # noqa
[docs]
    def solve_u(self, t):
        """
        Solve the PDE problem using RK (SSPRK) scheme on the coarse mesh
        store the solution field to a varaible: self.u_cur
        """
        c_exp = UM2N.get_c(self.x, self.y, t, alpha=self.alpha)
        c_temp = fd.Function(self.vector_space).interpolate(c_exp)
        self.c.project(c_temp)
        self.solv1.solve()
        self.u1.assign(self.u + self.du)
        self.solv2.solve()
        self.u2.assign(0.75 * self.u + 0.25 * (self.u1 + self.du))
        self.solv3.solve()
        self.u_cur.assign((1.0 / 3.0) * self.u + (2.0 / 3.0) * (self.u2 + self.du)) 
[docs]
    def solve_u_fine(self, t):
        """
        Solve the PDE problem using RK (SSPRK) scheme on the fine mesh
        store the solution field to a varaible: self.u_cur_fine
        """
        c_exp = UM2N.get_c(self.x_fine, self.y_fine, t, alpha=self.alpha)
        c_temp = fd.Function(self.vector_space_fine).interpolate(c_exp)
        self.c_fine.project(c_temp)
        self.solv1_fine.solve()
        self.u1_fine.assign(self.u_fine + self.du_fine)
        self.solv2_fine.solve()
        self.u2_fine.assign(0.75 * self.u_fine + 0.25 * (self.u1_fine + self.du_fine))  # noqa
        self.solv3_fine.solve()
        self.u_cur_fine.assign(
            (1.0 / 3.0) * self.u_fine + (2.0 / 3.0) * (self.u2_fine + self.du_fine)
        )  # noqa 
[docs]
    def project_u_(self):
        self.u.project(self.u_fine_buffer)
        return 
[docs]
    def make_log_dir(self):
        UM2N.mkdir_if_not_exist(self.log_path) 
[docs]
    def make_plot_dir(self):
        UM2N.mkdir_if_not_exist(self.plot_path) 
[docs]
    def make_plot_more_dir(self):
        UM2N.mkdir_if_not_exist(self.plot_more_path) 
[docs]
    def make_plot_data_dir(self):
        UM2N.mkdir_if_not_exist(self.plot_data_path) 
[docs]
    def eval_problem(self, model_name="model"):
        print("In eval problem")
        self.t = 0.0
        step = 0
        idx = 0
        eval_cnt = 0
        res = {
            "deform_loss": None,  # nodal position loss
            "tangled_element": None,  # tangled elements on a mesh  # noqa
            "error_og": None,  # PDE error on original uniform mesh  # noqa
            "error_model": None,  # PDE error on model generated mesh   # noqa
            "error_ma": None,  # PDE error on MA generated mesh      # noqa
            "error_reduction_MA": None,  # PDE error reduced by using MA mesh  # noqa
            "error_reduction_model": None,  # PDE error reduced by using model mesh  # noqa
            "time_consumption_model": None,  # time consumed generating mesh inferenced by the model  # noqa
            "time_consumption_MA": None,  # time consumed generating mesh by Monge-Ampere method  # noqa
            "acceration_ratio": None,  # time_consumption_ma / time_consumption_model  # noqa
        }
        for i in range(self.n_step):
            print("evalutation, time: ", self.t)
            # data loading from raw file
            raw_data_path = self.dataset.file_names[idx]
            raw_data = np.load(raw_data_path, allow_pickle=True).item()
            data_t = raw_data.get("swirl_params")["t"]
            y = raw_data.get("y")
            # error tracking lists init
            self.error_adapt_list = []
            self.error_og_list = []
            sample = next(
                iter(DataLoader([self.dataset[idx]], batch_size=1, shuffle=False))
            )
            # solve PDE problem on fine mesh
            self.solve_u_fine(self.t)
            if abs(self.t - data_t) < 1e-5:
                print(
                    f"---- evaluating samples: step: {step}, t: {self.t:.5f}, data_t: {data_t:.5f}"
                )  # noqa
                # Evaluation time step hit
                # initiate model inferencing ...
                self.model.eval()
                bs = 1
                sample = sample.to(self.device)
                self.model = self.model.to(self.device)
                with torch.no_grad():
                    start = time.perf_counter()
                    if self.model_used == "MRTransformer" or self.model_used == "M2T":
                        data = sample
                        (
                            output_coord,
                            output,
                            out_monitor,
                            phix,
                            phiy,
                            mesh_query_x_all,
                            mesh_query_y_all,
                        ) = model_forward(
                            bs,
                            data,
                            self.model,
                            use_add_random_query=False,
                        )
                        out = output_coord
                        # # Create mesh query for deformer, seperate from the original mesh as feature for encoder
                        # mesh_query_x = (
                        #     sample.mesh_feat[:, 0].view(-1, 1).detach().clone()
                        # )
                        # mesh_query_y = (
                        #     sample.mesh_feat[:, 1].view(-1, 1).detach().clone()
                        # )
                        # mesh_query_x.requires_grad = True
                        # mesh_query_y.requires_grad = True
                        # mesh_query = torch.cat([mesh_query_x, mesh_query_y], dim=-1)
                        # num_nodes = mesh_query.shape[-2] // bs
                        # # Generate random mesh queries for unsupervised learning
                        # sampled_queries = generate_samples(
                        #     bs=bs,
                        #     num_samples_per_mesh=num_nodes,
                        #     num_meshes=5,
                        #     data=sample,
                        #     device=self.device,
                        # )
                        # sampled_queries_edge_index = construct_graph(
                        #     sampled_queries[:, :, :2], num_neighbors=6
                        # )
                        # mesh_sampled_queries_x = (
                        #     sampled_queries[:, :, 0].view(-1, 1).detach()
                        # )
                        # mesh_sampled_queries_y = (
                        #     sampled_queries[:, :, 1].view(-1, 1).detach()
                        # )
                        # mesh_sampled_queries_x.requires_grad = True
                        # mesh_sampled_queries_y.requires_grad = True
                        # mesh_sampled_queries = torch.cat(
                        #     [mesh_sampled_queries_x, mesh_sampled_queries_y], dim=-1
                        # ).view(-1, 2)
                        # coord_ori_x = sample.mesh_feat[:, 0].view(-1, 1)
                        # coord_ori_y = sample.mesh_feat[:, 1].view(-1, 1)
                        # coord_ori_x.requires_grad = True
                        # coord_ori_y.requires_grad = True
                        # coord_ori = torch.cat([coord_ori_x, coord_ori_y], dim=-1)
                        # num_nodes = coord_ori.shape[-2] // bs
                        # input_q = sample.mesh_feat[:, :4]
                        # input_kv = generate_samples(
                        #     bs=bs,
                        #     num_samples_per_mesh=num_nodes,
                        #     data=sample,
                        #     device=self.device,
                        # )
                        # # print(f"batch size: {bs}, num_nodes: {num_nodes}, input q", input_q.shape, "input_kv ", input_kv.shape)
                        # (output_coord_all, output, out_monitor), (phix, phiy) = (
                        #     self.model(
                        #         sample,
                        #         input_q,
                        #         input_q,
                        #         mesh_query,
                        #         sampled_queries=None,
                        #         sampled_queries_edge_index=None,
                        #     )
                        # )
                        # # (output_coord_all, output, out_monitor), (phix, phiy) = model(data, input_q, input_kv, mesh_query, sampled_queries, sampled_queries_edge_index)
                        # out = output_coord_all[: num_nodes * bs]
                    elif self.model_used == "M2N":
                        out = self.model(sample)
                    elif self.model_used == "MRN":
                        out = self.model(sample)
                    else:
                        raise Exception(f"model {self.model_used} not implemented.")
                    end = time.perf_counter()
                    dur_ms = (end - start) * 1000
                # TODO: do not uncomment this, there are accumulation errors
                # # calculate solution on fine mesh
                # function_space_fine = fd.FunctionSpace(self.mesh_fine, "CG", 1)
                # self.solve_u_fine(self.t)
                # self.u_fine = fd.Function(function_space_fine).project(
                #     self.u_cur_fine
                # )  # noqa
                # calculate solution on original mesh
                self.mesh.coordinates.dat.data[:] = self.init_coord
                self.project_u_()
                self.solve_u(self.t)
                function_space = fd.FunctionSpace(self.mesh_coarse, "CG", 1)
                self.uh = fd.Function(function_space).project(self.u_cur)
                # calculate solution on adapted mesh
                self.adapt_coord = y
                self.mesh.coordinates.dat.data[:] = self.adapt_coord
                self.mesh_new.coordinates.dat.data[:] = self.adapt_coord
                self.project_u_()
                self.solve_u(self.t)
                function_space_new = fd.FunctionSpace(self.mesh_new, "CG", 1)  # noqa
                self.uh_new = fd.Function(function_space_new).project(self.u_cur)  # noqa
                # calculate solution on model output mesh
                self.adapt_coord = out.detach().cpu().numpy()
                self.mesh.coordinates.dat.data[:] = self.adapt_coord
                self.mesh_model.coordinates.dat.data[:] = self.adapt_coord
                self.project_u_()
                self.solve_u(self.t)
                function_space_model = fd.FunctionSpace(self.mesh_model, "CG", 1)  # noqa
                self.uh_model = fd.Function(function_space_model).project(self.u_cur)  # noqa
                # check mesh integrity - Only perform evaluation on non-tangling mesh  # noqa
                num_tangle = UM2N.get_sample_tangle(out, sample.x[:, :2], sample.face)  # noqa
                if isinstance(num_tangle, torch.Tensor):
                    num_tangle = num_tangle.item()
                if num_tangle > 0:  # has tangled elems:
                    res["tangled_element"] = num_tangle
                    res["error_model"] = -1
                else:  # mesh is valid, perform evaluation: 1.
                    res["tangled_element"] = num_tangle
                self.mesh_model.coordinates.dat.data[:] = out.detach().cpu().numpy()
                # solve on fine mesh
                function_space_fine = fd.FunctionSpace(self.mesh_fine, "CG", 1)
                # self.solve_u_fine(self.t)
                u_fine = fd.Function(function_space_fine).project(self.u_cur_fine)  # noqa
                fig, plot_data_dict = UM2N.plot_compare(
                    self.mesh_fine,
                    self.mesh_coarse,
                    self.mesh_new,
                    self.mesh_model,
                    u_fine,
                    self.uh,
                    self.uh_new,
                    self.uh_model,
                    raw_data.get("hessian_norm")[:, 0],
                    raw_data.get("monitor_val")[:, 0],
                    num_tangle,
                    model_name,
                )
                res["deform_loss"] = 1000 * torch.nn.L1Loss()(out, sample.y).item()
                plot_data_dict["deform_loss"] = res["deform_loss"]
                fig.savefig(os.path.join(self.plot_more_path, f"plot_{idx:04d}.png"))  # noqa
                plt.close(fig)
                # Save plot data
                with open(
                    os.path.join(self.plot_data_path, f"plot_data_{idx:04d}.pkl"), "wb"
                ) as p:
                    pickle.dump(plot_data_dict, p)
                # ======================== Legacy plotting ========================
                # # error measuring
                # (
                #     u_fine,
                #     u_og_fine,
                #     u_ma_fine,
                #     u_og_coarse,
                #     u_ma_coarse,
                #     error_og,
                #     error_adapt,
                # ) = self.get_error()
                error_model = plot_data_dict["error_norm_model"]
                error_og = plot_data_dict["error_norm_original"]
                error_ma = plot_data_dict["error_norm_ma"]
                print(f"error_og: {error_og}, \terror_ma: {error_ma}")
                res["error_og"] = error_og
                res["error_ma"] = error_ma
                res["error_model"] = error_model
                print("inspect out type: ", type(out.detach().cpu().numpy()))
                # get time_MA by reading log file
                res["time_consumption_MA"] = get_log_og(
                    os.path.join(self.ds_root, "log"), idx
                )["time"]
                print(res)
                # metric calculation
                # res["deform_loss"] = 1000 * torch.nn.L1Loss()(out, sample.y).item()
                res["time_consumption_model"] = dur_ms
                res["acceration_ratio"] = (
                    res["time_consumption_MA"] / res["time_consumption_model"]
                )  # noqa
                res["error_reduction_MA"] = (res["error_og"] - res["error_ma"]) / res[
                    "error_og"
                ]  # noqa
                res["error_reduction_model"] = (
                    res["error_og"] - res["error_model"]
                ) / res["error_og"]  # noqa
                # save file
                df = pd.DataFrame(res, index=[0])
                df.to_csv(os.path.join(self.log_path, f"log_{idx:04d}.csv"))
                # plot compare mesh
                plot_fig = UM2N.plot_mesh_compare_benchmark(
                    out.detach().cpu().numpy(),
                    sample.y.detach().cpu().numpy(),
                    sample.face.detach().cpu().numpy(),
                    deform_loss=res["deform_loss"],
                    pde_loss_model=res["error_model"],
                    pde_loss_reduction_model=res["error_reduction_model"],
                    pde_loss_MA=res["error_ma"],
                    pde_loss_reduction_MA=res["error_reduction_MA"],
                    tangle=res["tangled_element"],
                )
                plot_fig.savefig(os.path.join(self.plot_path, f"plot_{idx:04d}.png"))  # noqa
                # plotting (visulisation during sovling)
                plot = False
                if plot is True:
                    self.plot_res()
                    plt.show()
                idx += 1
                plt.close()
                eval_cnt += 1
            # time stepping and prep for next solving iter
            self.t += self.dt
            step += 1
            self.u_fine.assign(self.u_cur_fine)
            self.u_fine_buffer.assign(self.u_cur_fine)
            if eval_cnt >= self.num_samples_to_eval:
                break
        return 
[docs]
    def vis_evaluate(self, sample):
        """
        It would be great if we have some visuals here to assist
        out judgment.
        """
        print("In evaluation VISUALISATION")
        self.mesh.coordinates.dat.data[:] = sample.y
        fd.triplot(self.mesh)
        self.mesh.coordinates.dat.data[:] = self.init_coord
        plt.show()
        return 
[docs]
    def get_error(self):
        # solve on fine mesh
        function_space_fine = fd.FunctionSpace(self.mesh_fine, "CG", 1)
        self.solve_u_fine(self.t)
        u_fine = fd.Function(function_space_fine).project(self.u_cur_fine)  # noqa
        # solve on coarse mesh
        self.mesh.coordinates.dat.data[:] = self.init_coord
        self.project_u_()
        self.solve_u(self.t)
        function_space_coarse = fd.FunctionSpace(self.mesh, "CG", 1)
        # u_og_2_coarse = fd.project(self.u_cur, function_space_coarse)
        u_og_2_coarse = fd.Function(function_space_coarse).project(self.u_cur)  # noqa
        u_og_2_fine = fd.project(self.u_cur, function_space_fine)
        # solve on coarse adapt mesh
        self.mesh.coordinates.dat.data[:] = self.adapt_coord
        self.project_u_()
        self.solve_u(self.t)
        function_space_coarse = fd.FunctionSpace(self.mesh, "CG", 1)
        # u_adapt_2_coarse = fd.project(self.u_cur, function_space_coarse)
        u_adapt_2_coarse = fd.Function(function_space_coarse).project(self.u_cur)  # noqa
        u_adapt_2_fine = fd.project(self.u_cur, function_space_fine)
        # error calculation
        error_og = fd.errornorm(u_fine, u_og_2_fine, norm_type="L2")
        error_adapt = fd.errornorm(u_fine, u_adapt_2_fine, norm_type="L2")
        # put mesh to init state
        self.mesh.coordinates.dat.data[:] = self.init_coord
        # u_fine_raw_data = u_fine.dat.data[:]
        return (
            u_fine,
            u_og_2_fine,
            u_adapt_2_fine,
            u_og_2_coarse,
            u_adapt_2_coarse,
            error_og,
            error_adapt,
        ) 
        # return u_fine_raw_data, u_og_raw_data, u_adapt_raw_data, error_og, error_adapt
[docs]
    def plot_res(self):
        fig = plt.figure(figsize=(15, 10))
        ax1 = fig.add_subplot(2, 3, 1, projection="3d")
        ax1.set_title("Solution on fine mesh")
        fd.trisurf(self.u_cur_fine, axes=ax1)
        ax2 = fig.add_subplot(2, 3, 2, projection="3d")
        ax2.set_title("Solution on original mesh")
        fd.trisurf(self.uh, axes=ax2)
        ax3 = fig.add_subplot(2, 3, 3, projection="3d")
        ax3.set_title("Solution on adapt mesh")
        fd.trisurf(self.uh_new, axes=ax3)
        ax4 = fig.add_subplot(2, 3, 4)
        ax4.set_title("Fine mesh")
        fd.triplot(self.mesh_fine, axes=ax4)
        ax5 = fig.add_subplot(2, 3, 5)
        ax5.set_title("Orignal mesh")
        fd.tripcolor(self.uh, axes=ax5, cmap="coolwarm")
        fd.triplot(self.mesh, axes=ax5)
        ax6 = fig.add_subplot(2, 3, 6)
        ax6.set_title("adapted mesh")
        fd.tripcolor(self.uh_new, axes=ax6, cmap="coolwarm")
        fd.triplot(self.mesh_new, axes=ax6)
        return fig