Source code for movement.mover

from warnings import warn

import firedrake
import firedrake.exceptions as fexc
import numpy as np
from firedrake.cython.dmcommon import create_section
from firedrake.petsc import PETSc

__all__ = ["PrimeMover"]


[docs] class PrimeMover: """ Base class for all mesh movers. """ def __init__( self, mesh, monitor_function=None, raise_convergence_errors=True, **kwargs ): r""" :arg mesh: the physical mesh :type mesh: :class:`firedrake.mesh.MeshGeometry` :arg monitor_function: a Python function which takes a mesh as input :type monitor_function: :class:`~.Callable` :kwarg raise_convergence_errors: convergence error handling behaviour: if `True` then :class:`~.ConvergenceError`\s are raised, else warnings are raised and the program is allowed to continue :kwarg raise_convergence_errors: :class:`bool` """ self.mesh = firedrake.Mesh(mesh.coordinates.copy(deepcopy=True)) self.monitor_function = monitor_function if not raise_convergence_errors: warn( f"{type(self)}.move called with raise_convergence_errors=False." " Beware: this option can produce poor quality meshes!" ) self.raise_convergence_errors = raise_convergence_errors self.dim = self.mesh.topological_dimension() self.gdim = self.mesh.geometric_dimension() self.plex = self.mesh.topology_dm self.vertex_indices = self.plex.getDepthStratum(0) self.edge_indices = self.plex.getDepthStratum(1) # Measures degree = kwargs.get("quadrature_degree") self.dx = firedrake.dx(domain=self.mesh, degree=degree) self.ds = firedrake.ds(domain=self.mesh, degree=degree) self.dS = firedrake.dS(domain=self.mesh, degree=degree) # Mesh coordinate functions self.coord_space = self.mesh.coordinates.function_space() self.x = firedrake.Function(self.mesh.coordinates, name="Physical coordinates") self.xi = firedrake.Function( self.mesh.coordinates, name="Computational coordinates" ) self.v = firedrake.Function(self.coord_space, name="Mesh velocity") def _convergence_message(self, iterations=None): """ Report solver convergence. :kwarg iterations: number of iterations before reaching convergence :type iterations: :class:`int` """ msg = "Solver converged" if iterations: msg += f" in {iterations} iteration{plural(iterations)}" PETSc.Sys.Print(f"{msg}.") def _exception(self, msg, exception=None, error_type=fexc.ConvergenceError): """ Raise an error or warning as indicated by the :attr:`raise_convergence_error` option. :arg msg: message for the error/warning report :type msg: :class:`str` :kwarg exception: original exception that was triggered :type exception: :class:`~.Exception` object :kwarg error_type: error class to use :type error_type: :class:`~.Exception` class """ exc_type = error_type if self.raise_convergence_errors else Warning if exception: raise exc_type(msg) from exception else: raise exc_type(msg) def _convergence_error(self, iterations=None, exception=None): """ Raise an error or warning for a solver fail as indicated by the :attr:`raise_convergence_error` option. :kwarg iterations: number of iterations before failure :type iterations: :class:`int` :kwarg exception: original exception that was triggered :type exception: :class:`~.Exception` """ msg = "Solver failed to converge" if iterations: msg += f" in {iterations} iteration{plural(iterations)}" self._exception(f"{msg}.", exception=exception) def _divergence_error(self, iterations=None, exception=None): """ Raise an error or warning for a solver divergence as indicated by the :attr:`raise_convergence_error` option. :kwarg iterations: number of iterations before failure :type iterations: :class:`int` :kwarg exception: original exception that was triggered :type exception: :class:`~.Exception` """ msg = "Solver diverged" if iterations: msg += f" after {iterations} iteration{plural(iterations)}" self._exception(f"{msg}.", exception=exception) def _get_coordinate_section(self): entity_dofs = np.zeros(self.dim + 1, dtype=np.int32) entity_dofs[0] = self.gdim self._coordinate_section = create_section(self.mesh, entity_dofs)[0] dm_coords = self.plex.getCoordinateDM() dm_coords.setDefaultSection(self._coordinate_section) self._coords_local_vec = dm_coords.createLocalVec() self._update_plex_coordinates() def _update_plex_coordinates(self): if not hasattr(self, "_coords_local_vec"): self._get_coordinate_section() self._coords_local_vec.array[:] = np.reshape( self.mesh.coordinates.dat.data_with_halos, self._coords_local_vec.array.shape, ) self.plex.setCoordinatesLocal(self._coords_local_vec) def _get_edge_vector_section(self): entity_dofs = np.zeros(self.dim + 1, dtype=np.int32) entity_dofs[1] = 1 self._edge_vector_section = create_section(self.mesh, entity_dofs)[0]
[docs] def coordinate_offset(self, index): """ Get the DMPlex coordinate section offset for a given `index`. """ if not hasattr(self, "_coordinate_section"): self._get_coordinate_section() return self._coordinate_section.getOffset(index) // self.dim
[docs] def edge_vector_offset(self, index): """ Get the DMPlex edge vector section offset for a given `index`. """ if not hasattr(self, "_edge_vector_section"): self._get_edge_vector_section() return self._edge_vector_section.getOffset(index)
[docs] def coordinate(self, index): """ Get the mesh coordinate associated with a given `index`. """ return self.mesh.coordinates.dat.data_with_halos[self.get_offset(index)]
[docs] def move(self): """ Move the mesh according to the method of choice. """ raise NotImplementedError("Implement `move` in the derived class.")
[docs] def adapt(self): """ Alias of `move`. """ warn( "`adapt` is deprecated (use `move` instead)", DeprecationWarning, stacklevel=2, ) return self.move()
def plural(iterations): return "s" if iterations != 1 else ""