Source code for animate.checkpointing

import os
from tempfile import mkdtemp

import firedrake
import firedrake.checkpointing as fchk
import firedrake.function as ffunc

from .metric import RiemannianMetric

__all__ = ["get_checkpoint_dir", "load_checkpoint", "save_checkpoint"]


[docs] def get_checkpoint_dir(): """ Make a temporary directory for checkpointing and return its path. """ if os.environ.get("ANIMATE_CHECKPOINT_DIR"): checkpoint_dir = os.environ["ANIMATE_CHECKPOINT_DIR"] else: animate_base_dir = os.path.dirname(os.path.realpath(__file__)) checkpoint_dir = os.path.join(animate_base_dir, ".checkpoints") comm = firedrake.COMM_WORLD if comm.rank == 0: if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) tmpdir = mkdtemp(prefix="animate-checkpoint", dir=checkpoint_dir) comm.bcast(tmpdir, root=0) else: tmpdir = comm.bcast(None, root=0) comm.barrier() return tmpdir
[docs] def load_checkpoint(filepath, mesh_name, metric_name, comm=firedrake.COMM_WORLD): """ Load a metric from a :class:`~.CheckpointFile`. Note that the checkpoint will have to be stored within Animate's ``.checkpoints`` subdirectory. :arg filepath: the path to the checkpoint file :type filepath: :class:`str` :arg mesh_name: the name under which the mesh is saved in the checkpoint file :type mesh_name: :class:`str` :arg metric_name: the name under which the metric is saved in the checkpoint file :type metric_name: :class:`str` :kwarg comm: MPI communicator for handling the checkpoint file :type comm: :class:`mpi4py.MPI.Intracom` :returns: the metric loaded from the checkpoint :rtype: :class:`animate.metric.RiemannianMetric` """ if not os.path.exists(filepath): raise Exception(f"Metric file does not exist! Path: {filepath}.") with fchk.CheckpointFile(filepath, "r", comm=comm) as chk: mesh = chk.load_mesh(mesh_name) metric = chk.load_function(mesh, metric_name) # Load stashed metric parameters mp = chk._read_pickled_dict("metric_parameters", "mp_dict") for key, value in mp.items(): if value == "Function": mp[key] = chk.load_function(mesh, key) metric = RiemannianMetric(metric.function_space()).assign(metric) metric.set_parameters(mp) return metric
[docs] def save_checkpoint(filepath, metric, metric_name=None, comm=firedrake.COMM_WORLD): """ Write the metric and underlying mesh to a :class:`~.CheckpointFile`. Note that the checkpoint will be stored within Animate's ``.checkpoints`` subdirectory. :arg filepath: the path of the checkpoint file :type filepath: :class:`str` :arg metric: the metric to save to the checkpoint :type metric: :class:`animate.metric.RiemannianMetric` :kwarg metric_name: the name under which to save the metric in the checkpoint file :type metric_name: :class:`str` :kwarg comm: MPI communicator for handling the checkpoint file :type comm: :class:`mpi4py.MPI.Intracom` """ mp = metric.metric_parameters.copy() with fchk.CheckpointFile(filepath, "w", comm=comm) as chk: chk.save_mesh(metric._mesh) chk.save_function(metric, name=metric_name or metric.name()) # Stash metric parameters for key, value in metric._variable_parameters.items(): if isinstance(value, ffunc.Function): chk.save_function(value, name=key) mp[key] = "Function" elif isinstance(value, firedrake.Constant): mp[key] = float(value) else: mp[key] = value chk._write_pickled_dict("metric_parameters", "mp_dict", mp)