"""
Driver functions for plotting solution data.
"""
from firedrake.pyplot import tricontourf, triplot  # noqa
import matplotlib.pyplot as plt
__all__ = ["plot_snapshots", "plot_indicator_snapshots"]
[docs]
def plot_snapshots(solutions, time_partition, field, label, **kwargs):
    """
    Plot a sequence of snapshots associated with
    ``solutions.field.label`` and a given
    :class:`~.TimePartition`.
    Any keyword arguments are passed to
    :func:`firedrake.plot.tricontourf`.
    :arg solutions: :class:`~.AttrDict` of solutions
        computed by solving a forward or adjoint
        problem
    :arg time_partition: the :class:`~.TimePartition`
        object used to solve the problem
    :arg field: solution field of choice
    :arg label: choose from ``'forward'``, ``'forward_old'``
        ``'adjoint'`` and ``'adjoint_next'``
    """
    tp = time_partition
    rows = tp.num_exports_per_subinterval[0] - 1
    cols = tp.num_subintervals
    steady = rows == cols == 1
    figsize = kwargs.pop("figsize", (6 * cols, 24 // cols))
    fig, axes = plt.subplots(rows, cols, sharex="col", figsize=figsize)
    tcs = []
    for i, sols_step in enumerate(solutions[field][label]):
        ax = axes if steady else axes[0] if cols == 1 else axes[0, i]
        ax.set_title(f"Mesh[{i}]")
        tc = []
        for j, sol in enumerate(sols_step):
            ax = axes if steady else axes[j] if cols == 1 else axes[j, i]
            tc.append(tricontourf(sol, axes=ax, **kwargs))
            if not steady:
                time = (
                    tp.subintervals[i][0]
                    + (j + 1) * tp.timesteps[i] * tp.num_timesteps_per_export[i]
                )
                ax.annotate(f"t={time:.2f}", (0.05, 0.05), color="white")
        tcs.append(tc)
    plt.tight_layout()
    return fig, axes, tcs 
[docs]
def plot_indicator_snapshots(indicators, time_partition, field, **kwargs):
    """
    Plot a sequence of snapshots associated with
    ``indicators`` and a given :class:`~.TimePartition`
    Any keyword arguments are passed to
    :func:`firedrake.plot.tricontourf`.
    :arg indicators: list of list of indicators,
        indexed by mesh sequence index, then timestep
    :arg time_partition: the :class:`~.TimePartition`
        object used to solve the problem
    """
    tp = time_partition
    rows = tp.num_exports_per_subinterval[0] - 1
    cols = tp.num_subintervals
    steady = rows == cols == 1
    figsize = kwargs.pop("figsize", (6 * cols, 24 // cols))
    fig, axes = plt.subplots(rows, cols, sharex="col", figsize=figsize)
    tcs = []
    for i, indi_step in enumerate(indicators[field]):
        ax = axes if steady else axes[0] if cols == 1 else axes[0, i]
        ax.set_title(f"Mesh[{i}]")
        tc = []
        for j, indi in enumerate(indi_step):
            ax = axes if steady else axes[j] if cols == 1 else axes[j, i]
            tc.append(tricontourf(indi, axes=ax, **kwargs))
            if not steady:
                time = (
                    tp.subintervals[i][0]
                    + (j + 1) * tp.timesteps[i] * tp.num_timesteps_per_export[i]
                )
                ax.annotate(f"t={time:.2f}", (0.05, 0.05), color="white")
        tcs.append(tc)
    plt.tight_layout()
    return fig, axes, tcs