"""
Sequences of meshes corresponding to a :class:`~.TimePartition`.
"""
from collections.abc import Iterable
import firedrake
import numpy as np
from animate.interpolation import transfer
from animate.quality import QualityMeasure
from animate.utility import Mesh, function_data_max
from firedrake.adjoint import pyadjoint
from firedrake.mesh import MeshSequenceGeometry
from firedrake.petsc import PETSc
from firedrake.pyplot import triplot
from .function_data import ForwardSolutionData
from .log import DEBUG, debug, info, logger, pyrint, warning
from .options import AdaptParameters
from .utility import AttrDict
__all__ = ["MeshSeq"]
[docs]
class MeshSeq:
"""
A sequence of meshes for solving a PDE associated with a particular
:class:`~.TimePartition` of the temporal domain.
"""
@PETSc.Log.EventDecorator()
def __init__(self, time_partition, initial_meshes, **kwargs):
r"""
:arg time_partition: a partition of the temporal domain
:type time_partition: :class:`~.TimePartition`
:arg initial_meshes: a list of meshes corresponding to the subinterval of the
time partition, or a single mesh to use for all subintervals
:type initial_meshes: :class:`list` or :class:`~.MeshGeometry`
:kwarg get_initial_condition: a function as described in
:meth:`~.MeshSeq.get_initial_condition`
:kwarg get_solver: a function as described in :meth:`~.MeshSeq.get_solver`
:kwarg transfer_method: the method to use for transferring fields between
meshes. Options are "project" (default) and "interpolate". See
:func:`animate.interpolation.transfer` for details
:type transfer_method: :class:`str`
:kwarg transfer_kwargs: kwargs to pass to the chosen transfer method
:type transfer_kwargs: :class:`dict` with :class:`str` keys and values which may
take various types
"""
self.time_partition = time_partition
self.subintervals = time_partition.subintervals
self.num_subintervals = time_partition.num_subintervals
self.field_names = time_partition.field_names
self.field_metadata = time_partition.field_metadata
self.solution_names = [
fieldname
for fieldname in self.field_names
if self.field_metadata[fieldname].solved_for
]
# Create a dictionary to hold field Functions with field names as keys and None
# as values
self.field_functions = dict.fromkeys(self.field_metadata)
self.set_meshes(initial_meshes)
self._fs = None
if "get_function_spaces" in kwargs:
raise KeyError(
"get_function_spaces is no longer supported. Specify the finite_element"
" argument for the Field class instead."
)
self._get_initial_condition = kwargs.get("get_initial_condition")
self._get_solver = kwargs.get("get_solver")
self._transfer_method = kwargs.get("transfer_method", "project")
self._transfer_kwargs = kwargs.get("transfer_kwargs", {})
self.steady = time_partition.steady
self.check_convergence = np.array([True] * len(self), dtype=bool)
self.converged = np.array([False] * len(self), dtype=bool)
self.fp_iteration = 0
self.params = None
self.sections = [{} for mesh in self]
self._outputs_consistent()
def __str__(self):
return f"{[str(mesh) for mesh in self.meshes]}"
def __repr__(self):
name = type(self).__name__
if len(self) == 1:
return f"{name}([{repr(self.meshes[0])}])"
elif len(self) == 2:
return f"{name}([{repr(self.meshes[0])}, {repr(self.meshes[1])}])"
else:
return f"{name}([{repr(self.meshes[0])}, ..., {repr(self.meshes[-1])}])"
[docs]
def debug(self, msg):
"""
Print a ``debug`` message.
:arg msg: the message to print
:type msg: :class:`str`
"""
debug(f"{type(self).__name__}: {msg}")
[docs]
def warning(self, msg):
"""
Print a ``warning`` message.
:arg msg: the message to print
:type msg: :class:`str`
"""
warning(f"{type(self).__name__}: {msg}")
[docs]
def info(self, msg):
"""
Print an ``info`` level message.
:arg msg: the message to print
:type msg: :class:`str`
"""
info(f"{type(self).__name__}: {msg}")
def __len__(self):
return len(self.meshes)
def __getitem__(self, subinterval):
"""
:arg subinterval: a subinterval index
:type subinterval: :class:`int`
:returns: the corresponding mesh
:rtype: :class:`firedrake.MeshGeometry`
"""
return self.meshes[subinterval]
def __setitem__(self, subinterval, mesh):
"""
:arg subinterval: a subinterval index
:type subinterval: :class:`int`
:arg mesh: the mesh to use for that subinterval
:type subinterval: :class:`firedrake.MeshGeometry`
"""
self.meshes[subinterval] = mesh
[docs]
def count_elements(self):
r"""
Count the number of elements in each mesh in the sequence.
:returns: list of element counts
:rtype: :class:`list` of :class:`int`\s
"""
comm = firedrake.COMM_WORLD
return [comm.allreduce(mesh.coordinates.cell_set.size) for mesh in self]
[docs]
def count_vertices(self):
r"""
Count the number of vertices in each mesh in the sequence.
:returns: list of vertex counts
:rtype: :class:`list` of :class:`int`\s
"""
comm = firedrake.COMM_WORLD
return [comm.allreduce(mesh.coordinates.node_set.size) for mesh in self]
def _reset_counts(self):
"""
Reset the lists of element and vertex counts.
"""
self.element_counts = [self.count_elements()]
self.vertex_counts = [self.count_vertices()]
[docs]
def set_meshes(self, meshes):
r"""
Set all meshes in the sequence and deduce various properties.
:arg meshes: list of meshes to use in the sequence, or a single mesh to use for
all subintervals
:type meshes: :class:`list` of :class:`firedrake.MeshGeometry`\s or
:class:`firedrake.MeshGeometry`
"""
# TODO #122: Refactor to use the set method
if not isinstance(meshes, list):
meshes = [Mesh(meshes) for subinterval in self.subintervals]
self.meshes = meshes
dim = np.array([mesh.topological_dimension for mesh in meshes])
if dim.min() != dim.max():
raise ValueError("Meshes must all have the same topological dimension.")
self.dim = dim.min()
self._reset_counts()
if logger.level == DEBUG:
for i, mesh in enumerate(meshes):
nc = self.element_counts[0][i]
nv = self.vertex_counts[0][i]
qm = QualityMeasure(mesh)
ar = qm("aspect_ratio")
mar = function_data_max(ar)
self.debug(
f"{i}: {nc:7d} cells, {nv:7d} vertices, max aspect ratio {mar:.2f}"
)
debug(100 * "-")
[docs]
def plot(self, fig=None, axes=None, **kwargs):
"""
Plot the meshes comprising a 2D :class:`~.MeshSeq`.
:kwarg fig: matplotlib figure to use
:type fig: :class:`matplotlib.figure.Figure`
:kwarg axes: matplotlib axes to use
:type axes: :class:`matplotlib.axes._axes.Axes`
:returns: matplotlib figure and axes for the plots
:rtype1: :class:`matplotlib.figure.Figure`
:rtype2: :class:`matplotlib.axes._axes.Axes`
All keyword arguments are passed to :func:`firedrake.pyplot.triplot`.
"""
from matplotlib.pyplot import subplots
if self.dim != 2:
raise ValueError("MeshSeq plotting only supported in 2D.")
# Process kwargs
interior_kw = {"edgecolor": "k"}
interior_kw.update(kwargs.pop("interior_kw", {}))
boundary_kw = {"edgecolor": "k"}
boundary_kw.update(kwargs.pop("boundary_kw", {}))
kwargs["interior_kw"] = interior_kw
kwargs["boundary_kw"] = boundary_kw
if fig is None or axes is None:
n = len(self)
fig, axes = subplots(ncols=n, nrows=1, figsize=(5 * n, 5))
# Loop over all axes and plot the meshes
k = 0
if not isinstance(axes, Iterable):
axes = [axes]
for i, axis in enumerate(axes):
if not isinstance(axis, Iterable):
axis = [axis]
for ax in axis:
ax.set_title(f"MeshSeq[{k}]")
triplot(self.meshes[k], axes=ax, **kwargs)
ax.axis(False)
k += 1
if len(axis) == 1:
axes[i] = axis[0]
if len(axes) == 1:
axes = axes[0]
return fig, axes
def _get_field_metadata(self, fieldname):
if fieldname not in self.field_names:
raise ValueError(f"Field '{fieldname}' is not associated with the MeshSeq.")
return self.field_metadata[fieldname]
[docs]
def get_function_spaces(self, mesh):
"""
Construct the function spaces corresponding to each field, for a given mesh.
:arg mesh: the mesh to base the function spaces on
:type mesh: :class:`firedrake.mesh.MeshGeometry`
:returns: a dictionary whose keys are field names and whose values are the
corresponding function spaces
:rtype: :class:`dict` with :class:`str` keys and
:class:`firedrake.functionspaceimpl.FunctionSpace` values
"""
function_spaces = {}
for fieldname, field in self.field_metadata.items():
function_spaces[fieldname] = field.get_function_space(mesh)
return function_spaces
[docs]
def get_initial_condition(self):
r"""
Get the initial conditions applied on the first mesh in the sequence.
:returns: the dictionary, whose keys are field names and whose values are the
corresponding initial conditions applied
:rtype: :class:`dict` with :class:`str` keys and
:class:`firedrake.function.Function` values
"""
if self._get_initial_condition is not None:
return self._get_initial_condition(self)
return {
fieldname: firedrake.Function(fs[0])
for fieldname, fs in self.function_spaces.items()
}
[docs]
def get_solver(self):
"""
Get the function mapping a subinterval index and an initial condition dictionary
to a dictionary of solutions for the corresponding solver step.
Signature for the function to be returned:
```
:arg index: the subinterval index
:type index: :class:`int`
:arg ic: map from fields to the corresponding initial condition components
:type ic: :class:`dict` with :class:`str` keys and
:class:`firedrake.function.Function` values
:return: map from fields to the corresponding solutions
:rtype: :class:`dict` with :class:`str` keys and
:class:`firedrake.function.Function` values
```
:returns: the function for obtaining the solver
:rtype: see docstring above
"""
if self._get_solver is None:
raise NotImplementedError("'get_solver' needs implementing.")
return self._get_solver(self)
def _transfer(self, source, target_space, **kwargs):
"""
Transfer a field between meshes using the specified transfer method.
:arg source: the function to be transferred
:type source: :class:`firedrake.function.Function` or
:class:`firedrake.cofunction.Cofunction`
:arg target_space: the function space which we seek to transfer onto, or the
function or cofunction to use as the target
:type target_space: :class:`firedrake.functionspaceimpl.FunctionSpace`,
:class:`firedrake.function.Function`
or :class:`firedrake.cofunction.Cofunction`
:returns: the transferred function
:rtype: :class:`firedrake.function.Function` or
:class:`firedrake.cofunction.Cofunction`
Extra keyword arguments are passed to :func:`goalie.interpolation.transfer`.
"""
# Update kwargs with those specified by the user
transfer_kwargs = kwargs.copy()
transfer_kwargs.update(self._transfer_kwargs)
return transfer(source, target_space, self._transfer_method, **transfer_kwargs)
def _outputs_consistent(self):
"""
Assert that function spaces and initial conditions are given in a
dictionary format with the same keys as :attr:`MeshSeq.field_metadata`.
"""
for method in ["initial_condition", "solver"]:
if getattr(self, f"_get_{method}") is None:
continue
method_map = getattr(self, f"get_{method}")
if method == "initial_condition":
method_map = method_map()
elif method == "solver":
self._reinitialise_fields(self.get_initial_condition())
solver_gen = method_map()(0)
assert hasattr(solver_gen, "__next__"), "solver should yield"
if logger.level == DEBUG:
next(solver_gen)
f, f_ = self.field_functions[next(iter(self.field_functions))]
if np.array_equal(f.dat.data_ro, f_.dat.data_ro):
self.debug(
"Current and lagged solutions are equal. Does the"
" solver yield before updating lagged solutions?"
) # noqa
break
assert isinstance(method_map, dict), f"get_{method} should return a dict"
mesh_seq_fields = set(self.field_functions)
method_fields = set(method_map.keys())
diff = mesh_seq_fields.difference(method_fields)
assert len(diff) == 0, f"missing fields {diff} in get_{method}"
diff = method_fields.difference(mesh_seq_fields)
assert len(diff) == 0, f"unexpected fields {diff} in get_{method}"
def _function_spaces_consistent(self):
"""
Determine whether the mesh sequence's function spaces are consistent with its
meshes.
:returns: ``True`` if the meshes and function spaces are consistent, otherwise
``False``
:rtype: `:class:`bool`
"""
consistent = len(self.time_partition) == len(self)
consistent &= all(
len(self) == len(self._fs[fieldname]) for fieldname in self.field_functions
)
for fieldname in self.field_functions:
if isinstance(self._fs[fieldname][0].mesh(), MeshSequenceGeometry):
consistent &= all(
mesh1 == mesh2
for mesh1, fs in zip(self.meshes, self._fs[fieldname], strict=True)
for mesh2 in fs.mesh()
)
else:
consistent &= all(
mesh == fs.mesh()
for mesh, fs in zip(self.meshes, self._fs[fieldname], strict=True)
)
consistent &= all(
self._fs[fieldname][0].ufl_element() == fs.ufl_element()
for fs in self._fs[fieldname]
)
return consistent
def _update_function_spaces(self):
"""
Update the function space dictionary associated with the mesh sequence.
"""
if self._fs is None or not self._function_spaces_consistent():
self._fs = AttrDict(
{
fieldname: [
self.get_function_spaces(mesh)[fieldname] for mesh in self
]
for fieldname in self.field_functions
}
)
assert self._function_spaces_consistent(), (
"Meshes and function spaces are inconsistent"
)
@property
def function_spaces(self):
"""
Get the function spaces associated with the mesh sequence.
:returns: a dictionary whose keys are field names and whose values are the
corresponding function spaces
:rtype: :class:`~.AttrDict` with :class:`str` keys and
:class:`firedrake.functionspaceimpl.FunctionSpace` values
"""
self._update_function_spaces()
return self._fs
@property
def initial_condition(self):
"""
Get the initial conditions associated with the first subinterval.
:returns: a dictionary whose keys are field names and whose values are the
corresponding initial conditions applied on the first subinterval
:rtype: :class:`~.AttrDict` with :class:`str` keys and
:class:`firedrake.function.Function` values
"""
return AttrDict(self.get_initial_condition())
@property
def solver(self):
"""
See :meth:`~.MeshSeq.get_solver`.
"""
return self.get_solver()
def _create_solutions(self):
"""
Create the :class:`~.FunctionData` instance for holding solution data.
"""
self._solutions = ForwardSolutionData(self.time_partition, self.function_spaces)
@property
def solutions(self):
"""
:returns: the solution data object
:rtype: :class:`~.FunctionData`
"""
if not hasattr(self, "_solutions"):
self._create_solutions()
return self._solutions
def _reinitialise_fields(self, initial_conditions):
"""
Reinitialise fields and assign initial conditions on the given subinterval.
:arg initial_conditions: the initial conditions to assign to lagged solutions
:type initial_conditions: :class:`dict` with :class:`str` keys and
:class:`firedrake.function.Function` values
"""
for fieldname in self.field_names:
ic = initial_conditions[fieldname]
fs = ic.function_space()
field = self._get_field_metadata(fieldname)
if field.unsteady:
self.field_functions[fieldname] = (
firedrake.Function(fs, name=fieldname),
firedrake.Function(fs, name=f"{fieldname}_old").assign(ic),
)
else:
self.field_functions[fieldname] = firedrake.Function(fs, name=fieldname)
self.field_functions[fieldname].assign(ic)
@PETSc.Log.EventDecorator()
def _solve_forward(self, update_solutions=True, solver_kwargs=None):
r"""
Solve a forward problem on a sequence of subintervals. Yields the final solution
on each subinterval.
:kwarg update_solutions: if ``True``, updates the solution data
:type update_solutions: :class:`bool`
:kwarg solver_kwargs: parameters for the forward solver
:type solver_kwargs: :class:`dict` whose keys are :class:`str`\s and whose
values may take various types
:yields: the solution data of the forward solves
:ytype: :class:`~.ForwardSolutionData`
"""
solver_kwargs = solver_kwargs or {}
num_subintervals = len(self)
tp = self.time_partition
if update_solutions:
# Reinitialise the solution data object
self._create_solutions()
solutions = self.solutions.extract(layout="field")
# Stop annotating
if pyadjoint.annotate_tape():
tape = pyadjoint.get_working_tape()
if tape is not None:
tape.clear_tape()
pyadjoint.pause_annotation()
# Loop over the subintervals
checkpoint = self.initial_condition
for i in range(num_subintervals):
solver_gen = self.solver(i, **solver_kwargs)
# Reinitialise fields and assign initial conditions
self._reinitialise_fields(checkpoint)
if update_solutions:
# Solve sequentially between each export time
for j in range(tp.num_exports_per_subinterval[i] - 1):
for _ in range(tp.num_timesteps_per_export[i]):
next(solver_gen)
# Update the solution data
for fieldname, sol in self.field_functions.items():
field = self._get_field_metadata(fieldname)
if field.unsteady:
assert isinstance(sol, tuple)
solutions[fieldname].forward[i][j].assign(sol[0])
solutions[fieldname].forward_old[i][j].assign(sol[1])
else:
assert isinstance(sol, firedrake.Function)
solutions[fieldname].forward[i][j].assign(sol)
else:
# Solve over the entire subinterval in one go
for _ in range(tp.num_timesteps_per_subinterval[i]):
next(solver_gen)
# Transfer the checkpoint to the next subintervals
if i < num_subintervals - 1:
checkpoint = AttrDict(
{
fieldname: self._transfer(
self.field_functions[fieldname][0]
if self._get_field_metadata(fieldname).unsteady
else self.field_functions[fieldname],
fs[i + 1],
)
for fieldname, fs in self._fs.items()
}
)
yield checkpoint
[docs]
@PETSc.Log.EventDecorator()
def get_checkpoints(self, run_final_subinterval=False, solver_kwargs=None):
r"""
Get checkpoints corresponding to the starting fields on each subinterval.
:kwarg run_final_subinterval: if ``True``, the solver is run on the final
subinterval
:type run_final_subinterval: :class:`bool`
:kwarg solver_kwargs: parameters for the forward solver
:type solver_kwargs: :class:`dict` with :class:`str` keys and values which may
take various types
:returns: checkpoints for each subinterval
:rtype: :class:`list` of :class:`firedrake.function.Function`\s
"""
solver_kwargs = solver_kwargs or {}
N = len(self)
# The first checkpoint is the initial condition
checkpoints = [self.initial_condition]
# If there is only one subinterval then we are done
if N == 1 and not run_final_subinterval:
return checkpoints
# Otherwise, solve each subsequent subinterval and append the checkpoint
solver_gen = self._solve_forward(
update_solutions=False, solver_kwargs=solver_kwargs
)
for _ in range(N if run_final_subinterval else N - 1):
checkpoints.append(next(solver_gen))
return checkpoints
[docs]
@PETSc.Log.EventDecorator()
def solve_forward(self, solver_kwargs=None):
r"""
Solve a forward problem on a sequence of subintervals.
A dictionary of solution fields is computed - see :class:`~.ForwardSolutionData`
for more details.
:kwarg solver_kwargs: parameters for the forward solver
:type solver_kwargs: :class:`dict` whose keys are :class:`str`\s and whose
values may take various types
:returns: the solution data of the forward solves
:rtype: :class:`~.ForwardSolutionData`
"""
solver_kwargs = solver_kwargs or {}
solver_gen = self._solve_forward(update_solutions=True, **solver_kwargs)
for _ in range(len(self)):
next(solver_gen)
return self.solutions
[docs]
def check_element_count_convergence(self):
r"""
Check for convergence of the fixed point iteration due to the relative
difference in element count being smaller than the specified tolerance.
:return: an array, whose entries are ``True`` if convergence is detected on the
corresponding subinterval
:rtype: :class:`list` of :class:`bool`\s
"""
if self.params.drop_out_converged:
converged = self.converged
else:
converged = np.array([False] * len(self), dtype=bool)
if len(self.element_counts) >= max(2, self.params.miniter + 1):
for i, (ne_, ne) in enumerate(zip(*self.element_counts[-2:], strict=True)):
if not self.check_convergence[i]:
self.info(
f"Skipping element count convergence check on subinterval {i})"
f" because check_convergence[{i}] == False."
)
continue
if abs(ne - ne_) <= self.params.element_rtol * ne_:
converged[i] = True
if len(self) == 1:
pyrint(
f"Element count converged after {self.fp_iteration + 1}"
" iterations under relative tolerance"
f" {self.params.element_rtol}."
)
else:
pyrint(
f"Element count converged on subinterval {i} after"
f" {self.fp_iteration + 1} iterations under relative"
f" tolerance {self.params.element_rtol}."
)
# Check only early subintervals are marked as converged
if self.params.drop_out_converged and not converged.all():
first_not_converged = converged.argsort()[0]
converged[first_not_converged:] = False
return converged
[docs]
@PETSc.Log.EventDecorator()
def fixed_point_iteration(
self,
adaptor,
parameters=None,
update_params=None,
solver_kwargs=None,
adaptor_kwargs=None,
):
r"""
Apply mesh adaptation using a fixed point iteration loop approach.
:arg adaptor: function for adapting the mesh sequence. Its arguments are the
mesh sequence and the solution data object. It should return ``True`` if the
convergence criteria checks are to be skipped for this iteration. Otherwise,
it should return ``False``.
:kwarg parameters: parameters to apply to the mesh adaptation process
:type parameters: :class:`~.AdaptParameters`
:kwarg update_params: function for updating :attr:`~.MeshSeq.params` at each
iteration. Its arguments are the parameter class and the fixed point
iteration
:kwarg solver_kwargs: parameters to pass to the solver
:type solver_kwargs: :class:`dict` with :class:`str` keys and values which may
take various types
:kwarg adaptor_kwargs: parameters to pass to the adaptor
:type adaptor_kwargs: :class:`dict` with :class:`str` keys and values which may
take various types
:returns: solution data object
:rtype: :class:`~.ForwardSolutionData`
"""
# TODO #124: adaptor no longer needs solution data to be passed explicitly
self.params = parameters or AdaptParameters()
solver_kwargs = solver_kwargs or {}
adaptor_kwargs = adaptor_kwargs or {}
self._reset_counts()
self.converged[:] = False
self.check_convergence[:] = True
for fp_iteration in range(self.params.maxiter):
self.fp_iteration = fp_iteration
if update_params is not None:
update_params(self.params, self.fp_iteration)
# Solve the forward problem over all meshes
self.solve_forward(solver_kwargs=solver_kwargs)
# Adapt meshes, logging element and vertex counts
continue_unconditionally = adaptor(self, self.solutions, **adaptor_kwargs)
if self.params.drop_out_converged:
self.check_convergence[:] = np.logical_not(
np.logical_or(continue_unconditionally, self.converged)
)
self.element_counts.append(self.count_elements())
self.vertex_counts.append(self.count_vertices())
# Check for element count convergence
self.converged[:] = self.check_element_count_convergence()
if self.converged.all():
break
else:
for i, conv in enumerate(self.converged):
if not conv:
pyrint(
f"Failed to converge on subinterval {i} in"
f" {self.params.maxiter} iterations."
)
return self.solutions