# Author: Chunyang Wang
# GitHub Username: chunyang-w
import os # noqa
import random # noqa
import time # noqa
from pprint import pprint # noqa
import firedrake as fd
import matplotlib.pyplot as plt # noqa
import numpy as np # noqa
import pandas as pd # noqa
import torch # noqa
from torch_geometric.loader import DataLoader
import UM2N
from UM2N.model.train_util import generate_samples
[docs]
def get_log_og(log_path, idx):
"""
Read log file from dataset log dir and return value in it
"""
df = pd.read_csv(os.path.join(log_path, f"log{idx}.csv"))
return {
"error_og": df["error_og"][0],
"error_adapt": df["error_adapt"][0],
"time": df["time"][0],
}
[docs]
def get_first_entry(dataset, target_idx):
for i in range(len(dataset)):
raw_data_path = dataset.file_names[i]
raw_data = np.load(raw_data_path, allow_pickle=True).item()
# pprint(raw_data)
print(raw_data.get("idx"), " ", target_idx)
# print(raw_data.get('t'))
if raw_data.get("idx") == target_idx:
return i
[docs]
class BurgersEvaluator:
"""
Solves the Burgers equation
Input:
- mesh: The mesh on which to solve the equation.
- dist_params: The parameters of the Gaussian distribution.
"""
def __init__(
self,
mesh,
mesh_fine,
mesh_new,
dataset,
model,
eval_dir,
ds_root,
idx,
**kwargs,
): # noqa
"""
Initialise the solver.
kwargs:
- nu: The viscosity of the fluid.
- dt: The time interval.
"""
self.device = kwargs.pop("device", "cuda")
self.model_used = kwargs.pop("model_used", "MRTransformer")
# Mesh
self.mesh = mesh
self.mesh_fine = mesh_fine
self.mesh_new = mesh_new
# evaluation vars
self.dataset = dataset # dataset containing all data
self.model = model # the NN model
self.eval_dir = eval_dir # evaluation root dir
self.ds_root = ds_root
self.log_path = os.path.join(eval_dir, "log")
self.plot_path = os.path.join(eval_dir, "plot")
self.plot_more_path = os.path.join(eval_dir, "plot_more")
self.plot_data_path = os.path.join(eval_dir, "plot_data")
self.idx = idx
# coordinates
self.init_coord = self.mesh.coordinates.vector().array().reshape(-1, 2)
self.init_coord_fine = (
self.mesh_fine.coordinates.vector().array().reshape(-1, 2)
) # noqa
self.best_coord = self.mesh.coordinates.vector().array().reshape(-1, 2)
self.adapt_coord = self.mesh.coordinates.vector().array().reshape(-1, 2) # noqa
self.error_adapt_list = []
self.error_og_list = []
self.best_error_iter = 0
# X and Y coordinates
self.x, self.y = fd.SpatialCoordinate(mesh)
self.x_fine, self.y_fine = fd.SpatialCoordinate(self.mesh_fine)
# Function spaces
self.P1 = fd.FunctionSpace(mesh, "CG", 1)
self.P2 = fd.FunctionSpace(mesh, "CG", 2)
self.P1_vec = fd.VectorFunctionSpace(mesh, "CG", 1)
self.P2_vec = fd.VectorFunctionSpace(mesh, "CG", 2)
self.P1_ten = fd.TensorFunctionSpace(mesh, "CG", 1)
self.P2_ten = fd.TensorFunctionSpace(mesh, "CG", 2)
self.P1_fine = fd.FunctionSpace(self.mesh_fine, "CG", 1)
self.P2_vec_fine = fd.VectorFunctionSpace(self.mesh_fine, "CG", 2)
self.phi_p2_vec_fine = fd.TestFunction(self.P2_vec_fine)
# Test functions
self.phi = fd.TestFunction(self.P1)
self.phi_p2_vec = fd.TestFunction(self.P2_vec)
self.trial_fine = fd.TrialFunction(self.P1_fine)
self.phi_fine = fd.TestFunction(self.P1_fine)
# buffer
self.u_fine_buffer = fd.Function(self.P2_vec_fine)
self.coarse_adapt = fd.Function(self.P1_vec)
self.coarse_2_fine = fd.Function(self.P2_vec_fine)
self.coarse_2_fine_original = fd.Function(self.P2_vec_fine)
# simulation params
self.nu = kwargs.pop("nu", 1e-3)
self.gauss_list = kwargs.pop("gauss_list", None)
self.dt = kwargs.get("dt", 1.0 / 30)
self.sim_len = kwargs.get("T", 2.0)
self.T = self.sim_len
self.dtc = fd.Constant(self.dt)
self.u_init = 0
self.u_init_fine = 0
num_of_gauss = len(self.gauss_list)
for counter in range(num_of_gauss):
c_x, c_y, w = (
self.gauss_list[counter]["cx"],
self.gauss_list[counter]["cy"],
self.gauss_list[counter]["w"],
) # noqa
self.u_init += fd.exp(-((self.x - c_x) ** 2 + (self.y - c_y) ** 2) / w) # noqa
self.u_init_fine += fd.exp(
-((self.x_fine - c_x) ** 2 + (self.y_fine - c_y) ** 2) / w
) # noqa
# solution vars
self.u_og = fd.Function(self.P2_vec) # u_{0}
self.u = fd.Function(self.P2_vec) # u_{n+1}
self.u_ = fd.Function(self.P2_vec) # u_{n}
self.F = (
fd.inner((self.u - self.u_) / self.dtc, self.phi_p2_vec)
+ fd.inner(fd.dot(self.u, fd.nabla_grad(self.u)), self.phi_p2_vec)
+ self.nu * fd.inner(fd.grad(self.u), fd.grad(self.phi_p2_vec))
) * fd.dx(domain=self.mesh)
self.u_fine = fd.Function(self.P2_vec_fine) # u_{0}
self.u_fine_ = fd.Function(self.P2_vec_fine) # u_{n+1}
self.F_fine = (
fd.inner((self.u_fine - self.u_fine_) / self.dtc, self.phi_p2_vec_fine)
+ fd.inner(
fd.dot(self.u_fine, fd.nabla_grad(self.u_fine)),
self.phi_p2_vec_fine, # noqa
)
+ self.nu * fd.inner(fd.grad(self.u_fine), fd.grad(self.phi_p2_vec_fine))
) * fd.dx(domain=self.mesh_fine)
# initial vals
self.initial_velocity = fd.as_vector([self.u_init, 0])
self.initial_velocity_fine = fd.as_vector([self.u_init_fine, 0])
self.u.project(self.initial_velocity)
self.u_.assign(self.u)
self.u_og.assign(self.u)
ic_fine = fd.project(self.initial_velocity_fine, self.P2_vec_fine)
self.u_fine.assign(ic_fine)
self.u_fine_.assign(ic_fine)
self.u_fine_buffer.assign(ic_fine)
# solver params
self.sp = {
"mat_type": "aij",
"ksp_type": "preonly",
"pc_type": "lu",
"pc_factor_mat_solver_type": "mumps",
}
[docs]
def project_u_(self):
self.u_.project(self.u_fine_buffer)
[docs]
def eval_problem(self):
"""
Solves the Burgers equation.
"""
print("target index", self.idx)
idx_start = get_first_entry(self.dataset, self.idx)
print("idx_start: ", idx_start)
i = 0
t = 0.0
self.step = 0
self.best_error_iter = 0
res = {
"deform_loss": None, # 1. nodal position loss
"tangled_element": None, # 2. tangled elements on a mesh # noqa
"error_og": None, # 3. PDE error on original uniform mesh # noqa
"error_model": None, # 4. PDE error on model generated mesh # noqa
"error_ma": None, # 5. PDE error on MA generated mesh # noqa
"error_reduction_MA": None, # 6. PDE error reduced by using MA mesh # noqa
"error_reduction_model": None, # 7. PDE error reduced by using model mesh # noqa
"time_consumption_model": None, # 8. time consumed generating mesh inferenced by the model # noqa
"time_consumption_MA": None, # 9. time consumed generating mesh by Monge-Ampere method # noqa
"acceration_ratio": None, # 10. time_consumption_ma / time_consumption_model # noqa
}
while t < self.T - 0.5 * self.dt:
# get model raw file:
cur_step = idx_start + i
raw_data_path = self.dataset.file_names[cur_step]
raw_data = np.load(raw_data_path, allow_pickle=True).item()
# get sample for item
self.error_adapt_list = []
self.error_og_list = []
print("step: {}, t: {}".format(self.step, t))
# solve on fine mesh
fd.solve(self.F_fine == 0, self.u_fine)
# PDE error measuring
print("cur_step: ", cur_step)
print("compare:", t, raw_data.get("t"), raw_data.get("idx"), self.idx)
if (abs(t - raw_data.get("t")) < 1e-5) and raw_data.get("idx") == self.idx:
print("in here", t, raw_data.get("t"), raw_data.get("idx"))
sample = next(
iter(
DataLoader(
[self.dataset[cur_step]], batch_size=1, shuffle=False
)
)
)
self.model.eval()
bs = 1
sample = sample.to(self.device)
self.model = self.model.to(self.device)
with torch.no_grad():
start = time.perf_counter()
if self.model_used == "MRTransformer" or self.model_used == "MRT":
# Create mesh query for deformer, seperate from the original mesh as feature for encoder
mesh_query_x = (
sample.mesh_feat[:, 0].view(-1, 1).detach().clone()
)
mesh_query_y = (
sample.mesh_feat[:, 1].view(-1, 1).detach().clone()
)
mesh_query_x.requires_grad = True
mesh_query_y.requires_grad = True
mesh_query = torch.cat([mesh_query_x, mesh_query_y], dim=-1)
num_nodes = mesh_query.shape[-2] // bs
# Generate random mesh queries for unsupervised learning
sampled_queries = generate_samples(
bs=bs,
num_samples_per_mesh=num_nodes,
num_meshes=5,
data=sample,
device=self.device,
)
mesh_sampled_queries_x = (
sampled_queries[:, :, 0].view(-1, 1).detach()
)
mesh_sampled_queries_y = (
sampled_queries[:, :, 1].view(-1, 1).detach()
)
mesh_sampled_queries_x.requires_grad = True
mesh_sampled_queries_y.requires_grad = True
coord_ori_x = sample.mesh_feat[:, 0].view(-1, 1)
coord_ori_y = sample.mesh_feat[:, 1].view(-1, 1)
coord_ori_x.requires_grad = True
coord_ori_y.requires_grad = True
coord_ori = torch.cat([coord_ori_x, coord_ori_y], dim=-1)
num_nodes = coord_ori.shape[-2] // bs
input_q = sample.mesh_feat[:, :4]
(output_coord_all, output, out_monitor), (phix, phiy) = (
self.model(
sample,
input_q,
input_q,
mesh_query,
sampled_queries=None,
sampled_queries_edge_index=None,
)
)
out = output_coord_all[: num_nodes * bs]
elif self.model_used == "M2N":
out = self.model(sample)
elif self.model_used == "MRN":
out = self.model(sample)
else:
raise Exception(f"model {self.model_used} not implemented.")
end = time.perf_counter()
dur_ms = (end - start) * 1000
# check mesh integrity - Only perform evaluation on non-tangling mesh # noqa
num_tangle = UM2N.get_sample_tangle(out, sample.x[:, :2], sample.face) # noqa
if isinstance(num_tangle, torch.Tensor):
num_tangle = num_tangle.item()
if num_tangle > 0: # has tangled elems:
res["tangled_element"] = num_tangle
res["error_model"] = -1
else: # mesh is valid, perform evaluation: 1.
res["tangled_element"] = num_tangle
# perform PDE error analysis on model generated mesh
self.adapt_coord = out.detach().cpu().numpy()
_, error_model = self.get_error()
res["error_model"] = error_model
# get time_MA by reading log file
res["time_consumption_MA"] = get_log_og(
os.path.join(self.ds_root, "log"), (cur_step + 1)
)["time"]
# metric calculation
res["deform_loss"] = 1000 * torch.nn.L1Loss()(out, sample.y).item()
res["time_consumption_model"] = dur_ms
res["acceration_ratio"] = (
res["time_consumption_MA"] / res["time_consumption_model"]
) # noqa
# solution calculation
mesh_new = self.mesh_new
self.adapt_coord = sample.y.cpu()
mesh_new.coordinates.dat.data[:] = self.adapt_coord
# calculate solution on original mesh
self.mesh.coordinates.dat.data[:] = self.init_coord
self.project_u_()
fd.solve(self.F == 0, self.u)
function_space = fd.FunctionSpace(self.mesh, "CG", 1)
uh_0 = fd.Function(function_space)
uh_0.project(self.u[0])
# calculate solution on adapted mesh
self.mesh.coordinates.dat.data[:] = self.adapt_coord
self.project_u_()
fd.solve(self.F == 0, self.u)
function_space_new = fd.FunctionSpace(mesh_new, "CG", 1)
function_space_vec_new = fd.VectorFunctionSpace(mesh_new, "CG", 1)
uh_new = fd.Function(function_space_vec_new)
uh_new.project(self.u)
uh_new_0 = fd.Function(function_space_new)
uh_new_0.project(uh_new[0])
error_og, error_adapt = self.get_error()
print("error_og: {}, error_adapt: {}".format(error_og, error_adapt)) # noqa
res["error_og"] = error_og
res["error_ma"] = error_adapt
res["error_reduction_MA"] = (res["error_og"] - res["error_ma"]) / res[
"error_og"
] # noqa
res["error_reduction_model"] = (
res["error_og"] - res["error_model"]
) / res["error_og"] # noqa
# save file
df = pd.DataFrame(res, index=[0])
df.to_csv(os.path.join(self.log_path, f"log{self.idx}_{cur_step}.csv")) # noqa
# plot compare mesh
compare_plot = UM2N.plot_mesh_compare_benchmark(
out.detach().cpu().numpy(),
sample.y.detach().cpu().numpy(),
sample.face.detach().cpu().numpy(),
res["deform_loss"],
res["error_model"],
res["error_reduction_model"],
res["error_ma"],
res["error_reduction_MA"],
res["tangled_element"],
)
compare_plot.savefig(
os.path.join(self.plot_path, f"plot_{self.idx}_{cur_step}.png")
) # noqa
# put coords back to original position (for u sampling)
self.mesh.coordinates.dat.data[:] = self.init_coord
# 3D plot of model solution
# more detailed plot - 3d plot and 2d plot with mesh
fig = plt.figure(figsize=(8, 8))
# 3D plot of MA solution TODO
ax1 = fig.add_subplot(2, 2, 1, projection="3d")
ax1.set_title("MA Solution (3D)")
fd.trisurf(uh_new_0, axes=ax1)
if num_tangle == 0:
# solve on coarse adapt mesh
function_space_fine = fd.FunctionSpace(self.mesh_fine, "CG", 1) # noqa
self.mesh.coordinates.dat.data[:] = out.detach().cpu().numpy() # noqa
function_space = fd.FunctionSpace(self.mesh, "CG", 1)
self.project_u_()
fd.solve(self.F == 0, self.u)
u_adapt_coarse_0 = fd.Function(function_space)
u_adapt_coarse_0.project(self.u[0])
# old
# self.adapt_coord = out.detach().cpu().numpy()
# self.mesh.coordinates.dat.data[:] = self.adapt_coord
# self.mesh_new.coordinates.dat.data[:] = self.adapt_coord
# self.project_u_()
# self.solve_u(self.t)
# function_space_new = fd.FunctionSpace(self.mesh_new, "CG", 1) # noqa
# uh_model = fd.Function(function_space_new).project(self.u_cur) # noqa
ax2 = fig.add_subplot(2, 2, 2, projection="3d")
ax2.set_title("Model Solution (3D)")
fd.trisurf(u_adapt_coarse_0, axes=ax2)
# 2d plot and mesh for Model TODO
ax4 = fig.add_subplot(2, 2, 4)
ax4.set_title("Soultion on Model mesh")
fd.tripcolor(u_adapt_coarse_0, cmap="coolwarm", axes=ax4)
self.mesh_new.coordinates.dat.data[:] = out.detach().cpu().numpy() # noqa
fd.triplot(self.mesh_new, axes=ax4)
# 2d plot and mesh for MA
ax3 = fig.add_subplot(2, 2, 3)
ax3.set_title("Soultion on MA mesh")
fd.tripcolor(uh_new_0, cmap="coolwarm", axes=ax3)
self.mesh_new.coordinates.dat.data[:] = sample.y.detach().cpu().numpy()
fd.triplot(self.mesh_new, axes=ax3)
fig.savefig(
os.path.join(self.plot_more_path, f"plot_{self.idx}_{cur_step}.png")
) # noqa
plt.close()
i += 1
# step forward in time
self.u_fine_.assign(self.u_fine)
# self.u_fine_buffer.project(self.u)
self.u_fine_buffer.assign(self.u_fine)
# fd.triplot(self.u_fine)
# plt.show()
# self.u_.assign(self.u)
t += self.dt
self.step += 1
return
[docs]
def get_error(self):
# print("get_error: u_ sum is: ", np.sum(self.u_.dat.data[:]))
function_space_fine = fd.FunctionSpace(self.mesh_fine, "CG", 1)
# solve on fine mesh
fd.solve(self.F_fine == 0, self.u_fine)
u_fine_0 = fd.Function(function_space_fine)
u_f = u_fine_0.project(self.u_fine[0])
# print('u_f sum: ', np.sum(u_f.dat.data[:]))
# solve on coarse mesh
self.mesh.coordinates.dat.data[:] = self.init_coord
function_space = fd.FunctionSpace(self.mesh, "CG", 1)
self.project_u_()
# print("og u_ sum: ", np.sum(self.u_.dat.data[:]))
fd.solve(self.F == 0, self.u)
u_0_fine = fd.Function(function_space_fine)
u_0_coarse = fd.Function(function_space)
u_0_coarse.project(self.u[0])
u_0_fine.project(u_0_coarse)
# print('u_0_fine sum 1: ', np.sum(u_0_fine.dat.data[:]))
error_og = fd.errornorm(u_0_fine, u_f, norm_type="L2")
# solve on coarse adapt mesh
self.mesh.coordinates.dat.data[:] = self.adapt_coord
function_space = fd.FunctionSpace(self.mesh, "CG", 1)
self.project_u_()
# print("adapt u_ sum: ", np.sum(self.u_.dat.data[:]))
fd.solve(self.F == 0, self.u)
u_adapt_fine_0 = fd.Function(function_space_fine)
u_adapt_coarse_0 = fd.Function(function_space)
u_adapt_coarse_0.project(self.u[0])
u_adapt_fine_0.project(u_adapt_coarse_0)
# print('u sum 2: ', np.sum(u_adapt_fine_0.dat.data[:]))
error_adapt = fd.errornorm(u_adapt_fine_0, u_f, norm_type="L2")
self.mesh.coordinates.dat.data[:] = self.init_coord
return error_og, error_adapt
[docs]
def make_log_dir(self):
UM2N.mkdir_if_not_exist(self.log_path)
[docs]
def make_plot_dir(self):
UM2N.mkdir_if_not_exist(self.plot_path)
[docs]
def make_plot_more_dir(self):
UM2N.mkdir_if_not_exist(self.plot_more_path)
[docs]
def make_plot_data_dir(self):
UM2N.mkdir_if_not_exist(self.plot_data_path)