Source code for goalie.go_mesh_seq

"""
Drivers for goal-oriented error estimation on sequences of meshes.
"""

from collections.abc import Iterable
from copy import deepcopy

import numpy as np
import ufl
from animate.interpolation import interpolate
from firedrake import Function, FunctionSpace, MeshHierarchy, TransferManager
from firedrake.petsc import PETSc

from .adjoint import AdjointMeshSeq
from .error_estimation import get_dwr_indicator
from .function_data import IndicatorData
from .log import pyrint
from .options import GoalOrientedAdaptParameters

__all__ = ["GoalOrientedMeshSeq"]


[docs] class GoalOrientedMeshSeq(AdjointMeshSeq): """ An extension of :class:`~.AdjointMeshSeq` to account for goal-oriented problems. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.estimator_values = [] self._forms = None self._prev_form_coeffs = None self._changed_form_coeffs = None
[docs] def read_forms(self, forms_dictionary): """ Read in the variational form corresponding to each prognostic field. :arg forms_dictionary: dictionary where the keys are the field names and the values are the UFL forms :type forms_dictionary: :class:`dict` """ for field, form in forms_dictionary.items(): if field not in self.fields: raise ValueError( f"Unexpected field '{field}' in forms dictionary." f" Expected one of {self.time_partition.field_names}." ) if not isinstance(form, ufl.Form): raise TypeError( f"Expected a UFL form for field '{field}', not '{type(form)}'." ) self._forms = forms_dictionary
@property def forms(self): """ Get the variational form associated with each prognostic field. :returns: dictionary where the keys are the field names and the values are the UFL forms :rtype: :class:`dict` """ if self._forms is None: raise AttributeError( "Forms have not been read in." " Use read_forms({'field_name': F}) in get_solver to read in the forms." ) return self._forms @PETSc.Log.EventDecorator() def _detect_changing_coefficients(self, export_idx): """ Detect whether coefficients other than the solution in the variational forms change over time. If they do, store the changed coefficients so we can update them in :meth:`~.GoalOrientedMeshSeq.indicate_errors`. Changed coefficients are stored in a dictionary with the following structure: ``{field: {coeff_idx: {export_timestep_idx: coefficient}}}``, where ``coefficient=forms[field].coefficients()[coeff_idx]`` at export timestep ``export_timestep_idx``. :arg export_idx: index of the current export timestep within the subinterval :type export_idx: :class:`int` """ if export_idx == 0: # Copy coefficients at subinterval's first export timestep self._prev_form_coeffs = { field: deepcopy(form.coefficients()) for field, form in self.forms.items() } self._changed_form_coeffs = {field: {} for field in self.fields} else: # Store coefficients that have changed since the previous export timestep for field in self.fields: # Coefficients at the current timestep coeffs = self.forms[field].coefficients() for coeff_idx, (coeff, init_coeff) in enumerate( zip(coeffs, self._prev_form_coeffs[field]) ): # Skip solution fields since they are stored separately if coeff.name().split("_old")[0] in self.time_partition.field_names: continue if not np.allclose( coeff.vector().array(), init_coeff.vector().array() ): if coeff_idx not in self._changed_form_coeffs[field]: self._changed_form_coeffs[field][coeff_idx] = { 0: deepcopy(init_coeff) } self._changed_form_coeffs[field][coeff_idx][export_idx] = ( deepcopy(coeff) ) # Use the current coeff for comparison in the next timestep init_coeff.assign(coeff)
[docs] @PETSc.Log.EventDecorator() def get_enriched_mesh_seq(self, enrichment_method="p", num_enrichments=1): """ Construct a sequence of globally enriched spaces. The following global enrichment methods are supported: * h-refinement (``enrichment_method='h'``) - refine each mesh element uniformly in each direction; * p-refinement (``enrichment_method='p'``) - increase the function space polynomial order by one globally. :kwarg enrichment_method: the method for enriching the mesh sequence :type enrichment_method: :class:`str` :kwarg num_enrichments: the number of enrichments to apply :type num_enrichments: :class:`int` :returns: the enriched mesh sequence :type: the type is inherited from the parent mesh sequence """ if enrichment_method not in ("h", "p"): raise ValueError(f"Enrichment method '{enrichment_method}' not supported.") if num_enrichments <= 0: raise ValueError("A positive number of enrichments is required.") # Apply h-refinement if enrichment_method == "h": if any(mesh == self.meshes[0] for mesh in self.meshes[1:]): raise ValueError( "h-enrichment is not supported for shallow-copied meshes." ) meshes = [MeshHierarchy(mesh, num_enrichments)[-1] for mesh in self.meshes] else: meshes = self.meshes # Construct object to hold enriched spaces enriched_mesh_seq = type(self)( self.time_partition, meshes, get_function_spaces=self._get_function_spaces, get_initial_condition=self._get_initial_condition, get_solver=self._get_solver, get_qoi=self._get_qoi, qoi_type=self.qoi_type, ) enriched_mesh_seq._update_function_spaces() # Apply p-refinement if enrichment_method == "p": for label, fs in enriched_mesh_seq.function_spaces.items(): for n, _space in enumerate(fs): element = _space.ufl_element() element = element.reconstruct( degree=element.degree() + num_enrichments ) enriched_mesh_seq._fs[label][n] = FunctionSpace( enriched_mesh_seq.meshes[n], element ) return enriched_mesh_seq
@staticmethod def _get_transfer_function(enrichment_method): """ Get the function for transferring function data between a mesh sequence and its enriched counterpart. :arg enrichment_method: the enrichment method used to generate the counterpart - see :meth:`~.GoalOrientedMeshSeq.get_enriched_mesh_seq` for the supported enrichment methods :type enrichment_method: :class:`str` :returns: the function for mapping function data between mesh sequences """ if enrichment_method == "h": return TransferManager().prolong else: return interpolate def _create_indicators(self): """ Create the :class:`~.FunctionData` instance for holding error indicator data. """ self._indicators = IndicatorData(self.time_partition, self.meshes) @property def indicators(self): """ :returns: the error indicator data object :rtype: :class:`~.IndicatorData` """ if not hasattr(self, "_indicators"): self._create_indicators() return self._indicators
[docs] @PETSc.Log.EventDecorator() def indicate_errors( self, enrichment_kwargs=None, solver_kwargs=None, indicator_fn=get_dwr_indicator ): """ Compute goal-oriented error indicators for each subinterval based on solving the adjoint problem in a globally enriched space. :kwarg enrichment_kwargs: keyword arguments to pass to the global enrichment method - see :meth:`~.GoalOrientedMeshSeq.get_enriched_mesh_seq` for the supported enrichment methods and options :type enrichment_kwargs: :class:`dict` with :class:`str` keys and values which may take various types :kwarg solver_kwargs: parameters for the forward solver, as well as any parameters for the QoI, which should be included as a sub-dictionary with key 'qoi_kwargs' :type solver_kwargs: :class:`dict` with :class:`str` keys and values which may take various types :kwarg indicator_fn: function which maps the form, adjoint error and enriched space(s) as arguments to the error indicator :class:`firedrake.function.Function` :returns: solution and indicator data objects :rtype1: :class:`~.AdjointSolutionData` :rtype2: :class:`~.IndicatorData` """ solver_kwargs = solver_kwargs or {} default_enrichment_kwargs = {"enrichment_method": "p", "num_enrichments": 1} enrichment_kwargs = dict(default_enrichment_kwargs, **(enrichment_kwargs or {})) enriched_mesh_seq = self.get_enriched_mesh_seq(**enrichment_kwargs) transfer = self._get_transfer_function(enrichment_kwargs["enrichment_method"]) # Reinitialise the error indicator data object self._create_indicators() # Initialise adjoint solver generators on the MeshSeq and its enriched version adj_sol_gen = self._solve_adjoint(**solver_kwargs) # Track form coefficient changes in the enriched problem if the problem is # unsteady adj_sol_gen_enriched = enriched_mesh_seq._solve_adjoint( track_coefficients=not self.steady, **solver_kwargs, ) FWD, ADJ = "forward", "adjoint" FWD_OLD = "forward" if self.steady else "forward_old" ADJ_NEXT = "adjoint" if self.steady else "adjoint_next" P0_spaces = [FunctionSpace(mesh, "DG", 0) for mesh in self] # Loop over each subinterval in reverse for i in reversed(range(len(self))): # Solve the adjoint problem on the current subinterval next(adj_sol_gen) next(adj_sol_gen_enriched) # Get Functions u, u_, u_star, u_star_next, u_star_e = {}, {}, {}, {}, {} enriched_spaces = { f: enriched_mesh_seq.function_spaces[f][i] for f in self.fields } for f, fs_e in enriched_spaces.items(): if self.field_types[f] == "steady": u[f] = enriched_mesh_seq.fields[f] else: u[f], u_[f] = enriched_mesh_seq.fields[f] u_star[f] = Function(fs_e) u_star_next[f] = Function(fs_e) u_star_e[f] = Function(fs_e) # Loop over each timestep for j in range(self.time_partition.num_exports_per_subinterval[i] - 1): # In case of having multiple solution fields that are solved for one # after another, the field that is solved for first uses the values of # latter fields from the previous timestep. Therefore, we must transfer # the lagged solution of latter fields as if they were the current # timestep solutions. This assumes that the order of fields being solved # for in get_solver is the same as their order in self.fields for f_next in self.time_partition.field_names[1:]: transfer(self.solutions[f_next][FWD_OLD][i][j], u[f_next]) # Loop over each strongly coupled field for f in self.fields: # Transfer solutions associated with the current field f transfer(self.solutions[f][FWD][i][j], u[f]) if self.field_types[f] == "unsteady": transfer(self.solutions[f][FWD_OLD][i][j], u_[f]) transfer(self.solutions[f][ADJ][i][j], u_star[f]) transfer(self.solutions[f][ADJ_NEXT][i][j], u_star_next[f]) # Combine adjoint solutions as appropriate u_star[f].assign(0.5 * (u_star[f] + u_star_next[f])) u_star_e[f].assign( 0.5 * ( enriched_mesh_seq.solutions[f][ADJ][i][j] + enriched_mesh_seq.solutions[f][ADJ_NEXT][i][j] ) ) u_star_e[f] -= u_star[f] # Update other time-dependent form coefficients if they changed # since the previous export timestep emseq = enriched_mesh_seq if not self.steady and emseq._changed_form_coeffs[f]: for idx, coeffs in emseq._changed_form_coeffs[f].items(): if j in coeffs: emseq.forms[f].coefficients()[idx].assign(coeffs[j]) # Evaluate error indicator indi_e = indicator_fn(enriched_mesh_seq.forms[f], u_star_e[f]) # Transfer back to the base space indi = self._transfer(indi_e, P0_spaces[i]) indi.interpolate(abs(indi)) self.indicators[f][i][j].interpolate(ufl.max_value(indi, 1.0e-16)) return self.solutions, self.indicators
[docs] @PETSc.Log.EventDecorator() def error_estimate(self, absolute_value=False): r""" Deduce the error estimator value associated with error indicator fields defined over the mesh sequence. :kwarg absolute_value: if ``True``, the modulus is taken on each element :type absolute_value: :class:`bool` :returns: the error estimator value :rtype: :class:`float` """ assert isinstance(self.indicators, IndicatorData) if not isinstance(absolute_value, bool): raise TypeError( f"Expected 'absolute_value' to be a bool, not '{type(absolute_value)}'." ) estimator = 0 for field, by_field in self.indicators.items(): if field not in self.time_partition.field_names: raise ValueError( f"Key '{field}' does not exist in the TimePartition provided." ) assert not isinstance(by_field, Function) and isinstance(by_field, Iterable) for by_mesh, dt in zip(by_field, self.time_partition.timesteps): assert not isinstance(by_mesh, Function) and isinstance( by_mesh, Iterable ) for indicator in by_mesh: if absolute_value: indicator.interpolate(abs(indicator)) estimator += dt * indicator.vector().gather().sum() return estimator
[docs] def check_estimator_convergence(self): """ Check for convergence of the fixed point iteration due to the relative difference in error estimator value being smaller than the specified tolerance. :return: ``True`` if estimator convergence is detected, else ``False`` :rtype: :class:`bool` """ if not self.check_convergence.any(): self.info( "Skipping estimator convergence check because check_convergence" f" contains False values for indices {self._subintervals_not_checked}." ) return False if len(self.estimator_values) >= max(2, self.params.miniter + 1): ee_, ee = self.estimator_values[-2:] if abs(ee - ee_) < self.params.estimator_rtol * abs(ee_): pyrint( f"Error estimator converged after {self.fp_iteration+1} iterations" f" under relative tolerance {self.params.estimator_rtol}." ) return True return False
[docs] @PETSc.Log.EventDecorator() def fixed_point_iteration( self, adaptor, parameters=None, update_params=None, enrichment_kwargs=None, adaptor_kwargs=None, solver_kwargs=None, indicator_fn=get_dwr_indicator, ): r""" Apply goal-oriented mesh adaptation using a fixed point iteration loop approach. :arg adaptor: function for adapting the mesh sequence. Its arguments are the mesh sequence and the solution and indicator data objects. It should return ``True`` if the convergence criteria checks are to be skipped for this iteration. Otherwise, it should return ``False``. :kwarg parameters: parameters to apply to the mesh adaptation process :type parameters: :class:`~.GoalOrientedAdaptParameters` :kwarg update_params: function for updating :attr:`~.MeshSeq.params` at each iteration. Its arguments are the parameter class and the fixed point iteration :kwarg enrichment_kwargs: keyword arguments to pass to the global enrichment method :type enrichment_kwargs: :class:`dict` with :class:`str` keys and values which may take various types :kwarg solver_kwargs: parameters to pass to the solver :type solver_kwargs: :class:`dict` with :class:`str` keys and values which may take various types :kwarg adaptor_kwargs: parameters to pass to the adaptor :type adaptor_kwargs: :class:`dict` with :class:`str` keys and values which may take various types :kwarg indicator_fn: function which maps the form, adjoint error and enriched space(s) as arguments to the error indicator :class:`firedrake.function.Function` :returns: solution and indicator data objects :rtype1: :class:`~.AdjointSolutionData` :rtype2: :class:`~.IndicatorData` """ # TODO #124: adaptor no longer needs solution and indicator data to be passed # explicitly self.params = parameters or GoalOrientedAdaptParameters() enrichment_kwargs = enrichment_kwargs or {} adaptor_kwargs = adaptor_kwargs or {} solver_kwargs = solver_kwargs or {} self._reset_counts() self.qoi_values = [] self.estimator_values = [] self.converged[:] = False self.check_convergence[:] = True for fp_iteration in range(self.params.maxiter): self.fp_iteration = fp_iteration if update_params is not None: update_params(self.params, self.fp_iteration) # Indicate errors over all meshes self.indicate_errors( enrichment_kwargs=enrichment_kwargs, solver_kwargs=solver_kwargs, indicator_fn=indicator_fn, ) # Check for QoI convergence # TODO #23: Put this check inside the adjoint solve as an optional return # condition so that we can avoid unnecessary extra solves self.qoi_values.append(self.J) qoi_converged = self.check_qoi_convergence() if self.params.convergence_criteria == "any" and qoi_converged: self.converged[:] = True break # Check for error estimator convergence self.estimator_values.append(self.error_estimate()) ee_converged = self.check_estimator_convergence() if self.params.convergence_criteria == "any" and ee_converged: self.converged[:] = True break # Adapt meshes and log element counts continue_unconditionally = adaptor( self, self.solutions, self.indicators, **adaptor_kwargs ) if self.params.drop_out_converged: self.check_convergence[:] = np.logical_not( np.logical_or(continue_unconditionally, self.converged) ) self.element_counts.append(self.count_elements()) self.vertex_counts.append(self.count_vertices()) # Check for element count convergence self.converged[:] = self.check_element_count_convergence() elem_converged = self.converged.all() if self.params.convergence_criteria == "any" and elem_converged: break # Convergence check for 'all' mode if qoi_converged and ee_converged and elem_converged: break else: if self.params.convergence_criteria == "all": pyrint(f"Failed to converge in {self.params.maxiter} iterations.") self.converged[:] = False else: for i, conv in enumerate(self.converged): if not conv: pyrint( f"Failed to converge on subinterval {i} in" f" {self.params.maxiter} iterations." ) return self.solutions, self.indicators