"""
Drivers for goal-oriented error estimation on sequences of meshes.
"""
from collections.abc import Iterable
from copy import deepcopy
import numpy as np
import ufl
from animate.interpolation import interpolate
from firedrake import Function, FunctionSpace, MeshHierarchy, TransferManager
from firedrake.petsc import PETSc
from .adjoint import AdjointMeshSeq
from .error_estimation import get_dwr_indicator
from .function_data import IndicatorData
from .log import pyrint
from .options import GoalOrientedAdaptParameters
__all__ = ["GoalOrientedMeshSeq"]
[docs]
class GoalOrientedMeshSeq(AdjointMeshSeq):
"""
An extension of :class:`~.AdjointMeshSeq` to account for goal-oriented problems.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.estimator_values = []
self._forms = None
self._prev_form_coeffs = None
self._changed_form_coeffs = None
@property
def forms(self):
"""
Get the variational form associated with each prognostic field.
:returns: dictionary where the keys are the field names and the values are the
UFL forms
:rtype: :class:`dict`
"""
if self._forms is None:
raise AttributeError(
"Forms have not been read in."
" Use read_forms({'field_name': F}) in get_solver to read in the forms."
)
return self._forms
@PETSc.Log.EventDecorator()
def _detect_changing_coefficients(self, export_idx):
"""
Detect whether coefficients other than the solution in the variational forms
change over time. If they do, store the changed coefficients so we can update
them in :meth:`~.GoalOrientedMeshSeq.indicate_errors`.
Changed coefficients are stored in a dictionary with the following structure:
``{field: {coeff_idx: {export_timestep_idx: coefficient}}}``, where
``coefficient=forms[field].coefficients()[coeff_idx]`` at export timestep
``export_timestep_idx``.
:arg export_idx: index of the current export timestep within the subinterval
:type export_idx: :class:`int`
"""
if export_idx == 0:
# Copy coefficients at subinterval's first export timestep
self._prev_form_coeffs = {
field: deepcopy(form.coefficients())
for field, form in self.forms.items()
}
self._changed_form_coeffs = {field: {} for field in self.fields}
else:
# Store coefficients that have changed since the previous export timestep
for field in self.fields:
# Coefficients at the current timestep
coeffs = self.forms[field].coefficients()
for coeff_idx, (coeff, init_coeff) in enumerate(
zip(coeffs, self._prev_form_coeffs[field])
):
# Skip solution fields since they are stored separately
if coeff.name().split("_old")[0] in self.time_partition.field_names:
continue
if not np.allclose(
coeff.vector().array(), init_coeff.vector().array()
):
if coeff_idx not in self._changed_form_coeffs[field]:
self._changed_form_coeffs[field][coeff_idx] = {
0: deepcopy(init_coeff)
}
self._changed_form_coeffs[field][coeff_idx][export_idx] = (
deepcopy(coeff)
)
# Use the current coeff for comparison in the next timestep
init_coeff.assign(coeff)
[docs]
@PETSc.Log.EventDecorator()
def get_enriched_mesh_seq(self, enrichment_method="p", num_enrichments=1):
"""
Construct a sequence of globally enriched spaces.
The following global enrichment methods are supported:
* h-refinement (``enrichment_method='h'``) - refine each mesh element
uniformly in each direction;
* p-refinement (``enrichment_method='p'``) - increase the function space
polynomial order by one globally.
:kwarg enrichment_method: the method for enriching the mesh sequence
:type enrichment_method: :class:`str`
:kwarg num_enrichments: the number of enrichments to apply
:type num_enrichments: :class:`int`
:returns: the enriched mesh sequence
:type: the type is inherited from the parent mesh sequence
"""
if enrichment_method not in ("h", "p"):
raise ValueError(f"Enrichment method '{enrichment_method}' not supported.")
if num_enrichments <= 0:
raise ValueError("A positive number of enrichments is required.")
# Apply h-refinement
if enrichment_method == "h":
if any(mesh == self.meshes[0] for mesh in self.meshes[1:]):
raise ValueError(
"h-enrichment is not supported for shallow-copied meshes."
)
meshes = [MeshHierarchy(mesh, num_enrichments)[-1] for mesh in self.meshes]
else:
meshes = self.meshes
# Construct object to hold enriched spaces
enriched_mesh_seq = type(self)(
self.time_partition,
meshes,
get_function_spaces=self._get_function_spaces,
get_initial_condition=self._get_initial_condition,
get_solver=self._get_solver,
get_qoi=self._get_qoi,
qoi_type=self.qoi_type,
)
enriched_mesh_seq._update_function_spaces()
# Apply p-refinement
if enrichment_method == "p":
for label, fs in enriched_mesh_seq.function_spaces.items():
for n, _space in enumerate(fs):
element = _space.ufl_element()
element = element.reconstruct(
degree=element.degree() + num_enrichments
)
enriched_mesh_seq._fs[label][n] = FunctionSpace(
enriched_mesh_seq.meshes[n], element
)
return enriched_mesh_seq
@staticmethod
def _get_transfer_function(enrichment_method):
"""
Get the function for transferring function data between a mesh sequence and its
enriched counterpart.
:arg enrichment_method: the enrichment method used to generate the counterpart
- see :meth:`~.GoalOrientedMeshSeq.get_enriched_mesh_seq` for the supported
enrichment methods
:type enrichment_method: :class:`str`
:returns: the function for mapping function data between mesh sequences
"""
if enrichment_method == "h":
return TransferManager().prolong
else:
return interpolate
def _create_indicators(self):
"""
Create the :class:`~.FunctionData` instance for holding error indicator data.
"""
self._indicators = IndicatorData(self.time_partition, self.meshes)
@property
def indicators(self):
"""
:returns: the error indicator data object
:rtype: :class:`~.IndicatorData`
"""
if not hasattr(self, "_indicators"):
self._create_indicators()
return self._indicators
[docs]
@PETSc.Log.EventDecorator()
def indicate_errors(
self, enrichment_kwargs=None, solver_kwargs=None, indicator_fn=get_dwr_indicator
):
"""
Compute goal-oriented error indicators for each subinterval based on solving the
adjoint problem in a globally enriched space.
:kwarg enrichment_kwargs: keyword arguments to pass to the global enrichment
method - see :meth:`~.GoalOrientedMeshSeq.get_enriched_mesh_seq` for the
supported enrichment methods and options
:type enrichment_kwargs: :class:`dict` with :class:`str` keys and values which
may take various types
:kwarg solver_kwargs: parameters for the forward solver, as well as any
parameters for the QoI, which should be included as a sub-dictionary with
key 'qoi_kwargs'
:type solver_kwargs: :class:`dict` with :class:`str` keys and values which may
take various types
:kwarg indicator_fn: function which maps the form, adjoint error and enriched
space(s) as arguments to the error indicator
:class:`firedrake.function.Function`
:returns: solution and indicator data objects
:rtype1: :class:`~.AdjointSolutionData`
:rtype2: :class:`~.IndicatorData`
"""
solver_kwargs = solver_kwargs or {}
default_enrichment_kwargs = {"enrichment_method": "p", "num_enrichments": 1}
enrichment_kwargs = dict(default_enrichment_kwargs, **(enrichment_kwargs or {}))
enriched_mesh_seq = self.get_enriched_mesh_seq(**enrichment_kwargs)
transfer = self._get_transfer_function(enrichment_kwargs["enrichment_method"])
# Reinitialise the error indicator data object
self._create_indicators()
# Initialise adjoint solver generators on the MeshSeq and its enriched version
adj_sol_gen = self._solve_adjoint(**solver_kwargs)
# Track form coefficient changes in the enriched problem if the problem is
# unsteady
adj_sol_gen_enriched = enriched_mesh_seq._solve_adjoint(
track_coefficients=not self.steady,
**solver_kwargs,
)
FWD, ADJ = "forward", "adjoint"
FWD_OLD = "forward" if self.steady else "forward_old"
ADJ_NEXT = "adjoint" if self.steady else "adjoint_next"
P0_spaces = [FunctionSpace(mesh, "DG", 0) for mesh in self]
# Loop over each subinterval in reverse
for i in reversed(range(len(self))):
# Solve the adjoint problem on the current subinterval
next(adj_sol_gen)
next(adj_sol_gen_enriched)
# Get Functions
u, u_, u_star, u_star_next, u_star_e = {}, {}, {}, {}, {}
enriched_spaces = {
f: enriched_mesh_seq.function_spaces[f][i] for f in self.fields
}
for f, fs_e in enriched_spaces.items():
if self.field_types[f] == "steady":
u[f] = enriched_mesh_seq.fields[f]
else:
u[f], u_[f] = enriched_mesh_seq.fields[f]
u_star[f] = Function(fs_e)
u_star_next[f] = Function(fs_e)
u_star_e[f] = Function(fs_e)
# Loop over each timestep
for j in range(self.time_partition.num_exports_per_subinterval[i] - 1):
# In case of having multiple solution fields that are solved for one
# after another, the field that is solved for first uses the values of
# latter fields from the previous timestep. Therefore, we must transfer
# the lagged solution of latter fields as if they were the current
# timestep solutions. This assumes that the order of fields being solved
# for in get_solver is the same as their order in self.fields
for f_next in self.time_partition.field_names[1:]:
transfer(self.solutions[f_next][FWD_OLD][i][j], u[f_next])
# Loop over each strongly coupled field
for f in self.fields:
# Transfer solutions associated with the current field f
transfer(self.solutions[f][FWD][i][j], u[f])
if self.field_types[f] == "unsteady":
transfer(self.solutions[f][FWD_OLD][i][j], u_[f])
transfer(self.solutions[f][ADJ][i][j], u_star[f])
transfer(self.solutions[f][ADJ_NEXT][i][j], u_star_next[f])
# Combine adjoint solutions as appropriate
u_star[f].assign(0.5 * (u_star[f] + u_star_next[f]))
u_star_e[f].assign(
0.5
* (
enriched_mesh_seq.solutions[f][ADJ][i][j]
+ enriched_mesh_seq.solutions[f][ADJ_NEXT][i][j]
)
)
u_star_e[f] -= u_star[f]
# Update other time-dependent form coefficients if they changed
# since the previous export timestep
emseq = enriched_mesh_seq
if not self.steady and emseq._changed_form_coeffs[f]:
for idx, coeffs in emseq._changed_form_coeffs[f].items():
if j in coeffs:
emseq.forms[f].coefficients()[idx].assign(coeffs[j])
# Evaluate error indicator
indi_e = indicator_fn(enriched_mesh_seq.forms[f], u_star_e[f])
# Transfer back to the base space
indi = self._transfer(indi_e, P0_spaces[i])
indi.interpolate(abs(indi))
self.indicators[f][i][j].interpolate(ufl.max_value(indi, 1.0e-16))
return self.solutions, self.indicators
[docs]
@PETSc.Log.EventDecorator()
def error_estimate(self, absolute_value=False):
r"""
Deduce the error estimator value associated with error indicator fields defined
over the mesh sequence.
:kwarg absolute_value: if ``True``, the modulus is taken on each element
:type absolute_value: :class:`bool`
:returns: the error estimator value
:rtype: :class:`float`
"""
assert isinstance(self.indicators, IndicatorData)
if not isinstance(absolute_value, bool):
raise TypeError(
f"Expected 'absolute_value' to be a bool, not '{type(absolute_value)}'."
)
estimator = 0
for field, by_field in self.indicators.items():
if field not in self.time_partition.field_names:
raise ValueError(
f"Key '{field}' does not exist in the TimePartition provided."
)
assert not isinstance(by_field, Function) and isinstance(by_field, Iterable)
for by_mesh, dt in zip(by_field, self.time_partition.timesteps):
assert not isinstance(by_mesh, Function) and isinstance(
by_mesh, Iterable
)
for indicator in by_mesh:
if absolute_value:
indicator.interpolate(abs(indicator))
estimator += dt * indicator.vector().gather().sum()
return estimator
[docs]
def check_estimator_convergence(self):
"""
Check for convergence of the fixed point iteration due to the relative
difference in error estimator value being smaller than the specified tolerance.
:return: ``True`` if estimator convergence is detected, else ``False``
:rtype: :class:`bool`
"""
if not self.check_convergence.any():
self.info(
"Skipping estimator convergence check because check_convergence"
f" contains False values for indices {self._subintervals_not_checked}."
)
return False
if len(self.estimator_values) >= max(2, self.params.miniter + 1):
ee_, ee = self.estimator_values[-2:]
if abs(ee - ee_) < self.params.estimator_rtol * abs(ee_):
pyrint(
f"Error estimator converged after {self.fp_iteration+1} iterations"
f" under relative tolerance {self.params.estimator_rtol}."
)
return True
return False
[docs]
@PETSc.Log.EventDecorator()
def fixed_point_iteration(
self,
adaptor,
parameters=None,
update_params=None,
enrichment_kwargs=None,
adaptor_kwargs=None,
solver_kwargs=None,
indicator_fn=get_dwr_indicator,
):
r"""
Apply goal-oriented mesh adaptation using a fixed point iteration loop approach.
:arg adaptor: function for adapting the mesh sequence. Its arguments are the
mesh sequence and the solution and indicator data objects. It should return
``True`` if the convergence criteria checks are to be skipped for this
iteration. Otherwise, it should return ``False``.
:kwarg parameters: parameters to apply to the mesh adaptation process
:type parameters: :class:`~.GoalOrientedAdaptParameters`
:kwarg update_params: function for updating :attr:`~.MeshSeq.params` at each
iteration. Its arguments are the parameter class and the fixed point
iteration
:kwarg enrichment_kwargs: keyword arguments to pass to the global enrichment
method
:type enrichment_kwargs: :class:`dict` with :class:`str` keys and values which
may take various types
:kwarg solver_kwargs: parameters to pass to the solver
:type solver_kwargs: :class:`dict` with :class:`str` keys and values which may
take various types
:kwarg adaptor_kwargs: parameters to pass to the adaptor
:type adaptor_kwargs: :class:`dict` with :class:`str` keys and values which may
take various types
:kwarg indicator_fn: function which maps the form, adjoint error and enriched
space(s) as arguments to the error indicator
:class:`firedrake.function.Function`
:returns: solution and indicator data objects
:rtype1: :class:`~.AdjointSolutionData`
:rtype2: :class:`~.IndicatorData`
"""
# TODO #124: adaptor no longer needs solution and indicator data to be passed
# explicitly
self.params = parameters or GoalOrientedAdaptParameters()
enrichment_kwargs = enrichment_kwargs or {}
adaptor_kwargs = adaptor_kwargs or {}
solver_kwargs = solver_kwargs or {}
self._reset_counts()
self.qoi_values = []
self.estimator_values = []
self.converged[:] = False
self.check_convergence[:] = True
for fp_iteration in range(self.params.maxiter):
self.fp_iteration = fp_iteration
if update_params is not None:
update_params(self.params, self.fp_iteration)
# Indicate errors over all meshes
self.indicate_errors(
enrichment_kwargs=enrichment_kwargs,
solver_kwargs=solver_kwargs,
indicator_fn=indicator_fn,
)
# Check for QoI convergence
# TODO #23: Put this check inside the adjoint solve as an optional return
# condition so that we can avoid unnecessary extra solves
self.qoi_values.append(self.J)
qoi_converged = self.check_qoi_convergence()
if self.params.convergence_criteria == "any" and qoi_converged:
self.converged[:] = True
break
# Check for error estimator convergence
self.estimator_values.append(self.error_estimate())
ee_converged = self.check_estimator_convergence()
if self.params.convergence_criteria == "any" and ee_converged:
self.converged[:] = True
break
# Adapt meshes and log element counts
continue_unconditionally = adaptor(
self, self.solutions, self.indicators, **adaptor_kwargs
)
if self.params.drop_out_converged:
self.check_convergence[:] = np.logical_not(
np.logical_or(continue_unconditionally, self.converged)
)
self.element_counts.append(self.count_elements())
self.vertex_counts.append(self.count_vertices())
# Check for element count convergence
self.converged[:] = self.check_element_count_convergence()
elem_converged = self.converged.all()
if self.params.convergence_criteria == "any" and elem_converged:
break
# Convergence check for 'all' mode
if qoi_converged and ee_converged and elem_converged:
break
else:
if self.params.convergence_criteria == "all":
pyrint(f"Failed to converge in {self.params.maxiter} iterations.")
self.converged[:] = False
else:
for i, conv in enumerate(self.converged):
if not conv:
pyrint(
f"Failed to converge on subinterval {i} in"
f" {self.params.maxiter} iterations."
)
return self.solutions, self.indicators