"""
Sequences of meshes corresponding to a :class:`~.TimePartition`.
"""
from collections.abc import Iterable
import firedrake
import numpy as np
from animate.interpolation import transfer
from animate.quality import QualityMeasure
from animate.utility import Mesh
from firedrake.adjoint import pyadjoint
from firedrake.petsc import PETSc
from firedrake.pyplot import triplot
from .function_data import ForwardSolutionData
from .log import DEBUG, debug, info, logger, pyrint, warning
from .options import AdaptParameters
from .utility import AttrDict
__all__ = ["MeshSeq"]
[docs]
class MeshSeq:
"""
A sequence of meshes for solving a PDE associated with a particular
:class:`~.TimePartition` of the temporal domain.
"""
@PETSc.Log.EventDecorator()
def __init__(self, time_partition, initial_meshes, **kwargs):
r"""
:arg time_partition: a partition of the temporal domain
:type time_partition: :class:`~.TimePartition`
:arg initial_meshes: a list of meshes corresponding to the subinterval of the
time partition, or a single mesh to use for all subintervals
:type initial_meshes: :class:`list` or :class:`~.MeshGeometry`
:kwarg get_function_spaces: a function as described in
:meth:`~.MeshSeq.get_function_spaces`
:kwarg get_initial_condition: a function as described in
:meth:`~.MeshSeq.get_initial_condition`
:kwarg get_solver: a function as described in :meth:`~.MeshSeq.get_solver`
:kwarg transfer_method: the method to use for transferring fields between
meshes. Options are "project" (default) and "interpolate". See
:func:`animate.interpolation.transfer` for details
:type transfer_method: :class:`str`
:kwarg transfer_kwargs: kwargs to pass to the chosen transfer method
:type transfer_kwargs: :class:`dict` with :class:`str` keys and values which may
take various types
"""
self.time_partition = time_partition
self.fields = {field_name: None for field_name in time_partition.field_names}
self.field_types = dict(zip(self.fields, time_partition.field_types))
self.subintervals = time_partition.subintervals
self.num_subintervals = time_partition.num_subintervals
self.set_meshes(initial_meshes)
self._fs = None
self._get_function_spaces = kwargs.get("get_function_spaces")
self._get_initial_condition = kwargs.get("get_initial_condition")
self._get_solver = kwargs.get("get_solver")
self._transfer_method = kwargs.get("transfer_method", "project")
self._transfer_kwargs = kwargs.get("transfer_kwargs", {})
self.steady = time_partition.steady
self.check_convergence = np.array([True] * len(self), dtype=bool)
self.converged = np.array([False] * len(self), dtype=bool)
self.fp_iteration = 0
self.params = None
self.sections = [{} for mesh in self]
self._outputs_consistent()
def __str__(self):
return f"{[str(mesh) for mesh in self.meshes]}"
def __repr__(self):
name = type(self).__name__
if len(self) == 1:
return f"{name}([{repr(self.meshes[0])}])"
elif len(self) == 2:
return f"{name}([{repr(self.meshes[0])}, {repr(self.meshes[1])}])"
else:
return f"{name}([{repr(self.meshes[0])}, ..., {repr(self.meshes[-1])}])"
[docs]
def debug(self, msg):
"""
Print a ``debug`` message.
:arg msg: the message to print
:type msg: :class:`str`
"""
debug(f"{type(self).__name__}: {msg}")
[docs]
def warning(self, msg):
"""
Print a ``warning`` message.
:arg msg: the message to print
:type msg: :class:`str`
"""
warning(f"{type(self).__name__}: {msg}")
[docs]
def info(self, msg):
"""
Print an ``info`` level message.
:arg msg: the message to print
:type msg: :class:`str`
"""
info(f"{type(self).__name__}: {msg}")
def __len__(self):
return len(self.meshes)
def __getitem__(self, subinterval):
"""
:arg subinterval: a subinterval index
:type subinterval: :class:`int`
:returns: the corresponding mesh
:rtype: :class:`firedrake.MeshGeometry`
"""
return self.meshes[subinterval]
def __setitem__(self, subinterval, mesh):
"""
:arg subinterval: a subinterval index
:type subinterval: :class:`int`
:arg mesh: the mesh to use for that subinterval
:type subinterval: :class:`firedrake.MeshGeometry`
"""
self.meshes[subinterval] = mesh
[docs]
def count_elements(self):
r"""
Count the number of elements in each mesh in the sequence.
:returns: list of element counts
:rtype: :class:`list` of :class:`int`\s
"""
comm = firedrake.COMM_WORLD
return [comm.allreduce(mesh.coordinates.cell_set.size) for mesh in self]
[docs]
def count_vertices(self):
r"""
Count the number of vertices in each mesh in the sequence.
:returns: list of vertex counts
:rtype: :class:`list` of :class:`int`\s
"""
comm = firedrake.COMM_WORLD
return [comm.allreduce(mesh.coordinates.node_set.size) for mesh in self]
def _reset_counts(self):
"""
Reset the lists of element and vertex counts.
"""
self.element_counts = [self.count_elements()]
self.vertex_counts = [self.count_vertices()]
[docs]
def set_meshes(self, meshes):
r"""
Set all meshes in the sequence and deduce various properties.
:arg meshes: list of meshes to use in the sequence, or a single mesh to use for
all subintervals
:type meshes: :class:`list` of :class:`firedrake.MeshGeometry`\s or
:class:`firedrake.MeshGeometry`
"""
# TODO #122: Refactor to use the set method
if not isinstance(meshes, Iterable):
meshes = [Mesh(meshes) for subinterval in self.subintervals]
self.meshes = meshes
dim = np.array([mesh.topological_dimension() for mesh in meshes])
if dim.min() != dim.max():
raise ValueError("Meshes must all have the same topological dimension.")
self.dim = dim.min()
self._reset_counts()
if logger.level == DEBUG:
for i, mesh in enumerate(meshes):
nc = self.element_counts[0][i]
nv = self.vertex_counts[0][i]
qm = QualityMeasure(mesh)
ar = qm("aspect_ratio")
mar = ar.vector().gather().max()
self.debug(
f"{i}: {nc:7d} cells, {nv:7d} vertices, max aspect ratio {mar:.2f}"
)
debug(100 * "-")
[docs]
def plot(self, fig=None, axes=None, **kwargs):
"""
Plot the meshes comprising a 2D :class:`~.MeshSeq`.
:kwarg fig: matplotlib figure to use
:type fig: :class:`matplotlib.figure.Figure`
:kwarg axes: matplotlib axes to use
:type axes: :class:`matplotlib.axes._axes.Axes`
:returns: matplotlib figure and axes for the plots
:rtype1: :class:`matplotlib.figure.Figure`
:rtype2: :class:`matplotlib.axes._axes.Axes`
All keyword arguments are passed to :func:`firedrake.pyplot.triplot`.
"""
from matplotlib.pyplot import subplots
if self.dim != 2:
raise ValueError("MeshSeq plotting only supported in 2D.")
# Process kwargs
interior_kw = {"edgecolor": "k"}
interior_kw.update(kwargs.pop("interior_kw", {}))
boundary_kw = {"edgecolor": "k"}
boundary_kw.update(kwargs.pop("boundary_kw", {}))
kwargs["interior_kw"] = interior_kw
kwargs["boundary_kw"] = boundary_kw
if fig is None or axes is None:
n = len(self)
fig, axes = subplots(ncols=n, nrows=1, figsize=(5 * n, 5))
# Loop over all axes and plot the meshes
k = 0
if not isinstance(axes, Iterable):
axes = [axes]
for i, axis in enumerate(axes):
if not isinstance(axis, Iterable):
axis = [axis]
for ax in axis:
ax.set_title(f"MeshSeq[{k}]")
triplot(self.meshes[k], axes=ax, **kwargs)
ax.axis(False)
k += 1
if len(axis) == 1:
axes[i] = axis[0]
if len(axes) == 1:
axes = axes[0]
return fig, axes
[docs]
def get_function_spaces(self, mesh):
"""
Construct the function spaces corresponding to each field, for a given mesh.
:arg mesh: the mesh to base the function spaces on
:type mesh: :class:`firedrake.mesh.MeshGeometry`
:returns: a dictionary whose keys are field names and whose values are the
corresponding function spaces
:rtype: :class:`dict` with :class:`str` keys and
:class:`firedrake.functionspaceimpl.FunctionSpace` values
"""
if self._get_function_spaces is None:
raise NotImplementedError("'get_function_spaces' needs implementing.")
return self._get_function_spaces(mesh)
[docs]
def get_initial_condition(self):
r"""
Get the initial conditions applied on the first mesh in the sequence.
:returns: the dictionary, whose keys are field names and whose values are the
corresponding initial conditions applied
:rtype: :class:`dict` with :class:`str` keys and
:class:`firedrake.function.Function` values
"""
if self._get_initial_condition is not None:
return self._get_initial_condition(self)
return {
field: firedrake.Function(fs[0])
for field, fs in self.function_spaces.items()
}
[docs]
def get_solver(self):
"""
Get the function mapping a subinterval index and an initial condition dictionary
to a dictionary of solutions for the corresponding solver step.
Signature for the function to be returned:
```
:arg index: the subinterval index
:type index: :class:`int`
:arg ic: map from fields to the corresponding initial condition components
:type ic: :class:`dict` with :class:`str` keys and
:class:`firedrake.function.Function` values
:return: map from fields to the corresponding solutions
:rtype: :class:`dict` with :class:`str` keys and
:class:`firedrake.function.Function` values
```
:returns: the function for obtaining the solver
:rtype: see docstring above
"""
if self._get_solver is None:
raise NotImplementedError("'get_solver' needs implementing.")
return self._get_solver(self)
def _transfer(self, source, target_space, **kwargs):
"""
Transfer a field between meshes using the specified transfer method.
:arg source: the function to be transferred
:type source: :class:`firedrake.function.Function` or
:class:`firedrake.cofunction.Cofunction`
:arg target_space: the function space which we seek to transfer onto, or the
function or cofunction to use as the target
:type target_space: :class:`firedrake.functionspaceimpl.FunctionSpace`,
:class:`firedrake.function.Function`
or :class:`firedrake.cofunction.Cofunction`
:returns: the transferred function
:rtype: :class:`firedrake.function.Function` or
:class:`firedrake.cofunction.Cofunction`
Extra keyword arguments are passed to :func:`goalie.interpolation.transfer`.
"""
# Update kwargs with those specified by the user
transfer_kwargs = kwargs.copy()
transfer_kwargs.update(self._transfer_kwargs)
return transfer(source, target_space, self._transfer_method, **transfer_kwargs)
def _outputs_consistent(self):
"""
Assert that function spaces and initial conditions are given in a
dictionary format with :attr:`MeshSeq.fields` as keys.
"""
for method in ["function_spaces", "initial_condition", "solver"]:
if getattr(self, f"_get_{method}") is None:
continue
method_map = getattr(self, f"get_{method}")
if method == "function_spaces":
method_map = method_map(self.meshes[0])
elif method == "initial_condition":
method_map = method_map()
elif method == "solver":
self._reinitialise_fields(self.get_initial_condition())
solver_gen = method_map()(0)
assert hasattr(solver_gen, "__next__"), "solver should yield"
if logger.level == DEBUG:
next(solver_gen)
f, f_ = self.fields[next(iter(self.fields))]
if np.array_equal(f.vector().array(), f_.vector().array()):
self.debug(
"Current and lagged solutions are equal. Does the"
" solver yield before updating lagged solutions?"
) # noqa
break
assert isinstance(method_map, dict), f"get_{method} should return a dict"
mesh_seq_fields = set(self.fields)
method_fields = set(method_map.keys())
diff = mesh_seq_fields.difference(method_fields)
assert len(diff) == 0, f"missing fields {diff} in get_{method}"
diff = method_fields.difference(mesh_seq_fields)
assert len(diff) == 0, f"unexpected fields {diff} in get_{method}"
def _function_spaces_consistent(self):
"""
Determine whether the mesh sequence's function spaces are consistent with its
meshes.
:returns: ``True`` if the meshes and function spaces are consistent, otherwise
``False``
:rtype: `:class:`bool`
"""
consistent = len(self.time_partition) == len(self)
consistent &= all(len(self) == len(self._fs[field]) for field in self.fields)
for field in self.fields:
consistent &= all(
mesh == fs.mesh() for mesh, fs in zip(self.meshes, self._fs[field])
)
consistent &= all(
self._fs[field][0].ufl_element() == fs.ufl_element()
for fs in self._fs[field]
)
return consistent
def _update_function_spaces(self):
"""
Update the function space dictionary associated with the mesh sequence.
"""
if self._fs is None or not self._function_spaces_consistent():
self._fs = AttrDict(
{
field: [self.get_function_spaces(mesh)[field] for mesh in self]
for field in self.fields
}
)
assert (
self._function_spaces_consistent()
), "Meshes and function spaces are inconsistent"
@property
def function_spaces(self):
"""
Get the function spaces associated with the mesh sequence.
:returns: a dictionary whose keys are field names and whose values are the
corresponding function spaces
:rtype: :class:`~.AttrDict` with :class:`str` keys and
:class:`firedrake.functionspaceimpl.FunctionSpace` values
"""
self._update_function_spaces()
return self._fs
@property
def initial_condition(self):
"""
Get the initial conditions associated with the first subinterval.
:returns: a dictionary whose keys are field names and whose values are the
corresponding initial conditions applied on the first subinterval
:rtype: :class:`~.AttrDict` with :class:`str` keys and
:class:`firedrake.function.Function` values
"""
return AttrDict(self.get_initial_condition())
@property
def solver(self):
"""
See :meth:`~.MeshSeq.get_solver`.
"""
return self.get_solver()
def _create_solutions(self):
"""
Create the :class:`~.FunctionData` instance for holding solution data.
"""
self._solutions = ForwardSolutionData(self.time_partition, self.function_spaces)
@property
def solutions(self):
"""
:returns: the solution data object
:rtype: :class:`~.FunctionData`
"""
if not hasattr(self, "_solutions"):
self._create_solutions()
return self._solutions
def _reinitialise_fields(self, initial_conditions):
"""
Reinitialise fields and assign initial conditions on the given subinterval.
:arg initial_conditions: the initial conditions to assign to lagged solutions
:type initial_conditions: :class:`dict` with :class:`str` keys and
:class:`firedrake.function.Function` values
"""
for field, ic in initial_conditions.items():
fs = ic.function_space()
if self.field_types[field] == "steady":
self.fields[field] = firedrake.Function(fs, name=f"{field}").assign(ic)
else:
self.fields[field] = (
firedrake.Function(fs, name=field),
firedrake.Function(fs, name=f"{field}_old").assign(ic),
)
@PETSc.Log.EventDecorator()
def _solve_forward(self, update_solutions=True, solver_kwargs=None):
r"""
Solve a forward problem on a sequence of subintervals. Yields the final solution
on each subinterval.
:kwarg update_solutions: if ``True``, updates the solution data
:type update_solutions: :class:`bool`
:kwarg solver_kwargs: parameters for the forward solver
:type solver_kwargs: :class:`dict` whose keys are :class:`str`\s and whose
values may take various types
:yields: the solution data of the forward solves
:ytype: :class:`~.ForwardSolutionData`
"""
solver_kwargs = solver_kwargs or {}
num_subintervals = len(self)
tp = self.time_partition
if update_solutions:
# Reinitialise the solution data object
self._create_solutions()
solutions = self.solutions.extract(layout="field")
# Stop annotating
if pyadjoint.annotate_tape():
tape = pyadjoint.get_working_tape()
if tape is not None:
tape.clear_tape()
pyadjoint.pause_annotation()
# Loop over the subintervals
checkpoint = self.initial_condition
for i in range(num_subintervals):
solver_gen = self.solver(i, **solver_kwargs)
# Reinitialise fields and assign initial conditions
self._reinitialise_fields(checkpoint)
if update_solutions:
# Solve sequentially between each export time
for j in range(tp.num_exports_per_subinterval[i] - 1):
for _ in range(tp.num_timesteps_per_export[i]):
next(solver_gen)
# Update the solution data
for field, sol in self.fields.items():
if not self.field_types[field] == "steady":
assert isinstance(sol, tuple)
solutions[field].forward[i][j].assign(sol[0])
solutions[field].forward_old[i][j].assign(sol[1])
else:
assert isinstance(sol, firedrake.Function)
solutions[field].forward[i][j].assign(sol)
else:
# Solve over the entire subinterval in one go
for _ in range(tp.num_timesteps_per_subinterval[i]):
next(solver_gen)
# Transfer the checkpoint to the next subintervals
if i < num_subintervals - 1:
checkpoint = AttrDict(
{
field: self._transfer(
self.fields[field]
if self.field_types[field] == "steady"
else self.fields[field][0],
fs[i + 1],
)
for field, fs in self._fs.items()
}
)
yield checkpoint
[docs]
@PETSc.Log.EventDecorator()
def get_checkpoints(self, run_final_subinterval=False, solver_kwargs=None):
r"""
Get checkpoints corresponding to the starting fields on each subinterval.
:kwarg run_final_subinterval: if ``True``, the solver is run on the final
subinterval
:type run_final_subinterval: :class:`bool`
:kwarg solver_kwargs: parameters for the forward solver
:type solver_kwargs: :class:`dict` with :class:`str` keys and values which may
take various types
:returns: checkpoints for each subinterval
:rtype: :class:`list` of :class:`firedrake.function.Function`\s
"""
solver_kwargs = solver_kwargs or {}
N = len(self)
# The first checkpoint is the initial condition
checkpoints = [self.initial_condition]
# If there is only one subinterval then we are done
if N == 1 and not run_final_subinterval:
return checkpoints
# Otherwise, solve each subsequent subinterval and append the checkpoint
solver_gen = self._solve_forward(
update_solutions=False, solver_kwargs=solver_kwargs
)
for _ in range(N if run_final_subinterval else N - 1):
checkpoints.append(next(solver_gen))
return checkpoints
[docs]
@PETSc.Log.EventDecorator()
def solve_forward(self, solver_kwargs=None):
r"""
Solve a forward problem on a sequence of subintervals.
A dictionary of solution fields is computed - see :class:`~.ForwardSolutionData`
for more details.
:kwarg solver_kwargs: parameters for the forward solver
:type solver_kwargs: :class:`dict` whose keys are :class:`str`\s and whose
values may take various types
:returns: the solution data of the forward solves
:rtype: :class:`~.ForwardSolutionData`
"""
solver_kwargs = solver_kwargs or {}
solver_gen = self._solve_forward(update_solutions=True, **solver_kwargs)
for _ in range(len(self)):
next(solver_gen)
return self.solutions
[docs]
def check_element_count_convergence(self):
r"""
Check for convergence of the fixed point iteration due to the relative
difference in element count being smaller than the specified tolerance.
:return: an array, whose entries are ``True`` if convergence is detected on the
corresponding subinterval
:rtype: :class:`list` of :class:`bool`\s
"""
if self.params.drop_out_converged:
converged = self.converged
else:
converged = np.array([False] * len(self), dtype=bool)
if len(self.element_counts) >= max(2, self.params.miniter + 1):
for i, (ne_, ne) in enumerate(zip(*self.element_counts[-2:])):
if not self.check_convergence[i]:
self.info(
f"Skipping element count convergence check on subinterval {i})"
f" because check_convergence[{i}] == False."
)
continue
if abs(ne - ne_) <= self.params.element_rtol * ne_:
converged[i] = True
if len(self) == 1:
pyrint(
f"Element count converged after {self.fp_iteration+1}"
" iterations under relative tolerance"
f" {self.params.element_rtol}."
)
else:
pyrint(
f"Element count converged on subinterval {i} after"
f" {self.fp_iteration+1} iterations under relative"
f" tolerance {self.params.element_rtol}."
)
# Check only early subintervals are marked as converged
if self.params.drop_out_converged and not converged.all():
first_not_converged = converged.argsort()[0]
converged[first_not_converged:] = False
return converged
[docs]
@PETSc.Log.EventDecorator()
def fixed_point_iteration(
self,
adaptor,
parameters=None,
update_params=None,
solver_kwargs=None,
adaptor_kwargs=None,
):
r"""
Apply mesh adaptation using a fixed point iteration loop approach.
:arg adaptor: function for adapting the mesh sequence. Its arguments are the
mesh sequence and the solution data object. It should return ``True`` if the
convergence criteria checks are to be skipped for this iteration. Otherwise,
it should return ``False``.
:kwarg parameters: parameters to apply to the mesh adaptation process
:type parameters: :class:`~.AdaptParameters`
:kwarg update_params: function for updating :attr:`~.MeshSeq.params` at each
iteration. Its arguments are the parameter class and the fixed point
iteration
:kwarg solver_kwargs: parameters to pass to the solver
:type solver_kwargs: :class:`dict` with :class:`str` keys and values which may
take various types
:kwarg adaptor_kwargs: parameters to pass to the adaptor
:type adaptor_kwargs: :class:`dict` with :class:`str` keys and values which may
take various types
:returns: solution data object
:rtype: :class:`~.ForwardSolutionData`
"""
# TODO #124: adaptor no longer needs solution data to be passed explicitly
self.params = parameters or AdaptParameters()
solver_kwargs = solver_kwargs or {}
adaptor_kwargs = adaptor_kwargs or {}
self._reset_counts()
self.converged[:] = False
self.check_convergence[:] = True
for fp_iteration in range(self.params.maxiter):
self.fp_iteration = fp_iteration
if update_params is not None:
update_params(self.params, self.fp_iteration)
# Solve the forward problem over all meshes
self.solve_forward(solver_kwargs=solver_kwargs)
# Adapt meshes, logging element and vertex counts
continue_unconditionally = adaptor(self, self.solutions, **adaptor_kwargs)
if self.params.drop_out_converged:
self.check_convergence[:] = np.logical_not(
np.logical_or(continue_unconditionally, self.converged)
)
self.element_counts.append(self.count_elements())
self.vertex_counts.append(self.count_vertices())
# Check for element count convergence
self.converged[:] = self.check_element_count_convergence()
if self.converged.all():
break
else:
for i, conv in enumerate(self.converged):
if not conv:
pyrint(
f"Failed to converge on subinterval {i} in"
f" {self.params.maxiter} iterations."
)
return self.solutions