Source code for caliber.regression.conformal_regression.cqr.base
from typing import Literal
import numpy as np
from numpy.typing import NDArray
from caliber.regression.conformal_regression.base import (
ConformalizedScoreRegressionModel,
)
from caliber.utils.functional import maybe_squeeze
from caliber.utils.quantile_checks import both_quantile_check, single_quantile_check
from caliber.utils.quantile_error import which_quantile_error
[docs]
class ConformalizedQuantileRegressionModel(ConformalizedScoreRegressionModel):
def __init__(
self,
confidence: float,
which_quantile: Literal["both", "lower", "upper"] = "both",
):
super().__init__(confidence=confidence)
self.which_quantile = which_quantile
[docs]
def fit(self, quantiles: NDArray[np.float64], targets: NDArray[np.float64]) -> None:
if targets.ndim == 1:
targets = targets[:, None]
if quantiles.ndim == 1:
quantiles = quantiles[:, None]
self._y_dim = targets.shape[1]
if self.which_quantile == "both":
both_quantile_check(quantiles, self._y_dim)
lowers, uppers = quantiles[:, : self._y_dim], quantiles[:, self._y_dim :]
scores = np.maximum(lowers - targets, targets - uppers)
elif self.which_quantile == "lower":
single_quantile_check(quantiles, self._y_dim)
scores = quantiles - targets
elif self.which_quantile == "upper":
single_quantile_check(quantiles, self._y_dim)
scores = targets - quantiles
else:
which_quantile_error(self.which_quantile)
super().fit(scores, targets)
[docs]
def predict(self, quantiles: NDArray[np.float64]) -> NDArray[np.float64]:
if self.which_quantile == "both":
both_quantile_check(quantiles, self._y_dim)
lowers, uppers = quantiles[:, : self._y_dim], quantiles[:, self._y_dim :]
return np.concatenate(
(lowers - self._params, uppers + self._params), axis=1
)
elif self.which_quantile == "lower":
single_quantile_check(quantiles, self._ydim)
return maybe_squeeze(quantiles - self._params, 1)
elif self.which_quantile == "upper":
single_quantile_check(quantiles, self._ydim)
return maybe_squeeze(quantiles + self._params, 1)
which_quantile_error(self.which_quantile)