tempoPFN / src /synthetic_generation /audio_generators /stochastic_rhythm_wrapper.py
Vladyslav Moroshan
Apply ruff formatting
96e1a32
from typing import Any
import numpy as np
from src.data.containers import TimeSeriesContainer
from src.synthetic_generation.abstract_classes import GeneratorWrapper
from src.synthetic_generation.audio_generators.stochastic_rhythm_generator import (
StochasticRhythmAudioGenerator,
)
from src.synthetic_generation.generator_params import StochasticRhythmAudioParams
class StochasticRhythmAudioWrapper(GeneratorWrapper):
def __init__(self, params: StochasticRhythmAudioParams):
super().__init__(params)
self.params: StochasticRhythmAudioParams = params
def _sample_parameters(self, batch_size: int) -> dict[str, Any]:
params = super()._sample_parameters(batch_size)
params.update(
{
"length": self.params.length,
"server_duration": self.params.server_duration,
"sample_rate": self.params.sample_rate,
"normalize_output": self.params.normalize_output,
"base_tempo_hz_range": self.params.base_tempo_hz_range,
"num_layers_range": self.params.num_layers_range,
"subdivisions": self.params.subdivisions,
"attack_range": self.params.attack_range,
"decay_range": self.params.decay_range,
"tone_freq_range": self.params.tone_freq_range,
"tone_mul_range": self.params.tone_mul_range,
}
)
return params
def generate_batch(
self,
batch_size: int,
seed: int | None = None,
params: dict[str, Any] | None = None,
) -> TimeSeriesContainer:
if seed is not None:
self._set_random_seeds(seed)
if params is None:
params = self._sample_parameters(batch_size)
generator = StochasticRhythmAudioGenerator(
length=params["length"],
server_duration=params["server_duration"],
sample_rate=params["sample_rate"],
normalize_output=params["normalize_output"],
base_tempo_hz_range=params["base_tempo_hz_range"],
num_layers_range=params["num_layers_range"],
subdivisions=params["subdivisions"],
attack_range=params["attack_range"],
decay_range=params["decay_range"],
tone_freq_range=params["tone_freq_range"],
tone_mul_range=params["tone_mul_range"],
random_seed=seed,
)
def _derive_series_seed(base_seed: int, index: int) -> int:
mixed = (
(base_seed & 0x7FFFFFFF)
^ ((index * 0x9E3779B1) & 0x7FFFFFFF)
^ (hash(self.__class__.__name__) & 0x7FFFFFFF)
)
return int(mixed)
batch_values = []
for i in range(batch_size):
series_seed = None if seed is None else _derive_series_seed(seed, i)
values = generator.generate_time_series(random_seed=series_seed)
batch_values.append(values)
return TimeSeriesContainer(
values=np.array(batch_values),
start=params["start"],
frequency=params["frequency"],
)