# Author: Chunyang Wang
# GitHub Username: acse-cw1722
import firedrake as fd
import matplotlib.pyplot as plt # noqa
import UM2N
__all__ = ["compare_error"]
[docs]
def compare_error(
data_in,
mesh,
mesh_fine,
mesh_model,
mesh_MA,
num_tangle,
model_name,
problem_type="helmholtz",
):
# read in params used to construct the analytical solution
σ_x = data_in.dist_params["σ_x"][0]
σ_y = data_in.dist_params["σ_y"][0]
μ_x = data_in.dist_params["μ_x"][0]
μ_y = data_in.dist_params["μ_y"][0]
z = data_in.dist_params["z"][0]
w = data_in.dist_params["w"][0]
simple_u = data_in.dist_params["simple_u"].cpu().numpy()[0]
n_dist = data_in.dist_params["n_dist"].cpu().numpy()[0]
# print('showing dist_params:', data_in.dist_params)
# print("data in ", data_in)
if model_name == "MRTransformer":
model_name = "M2T"
# construct u_exact
u_exact = None
if simple_u: # use sigmas to construct u_exact
def func(x, y):
temp = 0
for i in range(n_dist):
temp += fd.exp(
-1 * ((((x - μ_x[i]) ** 2) + ((y - μ_y[i]) ** 2)) / w[i])
)
return temp
u_exact = func
else: # use w to construct u_exact
def func(x, y):
temp = 0
for i in range(n_dist):
temp += z[i] * fd.exp(
-1
* (
(((x - μ_x[i]) ** 2) / (σ_x[i] ** 2))
+ (((y - μ_y[i]) ** 2) / (σ_y[i] ** 2))
)
)
return temp
u_exact = func
# construct the helmholtz equation
eq = None
if problem_type == "helmholtz":
eq = UM2N.HelmholtzEqGenerator(
params={
"u_exact_func": u_exact,
}
)
elif problem_type == "poisson":
eq = UM2N.PoissonEqGenerator(
params={
"u_exact_func": u_exact,
}
)
# solution on og mesh
og_res = eq.discretise(mesh)
og_solver = UM2N.EquationSolver(
params={
"function_space": og_res["function_space"],
"LHS": og_res["LHS"],
"RHS": og_res["RHS"],
"bc": og_res["bc"],
}
)
uh_og = og_solver.solve_eq()
# solution on MA mesh
mesh_MA.coordinates.dat.data[:] = data_in.y.detach().cpu().numpy()
ma_res = eq.discretise(mesh_MA)
ma_solver = UM2N.EquationSolver(
params={
"function_space": ma_res["function_space"],
"LHS": ma_res["LHS"],
"RHS": ma_res["RHS"],
"bc": ma_res["bc"],
}
)
uh_ma = ma_solver.solve_eq()
# solution on model mesh
uh_model = None
if num_tangle == 0:
model_res = eq.discretise(mesh_model)
model_solver = UM2N.EquationSolver(
params={
"function_space": model_res["function_space"],
"LHS": model_res["LHS"],
"RHS": model_res["RHS"],
"bc": model_res["bc"],
}
)
uh_model = model_solver.solve_eq()
# a high_res mesh
high_res_mesh = mesh_fine
high_res_function_space = fd.FunctionSpace(high_res_mesh, "CG", 1)
# exact solution on high_res mesh
res_high_res = eq.discretise(high_res_mesh)
uh_exact = fd.interpolate(res_high_res["u_exact"], high_res_function_space)
fig, plot_data_dict = UM2N.plot_compare(
mesh_fine,
mesh,
mesh_MA,
mesh_model,
uh_exact,
uh_og,
uh_ma,
uh_model,
data_in.monitor_val[:, 0].detach().cpu().numpy(),
data_in.monitor_val[:, 0].detach().cpu().numpy(),
num_tangle,
model_name,
)
# # a high_res mesh
# high_res_mesh = mesh_fine
# high_res_function_space = fd.FunctionSpace(high_res_mesh, "CG", 1)
# # exact solution on high_res mesh
# res_high_res = eq.discretise(high_res_mesh)
# u_exact = fd.interpolate(res_high_res["u_exact"], high_res_function_space)
# # projections
# uh_model_hr = None
# if num_tangle == 0:
# uh_model_hr = fd.project(uh_model, high_res_function_space)
# uh_og_hr = fd.project(uh_og, high_res_function_space)
# uh_ma_hr = fd.project(uh_ma, high_res_function_space)
# # errornorm calculation
# error_model_mesh = -1
# if num_tangle == 0:
# error_model_mesh = fd.errornorm(u_exact, uh_model_hr)
# error_og_mesh = fd.errornorm(u_exact, uh_og_hr)
# error_ma_mesh = fd.errornorm(u_exact, uh_ma_hr)
# # ==== Plot mesh, solution, error ======================
# plot_data_dict = {}
# rows, cols = 3, 4
# fig, ax = plt.subplots(
# rows, cols, figsize=(cols * 5, rows * 5), layout="compressed"
# )
# # High resolution mesh
# fd.triplot(mesh_fine, axes=ax[0, 0])
# ax[0, 0].set_title(f"High resolution Mesh (100 x 100)")
# # Orginal low resolution uniform mesh
# fd.triplot(mesh, axes=ax[0, 1])
# ax[0, 1].set_title(f"Original uniform Mesh")
# # Adapted mesh (MA)
# fd.triplot(mesh_MA, axes=ax[0, 2])
# ax[0, 2].set_title(f"Adapted Mesh (MA)")
# # Adapted mesh (Model)
# fd.triplot(mesh_model, axes=ax[0, 3])
# ax[0, 3].set_title(f"Adapted Mesh ({model_name})")
# plot_data_dict["mesh_ma"] = mesh_MA.coordinates.dat.data[:]
# plot_data_dict["mesh_model"] = mesh_model.coordinates.dat.data[:]
# cmap = "seismic"
# u_exact_max = u_exact.dat.data[:].max()
# u_og_max = uh_og.dat.data[:].max()
# u_ma_max = uh_ma.dat.data[:].max()
# u_model_max = uh_model.dat.data[:].max() if uh_model else float("-inf")
# solution_v_max = max(u_exact_max, u_og_max, u_ma_max, u_model_max)
# u_exact_min = u_exact.dat.data[:].min()
# u_og_min = uh_og.dat.data[:].min()
# u_ma_min = uh_ma.dat.data[:].min()
# u_model_min = uh_model.dat.data[:].min() if uh_model else float("inf")
# solution_v_min = min(u_exact_min, u_og_min, u_ma_min, u_model_min)
# # Solution on high resolution mesh
# cb = fd.tripcolor(
# u_exact, cmap=cmap, vmax=solution_v_max, vmin=solution_v_min, axes=ax[1, 0]
# )
# ax[1, 0].set_title(f"Solution on High Resolution (u_exact)")
# plt.colorbar(cb)
# # Solution on orginal low resolution uniform mesh
# cb = fd.tripcolor(
# uh_og, cmap=cmap, vmax=solution_v_max, vmin=solution_v_min, axes=ax[1, 1]
# )
# ax[1, 1].set_title(f"Solution on uniform Mesh")
# plt.colorbar(cb)
# # Solution on adapted mesh (MA)
# cb = fd.tripcolor(
# uh_ma, cmap=cmap, vmax=solution_v_max, vmin=solution_v_min, axes=ax[1, 2]
# )
# ax[1, 2].set_title(f"Solution on Adapted Mesh (MA)")
# plt.colorbar(cb)
# if uh_model:
# # Solution on adapted mesh (Model)
# cb = fd.tripcolor(
# uh_model, cmap=cmap, vmax=solution_v_max, vmin=solution_v_min, axes=ax[1, 3]
# )
# ax[1, 3].set_title(f"Solution on Adapted Mesh ({model_name})")
# plt.colorbar(cb)
# plot_data_dict["u_model"] = uh_model.dat.data[:]
# plot_data_dict["u_exact"] = u_exact.dat.data[:]
# plot_data_dict["u_original"] = uh_og.dat.data[:]
# plot_data_dict["u_ma"] = uh_ma.dat.data[:]
# plot_data_dict["u_v_max"] = solution_v_max
# plot_data_dict["u_v_min"] = solution_v_min
# err_orignal_mesh = fd.assemble(uh_og_hr - u_exact)
# err_adapted_mesh_ma = fd.assemble(uh_ma_hr - u_exact)
# if uh_model_hr:
# err_adapted_mesh_model = fd.assemble(uh_model_hr - u_exact)
# err_abs_max_val_adapted_mesh_model = max(
# abs(err_adapted_mesh_model.dat.data[:].max()),
# abs(err_adapted_mesh_model.dat.data[:].min()),
# )
# else:
# err_abs_max_val_adapted_mesh_model = 0.0
# err_abs_max_val_ori = max(
# abs(err_orignal_mesh.dat.data[:].max()), abs(err_orignal_mesh.dat.data[:].min())
# )
# err_abs_max_val_adapted_ma = max(
# abs(err_adapted_mesh_ma.dat.data[:].max()),
# abs(err_adapted_mesh_ma.dat.data[:].min()),
# )
# err_abs_max_val = max(
# max(err_abs_max_val_ori, err_abs_max_val_adapted_ma),
# err_abs_max_val_adapted_mesh_model,
# )
# err_v_max = err_abs_max_val
# err_v_min = -err_v_max
# # Visualize the monitor values of MA
# monitor_val = data_in.monitor_val
# monitor_val_vis_holder = fd.Function(ma_res["function_space"])
# monitor_val_vis_holder.dat.data[:] = monitor_val[:, 0].detach().cpu().numpy()
# # Monitor values
# cb = fd.tripcolor(monitor_val_vis_holder, cmap=cmap, axes=ax[2, 0])
# ax[2, 0].set_title(f"Monitor values")
# plt.colorbar(cb)
# # Error on orginal low resolution uniform mesh
# cb = fd.tripcolor(
# err_orignal_mesh, cmap=cmap, axes=ax[2, 1], vmax=err_v_max, vmin=err_v_min
# )
# ax[2, 1].set_title(f"Error (u-u_exact) uniform Mesh | L2 Norm: {error_og_mesh:.5f}")
# plt.colorbar(cb)
# # Error on adapted mesh (MA)
# cb = fd.tripcolor(
# err_adapted_mesh_ma, cmap=cmap, axes=ax[2, 2], vmax=err_v_max, vmin=err_v_min
# )
# ax[2, 2].set_title(
# f"Error (u-u_exact) MA| L2 Norm: {error_ma_mesh:.5f} | {(error_og_mesh-error_ma_mesh)/error_og_mesh*100:.2f}%"
# )
# plt.colorbar(cb)
# if uh_model_hr:
# # Error on adapted mesh (Model)
# cb = fd.tripcolor(
# err_adapted_mesh_model,
# cmap=cmap,
# axes=ax[2, 3],
# vmax=err_v_max,
# vmin=err_v_min,
# )
# ax[2, 3].set_title(
# f"Error (u-u_exact) {model_name}| L2 Norm: {error_model_mesh:.5f} | {(error_og_mesh-error_model_mesh)/error_og_mesh*100:.2f}%"
# )
# plt.colorbar(cb)
# plot_data_dict["error_map_model"] = err_adapted_mesh_model.dat.data[:]
# plot_data_dict["error_norm_model"] = error_model_mesh
# plot_data_dict["monitor_values"] = monitor_val_vis_holder.dat.data[:]
# plot_data_dict["error_map_original"] = err_orignal_mesh.dat.data[:]
# plot_data_dict["error_map_ma"] = err_adapted_mesh_ma.dat.data[:]
# plot_data_dict["error_norm_original"] = error_og_mesh
# plot_data_dict["error_norm_ma"] = error_ma_mesh
# # For visualization
# plot_data_dict["error_v_max"] = err_v_max
# for rr in range(rows):
# for cc in range(cols):
# ax[rr, cc].set_aspect("equal", "box")
error_og_mesh = plot_data_dict["error_norm_original"]
error_ma_mesh = plot_data_dict["error_norm_ma"]
error_model_mesh = plot_data_dict["error_norm_model"]
return {
"error_model_mesh": error_model_mesh,
"error_og_mesh": error_og_mesh,
"error_ma_mesh": error_ma_mesh,
"u_exact": u_exact,
"plot_more": fig,
"plot_data_dict": plot_data_dict,
}