File size: 5,514 Bytes
1c8d125 96e1a32 1c8d125 96e1a32 1c8d125 96e1a32 1c8d125 96e1a32 1c8d125 96e1a32 1c8d125 96e1a32 1c8d125 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import gpytorch
import numpy as np
from gpytorch.kernels import AdditiveKernel, PeriodicKernel, ProductKernel, ScaleKernel
def custom_gaussian_sample(
max_period_length,
kernel_periods=None,
gaussian_sample=True,
allow_extension=True,
rng=None,
):
if rng is None:
rng = np.random.default_rng()
means = (
np.array(kernel_periods) if kernel_periods is not None else np.array([3, 5, 7, 14, 20, 21, 24, 30, 60, 90, 120])
)
if allow_extension:
if max_period_length > 200:
st = max_period_length // 2 if max(means) < max_period_length // 2 else max(means) + 100
means = np.append(means, np.arange(st, max_period_length, 100))
else:
if max(means) < max_period_length / 2:
means = np.append(means, np.array([max_period_length // 2, max_period_length]))
elif max(means) < max_period_length:
means = np.append(means, max_period_length)
means = means[means <= max_period_length]
selected_mean = rng.choice(means)
if gaussian_sample:
# Define corresponding standard deviations using np.sqrt(means) * 2
std_devs = np.sqrt(means) ** 1.2 # / (means *0.008)
selected_std = std_devs[np.where(means == selected_mean)][0]
sample = rng.normal(selected_mean, selected_std)
else:
sample = selected_mean
if sample < 1:
sample = np.ceil(np.abs(sample))
return int(sample)
def create_kernel(
kernel: str,
seq_len: int,
max_period_length: int = 365,
max_degree: int = 5,
gaussians_periodic: bool = False,
kernel_periods=None,
kernel_counter=None,
freq=None,
exact_freqs=False,
gaussian_sample=True,
subfreq="",
rng=None,
):
if rng is None:
rng = np.random.default_rng()
scale_kernel = rng.choice([True, False])
lengthscale = rng.uniform(0.1, 5.0)
if kernel == "linear_kernel":
sigma_prior = gpytorch.priors.GammaPrior(rng.uniform(1, 6), rng.uniform(0.1, 1))
kernel = gpytorch.kernels.LinearKernel(variance_prior=sigma_prior)
elif kernel == "rbf_kernel":
kernel = gpytorch.kernels.RBFKernel()
kernel.lengthscale = lengthscale
elif kernel == "periodic_kernel":
if gaussians_periodic:
if exact_freqs and freq != "Y" and kernel_counter is not None:
period_length = custom_gaussian_sample(
max_period_length,
kernel_periods=kernel_periods[:-3]
if (kernel_counter["periodic_kernel"] <= 2) and (subfreq == "")
else kernel_periods,
gaussian_sample=gaussian_sample,
allow_extension=(kernel_counter["periodic_kernel"] > 2),
rng=rng,
)
kernel_counter["periodic_kernel"] -= 1
else:
period_length = custom_gaussian_sample(max_period_length, kernel_periods, gaussian_sample=True, rng=rng)
else:
period_length = rng.integers(1, max_period_length)
kernel = gpytorch.kernels.PeriodicKernel()
kernel.period_length = period_length / seq_len
kernel.lengthscale = lengthscale
elif kernel == "polynomial_kernel":
offset_prior = gpytorch.priors.GammaPrior(rng.uniform(1, 4), rng.uniform(0.1, 1))
degree = rng.integers(1, max_degree)
kernel = gpytorch.kernels.PolynomialKernel(offset_prior=offset_prior, power=degree)
elif kernel == "matern_kernel":
nu = rng.choice([0.5, 1.5, 2.5]) # Roughness parameter
kernel = gpytorch.kernels.MaternKernel(nu=nu)
kernel.lengthscale = lengthscale
elif kernel == "rational_quadratic_kernel":
alpha = rng.uniform(0.1, 10.0) # Scale mixture parameter
kernel = gpytorch.kernels.RQKernel(alpha=alpha)
kernel.lengthscale = lengthscale
elif kernel == "spectral_mixture_kernel":
num_mixtures = rng.integers(2, 6) # Number of spectral mixture components
kernel = gpytorch.kernels.SpectralMixtureKernel(num_mixtures=num_mixtures)
else:
raise ValueError(f"Unknown kernel: {kernel}")
if scale_kernel:
kernel = gpytorch.kernels.ScaleKernel(kernel)
return kernel
def extract_periodicities(kernel, seq_len):
periodicities = []
# Base case: if the kernel is a PeriodicKernel, extract its period_length
if isinstance(kernel, PeriodicKernel):
periodicities.append(kernel.period_length.item() * seq_len)
# If the kernel is a composite kernel (Additive, Product, Scale), recursively extract periodicities
elif isinstance(kernel, (AdditiveKernel, ProductKernel)):
for sub_kernel in kernel.kernels:
periodicities.extend(extract_periodicities(sub_kernel, seq_len))
elif isinstance(kernel, ScaleKernel):
periodicities.extend(extract_periodicities(kernel.base_kernel, seq_len))
return periodicities
def random_binary_map(a: gpytorch.kernels.Kernel, b: gpytorch.kernels.Kernel, rng=None):
"""
Applies a random binary operator (+ or *) with equal probability
on kernels ``a`` and ``b``.
Parameters
----------
a
A GP kernel.
b
A GP kernel.
Returns
-------
The composite kernel `a + b` or `a * b`.
"""
if rng is None:
rng = np.random.default_rng()
binary_maps = [lambda x, y: x + y, lambda x, y: x * y]
return rng.choice(binary_maps)(a, b)
|