Source code for timeserio.model_selection.time_series_split

from typing import List, Union

import numpy as np
from sklearn.model_selection import TimeSeriesSplit
from sklearn.model_selection._split import _BaseKFold


[docs]class PandasTimeSeriesSplit(_BaseKFold): """Apply a sklearn TimeSeriesSplit to multiple timeseries in a single DF. The dataframe should be ordered by date ascending for each time series, and the index should be unique. Parameters: groupby : string or array of strings The column name(s) to group the input dataframe by - each group should hold a monotonically increasing time series. datetime_col : string The column name of the datetime column - used to validate that the dataframe is groups of time series. n_splits : int, default = 3 Number of splits. Must be at least 2. max_train_size : int, optional Maximum size for a single training set. """ def __init__( self, groupby: Union[str, List[str]], datetime_col: str, n_splits: int = 3, max_train_size: int = None, ): self.groupby = groupby self.datetime_col = datetime_col self.n_splits = n_splits self.max_train_size = max_train_size
[docs] def split(self, df, y=None, groups=None): self._validate_df(df) groups = df.groupby(self.groupby).indices splits = {} while True: X_idxs, y_idxs = [], [] for key, sub_idx in groups.items(): sub_df = df.iloc[sub_idx] sub_y = y[sub_idx] if y is not None else None if key not in splits: splitter = TimeSeriesSplit( self.n_splits, self.max_train_size ) splits[key] = splitter.split(sub_df, sub_y) try: X_idx, y_idx = next(splits[key]) X_idx = np.array( [df.index.get_loc(i) for i in sub_df.iloc[X_idx].index] ) y_idx = np.array( [df.index.get_loc(i) for i in sub_df.iloc[y_idx].index] ) X_idxs.append(X_idx) y_idxs.append(y_idx) except StopIteration: pass if len(X_idxs) == 0: break yield np.concatenate(X_idxs), np.concatenate(y_idxs)
def _validate_df(self, df): if df.index.duplicated().any(): raise ValueError("Dataframe has non-unique index.") shift_date = df.groupby(self.groupby)[self.datetime_col].shift() if (shift_date >= df[self.datetime_col]).any(): raise ValueError( "Dataframe not in ascending order for each time series." )