These functions are wrappers around the E_loo function of the loo package.

# S3 method for brmsfit
loo_predict(
  object,
  type = c("mean", "var", "quantile"),
  probs = 0.5,
  psis_object = NULL,
  resp = NULL,
  ...
)

# S3 method for brmsfit
loo_linpred(
  object,
  type = c("mean", "var", "quantile"),
  probs = 0.5,
  psis_object = NULL,
  resp = NULL,
  ...
)

# S3 method for brmsfit
loo_predictive_interval(object, prob = 0.9, psis_object = NULL, ...)

Arguments

object

An object of class brmsfit.

type

The statistic to be computed on the results. Can by either "mean" (default), "var", or "quantile".

probs

A vector of quantiles to compute. Only used if type = quantile.

psis_object

An optional object returned by psis. If psis_object is missing then psis is executed internally, which may be time consuming for models fit to very large datasets.

resp

Optional names of response variables. If specified, predictions are performed only for the specified response variables.

...

Optional arguments passed to the underlying methods that is log_lik, as well as posterior_predict or posterior_linpred.

prob

For loo_predictive_interval, a scalar in \((0,1)\) indicating the desired probability mass to include in the intervals. The default is prob = 0.9 (\(90\)% intervals).

Value

loo_predict and loo_linpred return a vector with one element per observation. The only exception is if type = "quantile"

and length(probs) >= 2, in which case a separate vector for each element of probs is computed and they are returned in a matrix with

length(probs) rows and one column per observation.

loo_predictive_interval returns a matrix with one row per observation and two columns.

loo_predictive_interval(..., prob = p) is equivalent to

loo_predict(..., type = "quantile", probs = c(a, 1-a)) with

a = (1 - p)/2, except it transposes the result and adds informative column names.

Examples

if (FALSE) {
## data from help("lm")
ctl <- c(4.17,5.58,5.18,6.11,4.50,4.61,5.17,4.53,5.33,5.14)
trt <- c(4.81,4.17,4.41,3.59,5.87,3.83,6.03,4.89,4.32,4.69)
d <- data.frame(
  weight = c(ctl, trt),
  group = gl(2, 10, 20, labels = c("Ctl", "Trt"))
)
fit <- brm(weight ~ group, data = d)
loo_predictive_interval(fit, prob = 0.8)

## optionally log-weights can be pre-computed and reused
psis <- loo::psis(-log_lik(fit), cores = 2)
loo_predictive_interval(fit, prob = 0.8, psis_object = psis)
loo_predict(fit, type = "var", psis_object = psis)
}