"""
Partitioning for the temporal domain.
"""
from collections.abc import Iterable
import numpy as np
from .log import debug
__all__ = ["TimePartition", "TimeInterval", "TimeInstant"]
[docs]
class TimePartition:
    """
    A partition of the time interval of interest into subintervals.
    The subintervals are assumed to be uniform in length. However, different timestep
    values may be used on each subinterval.
    """
    def __init__(
        self,
        end_time,
        num_subintervals,
        timesteps,
        field_names,
        num_timesteps_per_export=1,
        start_time=0.0,
        subintervals=None,
        field_types=None,
    ):
        r"""
        :arg end_time: end time of the interval of interest
        :type end_time: :class:`float` or :class:`int`
        :arg num_subintervals: number of subintervals in the partition
        :type num_subintervals: :class:`int`
        :arg timesteps: a list timesteps to be used on each subinterval, or a single
            timestep to use for all subintervals
        :type timesteps: :class:`list` of :class:`float`\s or :class:`float`
        :arg field_names: the list of field names to consider
        :type field_names: :class:`list` of :class:`str`\s or :class:`str`
        :kwarg num_timesteps_per_export: a list of numbers of timesteps per export for
            each subinterval, or a single number to use for all subintervals
        :type num_timesteps_per_export: :class:`list` of :class`int`\s or :class:`int`
        :kwarg start_time: start time of the interval of interest
        :type start_time: :class:`float` or :class:`int`
        :kwarg subinterals: sequence of subintervals (which need not be of uniform
            length), or ``None`` to use uniform subintervals (the default)
        :type subintervals: :class:`list` of :class:`tuple`\s
        :kwarg field_types: a list of strings indicating whether each field is
            'unsteady' or 'steady', i.e., does the corresponding equation involve time
            derivatives or not?
        :type field_types: :class:`list` of :class:`str`\s or :class:`str`
        """
        debug(100 * "-")
        if isinstance(field_names, str):
            field_names = [field_names]
        self.field_names = field_names
        self.start_time = start_time
        self.end_time = end_time
        self.num_subintervals = int(np.round(num_subintervals))
        if not np.isclose(num_subintervals, self.num_subintervals):
            raise ValueError(
                f"Non-integer number of subintervals '{num_subintervals}'."
            )
        self.debug("num_subintervals")
        self.interval = (self.start_time, self.end_time)
        self.debug("interval")
        # Get subintervals
        self.subintervals = subintervals
        if self.subintervals is None:
            subinterval_time = (self.end_time - self.start_time) / num_subintervals
            self.subintervals = [
                (
                    self.start_time + i * subinterval_time,
                    self.start_time + (i + 1) * subinterval_time,
                )
                for i in range(num_subintervals)
            ]
        self._check_subintervals()
        self.debug("subintervals")
        # Get timestep on each subinterval
        if not isinstance(timesteps, Iterable):
            timesteps = [timesteps] * len(self)
        self.timesteps = timesteps
        self._check_timesteps()
        self.debug("timesteps")
        # Get number of timesteps on each subinterval
        self.num_timesteps_per_subinterval = []
        for i, ((ts, tf), dt) in enumerate(zip(self.subintervals, self.timesteps)):
            num_timesteps = (tf - ts) / dt
            self.num_timesteps_per_subinterval.append(int(np.round(num_timesteps)))
            if not np.isclose(num_timesteps, self.num_timesteps_per_subinterval[-1]):
                raise ValueError(
                    f"Non-integer number of timesteps on subinterval {i}:"
                    f" {num_timesteps}."
                )
        self.debug("num_timesteps_per_subinterval")
        # Get num timesteps per export
        if not isinstance(num_timesteps_per_export, Iterable):
            num_timesteps_per_export = [num_timesteps_per_export] * len(self)
        self.num_timesteps_per_export = num_timesteps_per_export
        self._check_num_timesteps_per_export()
        self.debug("num_timesteps_per_export")
        # Get num exports per subinterval
        self.num_exports_per_subinterval = [
            tsps // tspe + 1
            for tspe, tsps in zip(
                self.num_timesteps_per_export, self.num_timesteps_per_subinterval
            )
        ]
        self.debug("num_exports_per_subinterval")
        self.steady = (
            self.num_subintervals == 1 and self.num_timesteps_per_subinterval[0] == 1
        )
        self.debug("steady")
        # Process field types
        if field_types is None:
            num_fields = len(self.field_names)
            field_types = ["steady" if self.steady else "unsteady"] * num_fields
        elif isinstance(field_types, str):
            field_types = [field_types]
        self.field_types = field_types
        self._check_field_types()
        debug("field_types")
        debug(100 * "-")
[docs]
    def debug(self, attr):
        """
        Print attribute 'msg' for debugging purposes.
        :arg attr: the attribute to display debugging information for
        """
        try:
            val = self.__getattribute__(attr)
        except AttributeError as e:
            raise AttributeError(
                f"Attribute '{attr}' cannot be debugged because it doesn't exist."
            ) from e
        label = " ".join(attr.split("_"))
        debug(f"TimePartition: {label:25s} {val}") 
    def __str__(self):
        return f"{self.subintervals}"
    def __repr__(self):
        timesteps = ", ".join([str(dt) for dt in self.timesteps])
        field_names = ", ".join([f"'{field_name}'" for field_name in self.field_names])
        return (
            f"TimePartition("
            f"end_time={self.end_time}, "
            f"num_subintervals={self.num_subintervals}, "
            f"timesteps=[{timesteps}], "
            f"field_names=[{field_names}])"
        )
    def __len__(self):
        return self.num_subintervals
    def __getitem__(self, index_or_slice):
        """
        :arg index_or_slice: an index or slice to generate a sub-time partition for
        :type index_or_slice: :class:`int` or :class:`slice`
        :returns: a time partition for the given index or slice
        :rtype: :class:`~.TimePartition`
        """
        sl = index_or_slice
        if not isinstance(sl, slice):
            sl = slice(sl, sl + 1, 1)
        step = sl.step or 1
        if step != 1:
            raise NotImplementedError(
                "Can only currently handle slices with step size 1."
            )
        num_subintervals = len(range(sl.start, sl.stop, step))
        return TimePartition(
            end_time=self.subintervals[sl.stop - 1][1],
            num_subintervals=num_subintervals,
            timesteps=self.timesteps[sl],
            field_names=self.field_names,
            num_timesteps_per_export=self.num_timesteps_per_export[sl],
            start_time=self.subintervals[sl.start][0],
            field_types=self.field_types,
        )
    @property
    def num_timesteps(self):
        """
        :returns the total number of timesteps
        :rtype: :class:`int`
        """
        return sum(self.num_timesteps_per_subinterval)
    def _check_subintervals(self):
        if len(self.subintervals) != self.num_subintervals:
            raise ValueError(
                "Number of subintervals provided differs from num_subintervals:"
                f" {len(self.subintervals)} != {self.num_subintervals}."
            )
        if not np.isclose(self.subintervals[0][0], self.start_time):
            raise ValueError(
                "The first subinterval does not start at the start time:"
                f" {self.subintervals[0][0]} != {self.start_time}."
            )
        for i in range(self.num_subintervals - 1):
            if not np.isclose(self.subintervals[i][1], self.subintervals[i + 1][0]):
                raise ValueError(
                    f"The end of subinterval {i} does not match the start of"
                    f" subinterval {i+1}: {self.subintervals[i][1]} !="
                    f" {self.subintervals[i+1][0]}."
                )
        if not np.isclose(self.subintervals[-1][1], self.end_time):
            raise ValueError(
                "The final subinterval does not end at the end time:"
                f" {self.subintervals[-1][1]} != {self.end_time}."
            )
    def _check_timesteps(self):
        if len(self.timesteps) != self.num_subintervals:
            raise ValueError(
                "Number of timesteps does not match num_subintervals:"
                f" {len(self.timesteps)} != {self.num_subintervals}."
            )
    def _check_num_timesteps_per_export(self):
        if len(self.num_timesteps_per_export) != len(
            self.num_timesteps_per_subinterval
        ):
            raise ValueError(
                "Number of timesteps per export and subinterval do not match:"
                f" {len(self.num_timesteps_per_export)}"
                f" != {len(self.num_timesteps_per_subinterval)}."
            )
        for i, (tspe, tsps) in enumerate(
            zip(self.num_timesteps_per_export, self.num_timesteps_per_subinterval)
        ):
            if not isinstance(tspe, int):
                raise TypeError(
                    f"Expected number of timesteps per export on subinterval {i} to be"
                    f" an integer, not '{type(tspe)}'."
                )
            if tsps % tspe != 0:
                raise ValueError(
                    "Number of timesteps per export does not divide number of"
                    f" timesteps per subinterval on subinterval {i}:"
                    f" {tsps} | {tspe} != 0."
                )
    def _check_field_types(self):
        if len(self.field_names) != len(self.field_types):
            raise ValueError(
                "Number of field names does not match number of field types:"
                f" {len(self.field_names)} != {len(self.field_types)}."
            )
        for field_name, field_type in zip(self.field_names, self.field_types):
            if field_type not in ("unsteady", "steady"):
                raise ValueError(
                    f"Expected field type for field '{field_name}' to be either"
                    f" 'unsteady' or 'steady', but got '{field_type}'."
                )
    def __eq__(self, other):
        if len(self) != len(other):
            return False
        return (
            np.allclose(self.subintervals, other.subintervals)
            and np.allclose(self.timesteps, other.timesteps)
            and np.allclose(
                self.num_exports_per_subinterval, other.num_exports_per_subinterval
            )
            and self.field_names == other.field_names
            and self.field_types == other.field_types
        )
    def __ne__(self, other):
        if len(self) != len(other):
            return True
        return (
            not np.allclose(self.subintervals, other.subintervals)
            or not np.allclose(self.timesteps, other.timesteps)
            or not np.allclose(
                self.num_exports_per_subinterval, other.num_exports_per_subinterval
            )
            or not self.field_names == other.field_names
            or not self.field_types == other.field_types
        ) 
[docs]
class TimeInterval(TimePartition):
    """
    A trivial :class:`~.TimePartition` with a single subinterval.
    """
    def __init__(self, *args, **kwargs):
        if isinstance(args[0], tuple):
            assert len(args[0]) == 2
            kwargs["start_time"] = args[0][0]
            end_time = args[0][1]
        else:
            end_time = args[0]
        timestep = args[1]
        field_names = args[2]
        super().__init__(end_time, 1, timestep, field_names, **kwargs)
    def __repr__(self):
        return (
            f"TimeInterval("
            f"end_time={self.end_time}, "
            f"timestep={self.timestep}, "
            f"field_names={self.field_names})"
        )
    @property
    def timestep(self):
        """
        :returns: the timestep used on the single interval
        :rtype: :class:`float`
        """
        return self.timesteps[0] 
[docs]
class TimeInstant(TimeInterval):
    """
    A :class:`~.TimePartition` for steady-state problems.
    Under the hood this means dividing :math:`[0,1)` into a single timestep.
    """
    def __init__(self, field_names, **kwargs):
        if "end_time" in kwargs:
            if "time" in kwargs:
                raise ValueError("Both 'time' and 'end_time' are set.")
            time = kwargs.pop("end_time")
        else:
            time = kwargs.pop("time", 1.0)
        timestep = time
        super().__init__(time, timestep, field_names, **kwargs)
    def __str__(self):
        return f"({self.end_time})"
    def __repr__(self):
        return (
            f"TimeInstant(" f"time={self.end_time}, " f"field_names={self.field_names})"
        )