File size: 3,004 Bytes
1c8d125
96e1a32
1c8d125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96e1a32
1c8d125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96e1a32
1c8d125
 
 
 
 
 
 
 
 
 
 
96e1a32
 
1c8d125
 
 
 
 
 
 
96e1a32
1c8d125
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from abc import ABC, abstractmethod
from typing import Any

import numpy as np
import torch

from src.data.containers import TimeSeriesContainer
from src.data.frequency import (
    select_safe_random_frequency,
    select_safe_start_date,
)
from src.synthetic_generation.generator_params import GeneratorParams


class AbstractTimeSeriesGenerator(ABC):
    """
    Abstract base class for synthetic time series generators.
    """

    @abstractmethod
    def generate_time_series(self, random_seed: int | None = None) -> np.ndarray:
        """
        Generate synthetic time series data.

        Parameters
        ----------
        random_seed : int, optional
            Random seed for reproducibility.

        Returns
        -------
        np.ndarray
            Time series values of shape (length,) for univariate or
            (length, num_channels) for multivariate time series.
        """
        pass


class GeneratorWrapper:
    """
    Unified base class for all generator wrappers, using a GeneratorParams dataclass
    for configuration. Provides parameter sampling, validation, and batch formatting utilities.
    """

    def __init__(self, params: GeneratorParams):
        """
        Initialize the GeneratorWrapper with a GeneratorParams dataclass.

        Parameters
        ----------
        params : GeneratorParams
            Dataclass instance containing all generator configuration parameters.
        """
        self.params = params
        self._set_random_seeds(self.params.global_seed)

    def _set_random_seeds(self, seed: int) -> None:
        # For parameter sampling, we want diversity across batches even with similar seeds
        # Use a hash of the generator class name to ensure different generators get different parameter sequences
        param_seed = seed + hash(self.__class__.__name__) % 2**31
        self.rng = np.random.default_rng(param_seed)

        # Set global numpy and torch seeds for deterministic behavior in underlying generators
        np.random.seed(seed)
        torch.manual_seed(seed)

    def _sample_parameters(self, batch_size: int) -> dict[str, Any]:
        """
        Sample parameters with total_length fixed and history_length calculated.

        Returns
        -------
        Dict[str, Any]
            Dictionary containing sampled parameter values where
            history_length = total_length - future_length.
        """

        # Select a suitable frequency based on the total length
        frequency = [select_safe_random_frequency(self.params.length, self.rng) for _ in range(batch_size)]
        start = [select_safe_start_date(self.params.length, frequency[i], self.rng) for i in range(batch_size)]

        return {
            "frequency": frequency,
            "start": start,
        }

    @abstractmethod
    def generate_batch(self, batch_size: int, seed: int | None = None, **kwargs) -> TimeSeriesContainer:
        raise NotImplementedError("Subclasses must implement generate_batch()")