r"""
Nested dictionaries of solution data :class:`~.Function`\s.
"""
from abc import ABC, abstractmethod
import firedrake.function as ffunc
import firedrake.functionspace as ffs
from firedrake import TransferManager
from firedrake.checkpointing import CheckpointFile
from firedrake.output.vtk_output import VTKFile
from .utility import AttrDict
__all__ = [
"ForwardSolutionData",
"AdjointSolutionData",
"IndicatorData",
]
[docs]
class FunctionData(ABC):
"""
Abstract base class for classes holding field data.
"""
@abstractmethod
def __init__(self, time_partition, function_spaces):
r"""
:arg time_partition: the :class:`~.TimePartition` used to discretise the problem
in time
:arg function_spaces: the dictionary of :class:`~.FunctionSpace`\s used to
discretise the problem in space
"""
self.time_partition = time_partition
self.function_spaces = function_spaces
self._data = None
self.labels = self._label_dict[
"steady" if time_partition.steady else "unsteady"
]
def _create_data(self):
assert self._label_dict
tp = self.time_partition
self._data = AttrDict(
{
field: AttrDict(
{
label: [
[
ffunc.Function(fs, name=f"{field}_{label}")
for j in range(tp.num_exports_per_subinterval[i] - 1)
]
for i, fs in enumerate(self.function_spaces[field])
]
for label in self.labels
}
)
for field in tp.field_names
}
)
@property
def _data_by_field(self):
"""
Extract field data array in the default layout: as a doubly-nested dictionary
whose first key is the field name and second key is the field label. Entries
of the doubly-nested dictionary are doubly-nested lists, indexed first by
subinterval and then by export.
"""
if self._data is None:
self._create_data()
return self._data
def __getitem__(self, key):
return self._data_by_field[key]
[docs]
def items(self):
return self._data_by_field.items()
@property
def _data_by_label(self):
"""
Extract field data array in an alternative layout: as a doubly-nested dictionary
whose first key is the field label and second key is the field name. Entries
of the doubly-nested dictionary are doubly-nested lists, which retain the
default layout: indexed first by subinterval and then by export.
"""
tp = self.time_partition
return AttrDict(
{
label: AttrDict(
{f: self._data_by_field[f][label] for f in tp.field_names}
)
for label in self.labels
}
)
@property
def _data_by_subinterval(self):
"""
Extract field data array in an alternative format: as a list indexed by
subinterval. Entries of the list are doubly-nested dictionaries, which retain
the default layout: with the first key being field name and the second key being
the field label. Entries of the doubly-nested dictionaries are lists of field
data, indexed by export.
"""
tp = self.time_partition
return [
AttrDict(
{
field: AttrDict(
{
label: self._data_by_field[field][label][subinterval]
for label in self.labels
}
)
for field in tp.field_names
}
)
for subinterval in range(tp.num_subintervals)
]
[docs]
def export(self, output_fpath, export_field_types=None, initial_condition=None):
"""
Export field data to a file. The file format is determined by the extension of
the output file path. Supported formats are '.pvd' and '.h5'.
If the output file format is '.pvd', the data is exported as a series of VTK
files using Firedrake's :class:`~.VTKFile`. Since mixed function spaces are not
supported by VTK, each subfunction of a mixed function is exported separately.
If initial conditions are provided and fields other than 'forward' are to be
exported, the initial values of these other fields are set to 'nan' since they
are not defined at the initial time (e.g., 'adjoint' fields).
If the output file format is '.h5', the data is exported as a single HDF5 file
using Firedrake's :class:`~.CheckpointFile`. If names of meshes in the mesh
sequence are not unique, they are renamed to ``"mesh_i"``, where ``i`` is the
subinterval index. Functions are saved with names of the form ``"field_label"``.
Initial conditions are named in the form ``"field_initial"``. The exported data
may then be loaded using, for example,
.. code-block:: python
with CheckpointFile(output_fpath, "r") as afile:
first_mesh = afile.load_mesh("mesh_0")
initial_condition = afile.load_function(first_mesh, "u_initial")
first_export = afile.load_function(first_mesh, "u_forward", idx=0)
:arg output_fpath: the path to the output file
:type output_fpath: :class:`str`
:kwarg export_field_types: the field types to export; defaults to all available
field types
:type export_field_types: :class:`str` or :class:`list` of :class:`str`
:kwarg initial_condition: if provided, exports the provided initial condition
for 'forward' fields.
:type initial_condition: :class:`dict` of :class:`~.Function`
"""
if export_field_types is None:
default_export_types = {"forward", "adjoint", "error_indicator"}
export_field_types = list(set(self.labels) & default_export_types)
if isinstance(export_field_types, str):
export_field_types = [export_field_types]
if not all(field_type in self.labels for field_type in export_field_types):
raise ValueError(
f"Field types {export_field_types} not recognised."
f" Available types are {self.labels}."
)
if output_fpath.endswith(".pvd"):
self._export_vtk(output_fpath, export_field_types, initial_condition)
elif output_fpath.endswith(".h5"):
self._export_h5(output_fpath, export_field_types, initial_condition)
else:
raise ValueError(
f"Output file format not recognised: '{output_fpath}'."
" Supported formats are '.pvd' and '.h5'."
)
def _export_vtk(self, output_fpath, export_field_types, initial_condition=None):
"""
Export field data to a series of VTK files. Arguments are the same as for
:meth:`~.export`.
"""
tp = self.time_partition
outfile = VTKFile(output_fpath, adaptive=True)
if initial_condition is not None:
ics = []
for field, ic in initial_condition.items():
for field_type in export_field_types:
icc = ic.copy(deepcopy=True)
# If the function space is mixed, rename and append each
# subfunction separately
if hasattr(ic.function_space(), "num_sub_spaces"):
for idx, sf in enumerate(ic.subfunctions):
if field_type != "forward":
sf = sf.copy(deepcopy=True)
sf.assign(float("nan"))
sf.rename(f"{field}[{idx}]_{field_type}")
ics.append(sf)
else:
if field_type != "forward":
icc.assign(float("nan"))
icc.rename(f"{field}_{field_type}")
ics.append(icc)
outfile.write(*ics, time=tp.subintervals[0][0])
for i in range(tp.num_subintervals):
for j in range(tp.num_exports_per_subinterval[i] - 1):
time = (
tp.subintervals[i][0]
+ (j + 1) * tp.timesteps[i] * tp.num_timesteps_per_export[i]
)
fs = []
for field in tp.field_names:
mixed = hasattr(self.function_spaces[field][0], "num_sub_spaces")
for field_type in export_field_types:
f = self._data[field][field_type][i][j].copy(deepcopy=True)
if mixed:
for idx, sf in enumerate(f.subfunctions):
sf.rename(f"{field}[{idx}]_{field_type}")
fs.append(sf)
else:
f.rename(f"{field}_{field_type}")
fs.append(f)
outfile.write(*fs, time=time)
def _export_h5(self, output_fpath, export_field_types, initial_condition=None):
"""
Export field data to an HDF5 file. Arguments are the same as for
:meth:`~.export`.
"""
tp = self.time_partition
# Mesh names must be unique
mesh_names = [fs.mesh().name for fs in self.function_spaces[tp.field_names[0]]]
rename_meshes = len(set(mesh_names)) != len(mesh_names)
with CheckpointFile(output_fpath, "w") as outfile:
if initial_condition is not None:
for field, ic in initial_condition.items():
outfile.save_function(ic, name=f"{field}_initial")
for i in range(tp.num_subintervals):
if rename_meshes:
mesh_name = f"mesh_{i}"
mesh = self.function_spaces[tp.field_names[0]][i].mesh()
mesh.name = mesh_name
mesh.topology_dm.name = mesh_name
for field in tp.field_names:
for field_type in export_field_types:
name = f"{field}_{field_type}"
for j in range(tp.num_exports_per_subinterval[i] - 1):
f = self._data[field][field_type][i][j]
outfile.save_function(f, name=name, idx=j)
[docs]
def transfer(self, target, method="interpolate"):
"""
Transfer all functions from this :class:`~.FunctionData` object to the target
:class:`~.FunctionData` object by interpolation, projection or prolongation.
:arg target: the target :class:`~.FunctionData` object to which to transfer the
data
:type target: :class:`~.FunctionData`
:arg method: the transfer method to use. Either 'interpolate', 'project' or
'prolong'
:type method: :class:`str`
"""
stp = self.time_partition
ttp = target.time_partition
if method not in ["interpolate", "project", "prolong"]:
raise ValueError(
f"Transfer method '{method}' not supported."
" Supported methods are 'interpolate', 'project', and 'prolong'."
)
if stp.num_subintervals != ttp.num_subintervals:
raise ValueError(
"Source and target have different numbers of subintervals."
)
if stp.num_exports_per_subinterval != ttp.num_exports_per_subinterval:
raise ValueError(
"Source and target have different numbers of exports per subinterval."
)
common_fields = set(stp.field_names) & set(ttp.field_names)
if not common_fields:
raise ValueError("No common fields between source and target.")
common_labels = set(self.labels) & set(target.labels)
if not common_labels:
raise ValueError("No common labels between source and target.")
for field in common_fields:
for label in common_labels:
for i in range(stp.num_subintervals):
for j in range(stp.num_exports_per_subinterval[i] - 1):
source_function = self._data[field][label][i][j]
target_function = target._data[field][label][i][j]
if method == "interpolate":
target_function.interpolate(source_function)
elif method == "project":
target_function.project(source_function)
elif method == "prolong":
TransferManager().prolong(source_function, target_function)
[docs]
class ForwardSolutionData(FunctionData):
"""
Class representing solution data for general forward problems.
For a given exported timestep, the field types are:
* ``'forward'``: the forward solution after taking the timestep;
* ``'forward_old'``: the forward solution before taking the timestep (provided
the problem is not steady-state).
"""
def __init__(self, *args, **kwargs):
self._label_dict = {
"steady": ("forward",),
"unsteady": ("forward", "forward_old"),
}
super().__init__(*args, **kwargs)
[docs]
class AdjointSolutionData(FunctionData):
"""
Class representing solution data for general adjoint problems.
For a given exported timestep, the field types are:
* ``'forward'``: the forward solution after taking the timestep;
* ``'forward_old'``: the forward solution before taking the timestep (provided
the problem is not steady-state)
* ``'adjoint'``: the adjoint solution after taking the timestep;
* ``'adjoint_next'``: the adjoint solution before taking the timestep
backwards (provided the problem is not steady-state).
"""
def __init__(self, *args, **kwargs):
self._label_dict = {
"steady": ("forward", "adjoint"),
"unsteady": ("forward", "forward_old", "adjoint", "adjoint_next"),
}
super().__init__(*args, **kwargs)
[docs]
class IndicatorData(FunctionData):
"""
Class representing error indicator data.
Note that this class has a single dictionary with the field name as the key, rather
than a doubly-nested dictionary.
"""
def __init__(self, time_partition, meshes):
"""
:arg time_partition: the :class:`~.TimePartition` used to discretise the problem
in time
:arg meshes: the list of meshes used to discretise the problem in space
"""
self._label_dict = {
time_dep: ("error_indicator",) for time_dep in ("steady", "unsteady")
}
super().__init__(
time_partition,
{
key: [ffs.FunctionSpace(mesh, "DG", 0) for mesh in meshes]
for key in time_partition.field_names
},
)
@property
def _data_by_field(self):
"""
Extract indicator data array in the default layout: as a dictionary keyed with
the field name. Entries of the dictionary are doubly-nested lists, indexed first
by subinterval and then by export.
"""
if self._data is None:
self._create_data()
return AttrDict(
{
field: self._data[field]["error_indicator"]
for field in self.time_partition.field_names
}
)
@property
def _data_by_label(self):
"""
For indicator data there is only one field label (``"error_indicator"``), so
this method just delegates to :meth:`~._data_by_field`.
"""
return self._data_by_field
@property
def _data_by_subinterval(self):
"""
Extract indicator data array in an alternative format: as a list indexed by
subinterval. Entries of the list are dictionaries, keyed by field label.
Entries of the dictionaries are lists of field data, indexed by export.
"""
tp = self.time_partition
return [
AttrDict({f: self._data_by_field[f][subinterval] for f in tp.field_names})
for subinterval in range(tp.num_subintervals)
]