Source code for goalie.mesh_seq

"""
Sequences of meshes corresponding to a :class:`~.TimePartition`.
"""

from collections.abc import Iterable

import firedrake
import numpy as np
from animate.interpolation import transfer
from animate.quality import QualityMeasure
from animate.utility import Mesh, function_data_max
from firedrake.adjoint import pyadjoint
from firedrake.mesh import MeshSequenceGeometry
from firedrake.petsc import PETSc
from firedrake.pyplot import triplot

from .function_data import ForwardSolutionData
from .log import DEBUG, debug, info, logger, pyrint, warning
from .options import AdaptParameters
from .utility import AttrDict

__all__ = ["MeshSeq"]


[docs] class MeshSeq: """ A sequence of meshes for solving a PDE associated with a particular :class:`~.TimePartition` of the temporal domain. """ @PETSc.Log.EventDecorator() def __init__(self, time_partition, initial_meshes, **kwargs): r""" :arg time_partition: a partition of the temporal domain :type time_partition: :class:`~.TimePartition` :arg initial_meshes: a list of meshes corresponding to the subinterval of the time partition, or a single mesh to use for all subintervals :type initial_meshes: :class:`list` or :class:`~.MeshGeometry` :kwarg get_initial_condition: a function as described in :meth:`~.MeshSeq.get_initial_condition` :kwarg get_solver: a function as described in :meth:`~.MeshSeq.get_solver` :kwarg transfer_method: the method to use for transferring fields between meshes. Options are "project" (default) and "interpolate". See :func:`animate.interpolation.transfer` for details :type transfer_method: :class:`str` :kwarg transfer_kwargs: kwargs to pass to the chosen transfer method :type transfer_kwargs: :class:`dict` with :class:`str` keys and values which may take various types """ self.time_partition = time_partition self.subintervals = time_partition.subintervals self.num_subintervals = time_partition.num_subintervals self.field_names = time_partition.field_names self.field_metadata = time_partition.field_metadata self.solution_names = [ fieldname for fieldname in self.field_names if self.field_metadata[fieldname].solved_for ] # Create a dictionary to hold field Functions with field names as keys and None # as values self.field_functions = dict.fromkeys(self.field_metadata) self.set_meshes(initial_meshes) self._fs = None if "get_function_spaces" in kwargs: raise KeyError( "get_function_spaces is no longer supported. Specify the finite_element" " argument for the Field class instead." ) self._get_initial_condition = kwargs.get("get_initial_condition") self._get_solver = kwargs.get("get_solver") self._transfer_method = kwargs.get("transfer_method", "project") self._transfer_kwargs = kwargs.get("transfer_kwargs", {}) self.steady = time_partition.steady self.check_convergence = np.array([True] * len(self), dtype=bool) self.converged = np.array([False] * len(self), dtype=bool) self.fp_iteration = 0 self.params = None self.sections = [{} for mesh in self] self._outputs_consistent() def __str__(self): return f"{[str(mesh) for mesh in self.meshes]}" def __repr__(self): name = type(self).__name__ if len(self) == 1: return f"{name}([{repr(self.meshes[0])}])" elif len(self) == 2: return f"{name}([{repr(self.meshes[0])}, {repr(self.meshes[1])}])" else: return f"{name}([{repr(self.meshes[0])}, ..., {repr(self.meshes[-1])}])"
[docs] def debug(self, msg): """ Print a ``debug`` message. :arg msg: the message to print :type msg: :class:`str` """ debug(f"{type(self).__name__}: {msg}")
[docs] def warning(self, msg): """ Print a ``warning`` message. :arg msg: the message to print :type msg: :class:`str` """ warning(f"{type(self).__name__}: {msg}")
[docs] def info(self, msg): """ Print an ``info`` level message. :arg msg: the message to print :type msg: :class:`str` """ info(f"{type(self).__name__}: {msg}")
def __len__(self): return len(self.meshes) def __getitem__(self, subinterval): """ :arg subinterval: a subinterval index :type subinterval: :class:`int` :returns: the corresponding mesh :rtype: :class:`firedrake.MeshGeometry` """ return self.meshes[subinterval] def __setitem__(self, subinterval, mesh): """ :arg subinterval: a subinterval index :type subinterval: :class:`int` :arg mesh: the mesh to use for that subinterval :type subinterval: :class:`firedrake.MeshGeometry` """ self.meshes[subinterval] = mesh
[docs] def count_elements(self): r""" Count the number of elements in each mesh in the sequence. :returns: list of element counts :rtype: :class:`list` of :class:`int`\s """ comm = firedrake.COMM_WORLD return [comm.allreduce(mesh.coordinates.cell_set.size) for mesh in self]
[docs] def count_vertices(self): r""" Count the number of vertices in each mesh in the sequence. :returns: list of vertex counts :rtype: :class:`list` of :class:`int`\s """ comm = firedrake.COMM_WORLD return [comm.allreduce(mesh.coordinates.node_set.size) for mesh in self]
def _reset_counts(self): """ Reset the lists of element and vertex counts. """ self.element_counts = [self.count_elements()] self.vertex_counts = [self.count_vertices()]
[docs] def set_meshes(self, meshes): r""" Set all meshes in the sequence and deduce various properties. :arg meshes: list of meshes to use in the sequence, or a single mesh to use for all subintervals :type meshes: :class:`list` of :class:`firedrake.MeshGeometry`\s or :class:`firedrake.MeshGeometry` """ # TODO #122: Refactor to use the set method if not isinstance(meshes, list): meshes = [Mesh(meshes) for subinterval in self.subintervals] self.meshes = meshes dim = np.array([mesh.topological_dimension for mesh in meshes]) if dim.min() != dim.max(): raise ValueError("Meshes must all have the same topological dimension.") self.dim = dim.min() self._reset_counts() if logger.level == DEBUG: for i, mesh in enumerate(meshes): nc = self.element_counts[0][i] nv = self.vertex_counts[0][i] qm = QualityMeasure(mesh) ar = qm("aspect_ratio") mar = function_data_max(ar) self.debug( f"{i}: {nc:7d} cells, {nv:7d} vertices, max aspect ratio {mar:.2f}" ) debug(100 * "-")
[docs] def plot(self, fig=None, axes=None, **kwargs): """ Plot the meshes comprising a 2D :class:`~.MeshSeq`. :kwarg fig: matplotlib figure to use :type fig: :class:`matplotlib.figure.Figure` :kwarg axes: matplotlib axes to use :type axes: :class:`matplotlib.axes._axes.Axes` :returns: matplotlib figure and axes for the plots :rtype1: :class:`matplotlib.figure.Figure` :rtype2: :class:`matplotlib.axes._axes.Axes` All keyword arguments are passed to :func:`firedrake.pyplot.triplot`. """ from matplotlib.pyplot import subplots if self.dim != 2: raise ValueError("MeshSeq plotting only supported in 2D.") # Process kwargs interior_kw = {"edgecolor": "k"} interior_kw.update(kwargs.pop("interior_kw", {})) boundary_kw = {"edgecolor": "k"} boundary_kw.update(kwargs.pop("boundary_kw", {})) kwargs["interior_kw"] = interior_kw kwargs["boundary_kw"] = boundary_kw if fig is None or axes is None: n = len(self) fig, axes = subplots(ncols=n, nrows=1, figsize=(5 * n, 5)) # Loop over all axes and plot the meshes k = 0 if not isinstance(axes, Iterable): axes = [axes] for i, axis in enumerate(axes): if not isinstance(axis, Iterable): axis = [axis] for ax in axis: ax.set_title(f"MeshSeq[{k}]") triplot(self.meshes[k], axes=ax, **kwargs) ax.axis(False) k += 1 if len(axis) == 1: axes[i] = axis[0] if len(axes) == 1: axes = axes[0] return fig, axes
def _get_field_metadata(self, fieldname): if fieldname not in self.field_names: raise ValueError(f"Field '{fieldname}' is not associated with the MeshSeq.") return self.field_metadata[fieldname]
[docs] def get_function_spaces(self, mesh): """ Construct the function spaces corresponding to each field, for a given mesh. :arg mesh: the mesh to base the function spaces on :type mesh: :class:`firedrake.mesh.MeshGeometry` :returns: a dictionary whose keys are field names and whose values are the corresponding function spaces :rtype: :class:`dict` with :class:`str` keys and :class:`firedrake.functionspaceimpl.FunctionSpace` values """ function_spaces = {} for fieldname, field in self.field_metadata.items(): function_spaces[fieldname] = field.get_function_space(mesh) return function_spaces
[docs] def get_initial_condition(self): r""" Get the initial conditions applied on the first mesh in the sequence. :returns: the dictionary, whose keys are field names and whose values are the corresponding initial conditions applied :rtype: :class:`dict` with :class:`str` keys and :class:`firedrake.function.Function` values """ if self._get_initial_condition is not None: return self._get_initial_condition(self) return { fieldname: firedrake.Function(fs[0]) for fieldname, fs in self.function_spaces.items() }
[docs] def get_solver(self): """ Get the function mapping a subinterval index and an initial condition dictionary to a dictionary of solutions for the corresponding solver step. Signature for the function to be returned: ``` :arg index: the subinterval index :type index: :class:`int` :arg ic: map from fields to the corresponding initial condition components :type ic: :class:`dict` with :class:`str` keys and :class:`firedrake.function.Function` values :return: map from fields to the corresponding solutions :rtype: :class:`dict` with :class:`str` keys and :class:`firedrake.function.Function` values ``` :returns: the function for obtaining the solver :rtype: see docstring above """ if self._get_solver is None: raise NotImplementedError("'get_solver' needs implementing.") return self._get_solver(self)
def _transfer(self, source, target_space, **kwargs): """ Transfer a field between meshes using the specified transfer method. :arg source: the function to be transferred :type source: :class:`firedrake.function.Function` or :class:`firedrake.cofunction.Cofunction` :arg target_space: the function space which we seek to transfer onto, or the function or cofunction to use as the target :type target_space: :class:`firedrake.functionspaceimpl.FunctionSpace`, :class:`firedrake.function.Function` or :class:`firedrake.cofunction.Cofunction` :returns: the transferred function :rtype: :class:`firedrake.function.Function` or :class:`firedrake.cofunction.Cofunction` Extra keyword arguments are passed to :func:`goalie.interpolation.transfer`. """ # Update kwargs with those specified by the user transfer_kwargs = kwargs.copy() transfer_kwargs.update(self._transfer_kwargs) return transfer(source, target_space, self._transfer_method, **transfer_kwargs) def _outputs_consistent(self): """ Assert that function spaces and initial conditions are given in a dictionary format with the same keys as :attr:`MeshSeq.field_metadata`. """ for method in ["initial_condition", "solver"]: if getattr(self, f"_get_{method}") is None: continue method_map = getattr(self, f"get_{method}") if method == "initial_condition": method_map = method_map() elif method == "solver": self._reinitialise_fields(self.get_initial_condition()) solver_gen = method_map()(0) assert hasattr(solver_gen, "__next__"), "solver should yield" if logger.level == DEBUG: next(solver_gen) f, f_ = self.field_functions[next(iter(self.field_functions))] if np.array_equal(f.dat.data_ro, f_.dat.data_ro): self.debug( "Current and lagged solutions are equal. Does the" " solver yield before updating lagged solutions?" ) # noqa break assert isinstance(method_map, dict), f"get_{method} should return a dict" mesh_seq_fields = set(self.field_functions) method_fields = set(method_map.keys()) diff = mesh_seq_fields.difference(method_fields) assert len(diff) == 0, f"missing fields {diff} in get_{method}" diff = method_fields.difference(mesh_seq_fields) assert len(diff) == 0, f"unexpected fields {diff} in get_{method}" def _function_spaces_consistent(self): """ Determine whether the mesh sequence's function spaces are consistent with its meshes. :returns: ``True`` if the meshes and function spaces are consistent, otherwise ``False`` :rtype: `:class:`bool` """ consistent = len(self.time_partition) == len(self) consistent &= all( len(self) == len(self._fs[fieldname]) for fieldname in self.field_functions ) for fieldname in self.field_functions: if isinstance(self._fs[fieldname][0].mesh(), MeshSequenceGeometry): consistent &= all( mesh1 == mesh2 for mesh1, fs in zip(self.meshes, self._fs[fieldname], strict=True) for mesh2 in fs.mesh() ) else: consistent &= all( mesh == fs.mesh() for mesh, fs in zip(self.meshes, self._fs[fieldname], strict=True) ) consistent &= all( self._fs[fieldname][0].ufl_element() == fs.ufl_element() for fs in self._fs[fieldname] ) return consistent def _update_function_spaces(self): """ Update the function space dictionary associated with the mesh sequence. """ if self._fs is None or not self._function_spaces_consistent(): self._fs = AttrDict( { fieldname: [ self.get_function_spaces(mesh)[fieldname] for mesh in self ] for fieldname in self.field_functions } ) assert self._function_spaces_consistent(), ( "Meshes and function spaces are inconsistent" ) @property def function_spaces(self): """ Get the function spaces associated with the mesh sequence. :returns: a dictionary whose keys are field names and whose values are the corresponding function spaces :rtype: :class:`~.AttrDict` with :class:`str` keys and :class:`firedrake.functionspaceimpl.FunctionSpace` values """ self._update_function_spaces() return self._fs @property def initial_condition(self): """ Get the initial conditions associated with the first subinterval. :returns: a dictionary whose keys are field names and whose values are the corresponding initial conditions applied on the first subinterval :rtype: :class:`~.AttrDict` with :class:`str` keys and :class:`firedrake.function.Function` values """ return AttrDict(self.get_initial_condition()) @property def solver(self): """ See :meth:`~.MeshSeq.get_solver`. """ return self.get_solver() def _create_solutions(self): """ Create the :class:`~.FunctionData` instance for holding solution data. """ self._solutions = ForwardSolutionData(self.time_partition, self.function_spaces) @property def solutions(self): """ :returns: the solution data object :rtype: :class:`~.FunctionData` """ if not hasattr(self, "_solutions"): self._create_solutions() return self._solutions def _reinitialise_fields(self, initial_conditions): """ Reinitialise fields and assign initial conditions on the given subinterval. :arg initial_conditions: the initial conditions to assign to lagged solutions :type initial_conditions: :class:`dict` with :class:`str` keys and :class:`firedrake.function.Function` values """ for fieldname in self.field_names: ic = initial_conditions[fieldname] fs = ic.function_space() field = self._get_field_metadata(fieldname) if field.unsteady: self.field_functions[fieldname] = ( firedrake.Function(fs, name=fieldname), firedrake.Function(fs, name=f"{fieldname}_old").assign(ic), ) else: self.field_functions[fieldname] = firedrake.Function(fs, name=fieldname) self.field_functions[fieldname].assign(ic) @PETSc.Log.EventDecorator() def _solve_forward(self, update_solutions=True, solver_kwargs=None): r""" Solve a forward problem on a sequence of subintervals. Yields the final solution on each subinterval. :kwarg update_solutions: if ``True``, updates the solution data :type update_solutions: :class:`bool` :kwarg solver_kwargs: parameters for the forward solver :type solver_kwargs: :class:`dict` whose keys are :class:`str`\s and whose values may take various types :yields: the solution data of the forward solves :ytype: :class:`~.ForwardSolutionData` """ solver_kwargs = solver_kwargs or {} num_subintervals = len(self) tp = self.time_partition if update_solutions: # Reinitialise the solution data object self._create_solutions() solutions = self.solutions.extract(layout="field") # Stop annotating if pyadjoint.annotate_tape(): tape = pyadjoint.get_working_tape() if tape is not None: tape.clear_tape() pyadjoint.pause_annotation() # Loop over the subintervals checkpoint = self.initial_condition for i in range(num_subintervals): solver_gen = self.solver(i, **solver_kwargs) # Reinitialise fields and assign initial conditions self._reinitialise_fields(checkpoint) if update_solutions: # Solve sequentially between each export time for j in range(tp.num_exports_per_subinterval[i] - 1): for _ in range(tp.num_timesteps_per_export[i]): next(solver_gen) # Update the solution data for fieldname, sol in self.field_functions.items(): field = self._get_field_metadata(fieldname) if field.unsteady: assert isinstance(sol, tuple) solutions[fieldname].forward[i][j].assign(sol[0]) solutions[fieldname].forward_old[i][j].assign(sol[1]) else: assert isinstance(sol, firedrake.Function) solutions[fieldname].forward[i][j].assign(sol) else: # Solve over the entire subinterval in one go for _ in range(tp.num_timesteps_per_subinterval[i]): next(solver_gen) # Transfer the checkpoint to the next subintervals if i < num_subintervals - 1: checkpoint = AttrDict( { fieldname: self._transfer( self.field_functions[fieldname][0] if self._get_field_metadata(fieldname).unsteady else self.field_functions[fieldname], fs[i + 1], ) for fieldname, fs in self._fs.items() } ) yield checkpoint
[docs] @PETSc.Log.EventDecorator() def get_checkpoints(self, run_final_subinterval=False, solver_kwargs=None): r""" Get checkpoints corresponding to the starting fields on each subinterval. :kwarg run_final_subinterval: if ``True``, the solver is run on the final subinterval :type run_final_subinterval: :class:`bool` :kwarg solver_kwargs: parameters for the forward solver :type solver_kwargs: :class:`dict` with :class:`str` keys and values which may take various types :returns: checkpoints for each subinterval :rtype: :class:`list` of :class:`firedrake.function.Function`\s """ solver_kwargs = solver_kwargs or {} N = len(self) # The first checkpoint is the initial condition checkpoints = [self.initial_condition] # If there is only one subinterval then we are done if N == 1 and not run_final_subinterval: return checkpoints # Otherwise, solve each subsequent subinterval and append the checkpoint solver_gen = self._solve_forward( update_solutions=False, solver_kwargs=solver_kwargs ) for _ in range(N if run_final_subinterval else N - 1): checkpoints.append(next(solver_gen)) return checkpoints
[docs] @PETSc.Log.EventDecorator() def solve_forward(self, solver_kwargs=None): r""" Solve a forward problem on a sequence of subintervals. A dictionary of solution fields is computed - see :class:`~.ForwardSolutionData` for more details. :kwarg solver_kwargs: parameters for the forward solver :type solver_kwargs: :class:`dict` whose keys are :class:`str`\s and whose values may take various types :returns: the solution data of the forward solves :rtype: :class:`~.ForwardSolutionData` """ solver_kwargs = solver_kwargs or {} solver_gen = self._solve_forward(update_solutions=True, **solver_kwargs) for _ in range(len(self)): next(solver_gen) return self.solutions
[docs] def check_element_count_convergence(self): r""" Check for convergence of the fixed point iteration due to the relative difference in element count being smaller than the specified tolerance. :return: an array, whose entries are ``True`` if convergence is detected on the corresponding subinterval :rtype: :class:`list` of :class:`bool`\s """ if self.params.drop_out_converged: converged = self.converged else: converged = np.array([False] * len(self), dtype=bool) if len(self.element_counts) >= max(2, self.params.miniter + 1): for i, (ne_, ne) in enumerate(zip(*self.element_counts[-2:], strict=True)): if not self.check_convergence[i]: self.info( f"Skipping element count convergence check on subinterval {i})" f" because check_convergence[{i}] == False." ) continue if abs(ne - ne_) <= self.params.element_rtol * ne_: converged[i] = True if len(self) == 1: pyrint( f"Element count converged after {self.fp_iteration + 1}" " iterations under relative tolerance" f" {self.params.element_rtol}." ) else: pyrint( f"Element count converged on subinterval {i} after" f" {self.fp_iteration + 1} iterations under relative" f" tolerance {self.params.element_rtol}." ) # Check only early subintervals are marked as converged if self.params.drop_out_converged and not converged.all(): first_not_converged = converged.argsort()[0] converged[first_not_converged:] = False return converged
[docs] @PETSc.Log.EventDecorator() def fixed_point_iteration( self, adaptor, parameters=None, update_params=None, solver_kwargs=None, adaptor_kwargs=None, ): r""" Apply mesh adaptation using a fixed point iteration loop approach. :arg adaptor: function for adapting the mesh sequence. Its arguments are the mesh sequence and the solution data object. It should return ``True`` if the convergence criteria checks are to be skipped for this iteration. Otherwise, it should return ``False``. :kwarg parameters: parameters to apply to the mesh adaptation process :type parameters: :class:`~.AdaptParameters` :kwarg update_params: function for updating :attr:`~.MeshSeq.params` at each iteration. Its arguments are the parameter class and the fixed point iteration :kwarg solver_kwargs: parameters to pass to the solver :type solver_kwargs: :class:`dict` with :class:`str` keys and values which may take various types :kwarg adaptor_kwargs: parameters to pass to the adaptor :type adaptor_kwargs: :class:`dict` with :class:`str` keys and values which may take various types :returns: solution data object :rtype: :class:`~.ForwardSolutionData` """ # TODO #124: adaptor no longer needs solution data to be passed explicitly self.params = parameters or AdaptParameters() solver_kwargs = solver_kwargs or {} adaptor_kwargs = adaptor_kwargs or {} self._reset_counts() self.converged[:] = False self.check_convergence[:] = True for fp_iteration in range(self.params.maxiter): self.fp_iteration = fp_iteration if update_params is not None: update_params(self.params, self.fp_iteration) # Solve the forward problem over all meshes self.solve_forward(solver_kwargs=solver_kwargs) # Adapt meshes, logging element and vertex counts continue_unconditionally = adaptor(self, self.solutions, **adaptor_kwargs) if self.params.drop_out_converged: self.check_convergence[:] = np.logical_not( np.logical_or(continue_unconditionally, self.converged) ) self.element_counts.append(self.count_elements()) self.vertex_counts.append(self.count_vertices()) # Check for element count convergence self.converged[:] = self.check_element_count_convergence() if self.converged.all(): break else: for i, conv in enumerate(self.converged): if not conv: pyrint( f"Failed to converge on subinterval {i} in" f" {self.params.maxiter} iterations." ) return self.solutions