"""
Driver functions for plotting solution data.
"""
from firedrake.pyplot import tricontourf, triplot # noqa
import matplotlib.pyplot as plt
__all__ = ["plot_snapshots", "plot_indicator_snapshots"]
[docs]
def plot_snapshots(solutions, time_partition, field, label, **kwargs):
"""
Plot a sequence of snapshots associated with
``solutions.field.label`` and a given
:class:`~.TimePartition`.
Any keyword arguments are passed to
:func:`firedrake.plot.tricontourf`.
:arg solutions: :class:`~.AttrDict` of solutions
computed by solving a forward or adjoint
problem
:arg time_partition: the :class:`~.TimePartition`
object used to solve the problem
:arg field: solution field of choice
:arg label: choose from ``'forward'``, ``'forward_old'``
``'adjoint'`` and ``'adjoint_next'``
"""
tp = time_partition
rows = tp.num_exports_per_subinterval[0] - 1
cols = tp.num_subintervals
steady = rows == cols == 1
figsize = kwargs.pop("figsize", (6 * cols, 24 // cols))
fig, axes = plt.subplots(rows, cols, sharex="col", figsize=figsize)
tcs = []
for i, sols_step in enumerate(solutions[field][label]):
ax = axes if steady else axes[0] if cols == 1 else axes[0, i]
ax.set_title(f"Mesh[{i}]")
tc = []
for j, sol in enumerate(sols_step):
ax = axes if steady else axes[j] if cols == 1 else axes[j, i]
tc.append(tricontourf(sol, axes=ax, **kwargs))
if not steady:
time = (
tp.subintervals[i][0]
+ (j + 1) * tp.timesteps[i] * tp.num_timesteps_per_export[i]
)
ax.annotate(f"t={time:.2f}", (0.05, 0.05), color="white")
tcs.append(tc)
plt.tight_layout()
return fig, axes, tcs
[docs]
def plot_indicator_snapshots(indicators, time_partition, field, **kwargs):
"""
Plot a sequence of snapshots associated with
``indicators`` and a given :class:`~.TimePartition`
Any keyword arguments are passed to
:func:`firedrake.plot.tricontourf`.
:arg indicators: list of list of indicators,
indexed by mesh sequence index, then timestep
:arg time_partition: the :class:`~.TimePartition`
object used to solve the problem
"""
tp = time_partition
rows = tp.num_exports_per_subinterval[0] - 1
cols = tp.num_subintervals
steady = rows == cols == 1
figsize = kwargs.pop("figsize", (6 * cols, 24 // cols))
fig, axes = plt.subplots(rows, cols, sharex="col", figsize=figsize)
tcs = []
for i, indi_step in enumerate(indicators[field]):
ax = axes if steady else axes[0] if cols == 1 else axes[0, i]
ax.set_title(f"Mesh[{i}]")
tc = []
for j, indi in enumerate(indi_step):
ax = axes if steady else axes[j] if cols == 1 else axes[j, i]
tc.append(tricontourf(indi, axes=ax, **kwargs))
if not steady:
time = (
tp.subintervals[i][0]
+ (j + 1) * tp.timesteps[i] * tp.num_timesteps_per_export[i]
)
ax.annotate(f"t={time:.2f}", (0.05, 0.05), color="white")
tcs.append(tc)
plt.tight_layout()
return fig, axes, tcs