Skip to content

Commit 164a4ee

Browse files
strengejackeDominiqueMakowskiCopilot
authored
Add residualize_over_grid() (#386)
* Add `residualize_over_grid()` * docs * plot * some fixes * fix * fixes * fix * fix * styler * Update R/residualize_over_grid.R Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix * some fixes * minor fix * fix * styler * add tests * news, wordlist --------- Co-authored-by: Dominique Makowski <dom.mak19@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 3bfce0c commit 164a4ee

13 files changed

+1400
-10
lines changed

NAMESPACE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ S3method(print_md,visualisation_matrix)
5656
S3method(reshape_grouplevel,data.frame)
5757
S3method(reshape_grouplevel,default)
5858
S3method(reshape_grouplevel,estimate_grouplevel)
59+
S3method(residualize_over_grid,data.frame)
60+
S3method(residualize_over_grid,estimate_means)
5961
S3method(smoothing,data.frame)
6062
S3method(smoothing,numeric)
6163
S3method(standardize,estimate_contrasts)
@@ -99,6 +101,7 @@ export(pool_slopes)
99101
export(print_html)
100102
export(print_md)
101103
export(reshape_grouplevel)
104+
export(residualize_over_grid)
102105
export(smoothing)
103106
export(standardize)
104107
export(unstandardize)

NEWS.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@
3131

3232
* `estimate_grouplevel()` now supports models from package *coxme*.
3333

34+
* New function `residualize_over_grid()`, which residualizes a model
35+
over a grid of predictors. This is useful to visualize the residuals of a
36+
model over a grid of predictors.
37+
38+
* `visualisation_recipe()` and `plot()` get a `show_residuals` argument,
39+
to show the residuals of the model, related to the data grid, in the plot.
40+
3441
* Documentation of the `display()` method for *modelbased* objects has been
3542
added.
3643

R/residualize_over_grid.R

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
#' @title Compute partial residuals from a data grid
2+
#' @name residualize_over_grid
3+
#'
4+
#' @description This function computes partial residuals based on a data grid,
5+
#' where the data grid is usually a data frame from all combinations of factor
6+
#' variables or certain values of numeric vectors. This data grid is usually used
7+
#' as `newdata` argument in `predict()`, and can be created with
8+
#' [`insight::get_datagrid()`].
9+
#'
10+
#' @param grid A data frame representing the data grid, or an object of class
11+
#' `estimate_means` or `estimate_predicted`, as returned by the different
12+
#' `estimate_*()` functions.
13+
#' @param model The model for which to compute partial residuals. The data grid
14+
#' `grid` should match to predictors in the model.
15+
#' @param predictor_name The name of the focal predictor, for which partial residuals
16+
#' are computed.
17+
#' @param ... Currently not used.
18+
#'
19+
#' @section Partial Residuals:
20+
#' For **generalized linear models** (glms), residualized scores are computed as
21+
#' `inv.link(link(Y) + r)` where `Y` are the predicted values on the response
22+
#' scale, and `r` are the *working* residuals.
23+
#'
24+
#' For (generalized) linear **mixed models**, the random effect are also
25+
#' partialled out.
26+
#'
27+
#' @references
28+
#' Fox J, Weisberg S. Visualizing Fit and Lack of Fit in Complex Regression
29+
#' Models with Predictor Effect Plots and Partial Residuals. Journal of
30+
#' Statistical Software 2018;87.
31+
#'
32+
#' @return A data frame with residuals for the focal predictor.
33+
#'
34+
#' @examplesIf requireNamespace("marginaleffects", quietly = TRUE)
35+
#' set.seed(1234)
36+
#' x1 <- rnorm(200)
37+
#' x2 <- rnorm(200)
38+
#' # quadratic relationship
39+
#' y <- 2 * x1 + x1^2 + 4 * x2 + rnorm(200)
40+
#'
41+
#' d <- data.frame(x1, x2, y)
42+
#' model <- lm(y ~ x1 + x2, data = d)
43+
#'
44+
#' pr <- estimate_means(model, c("x1", "x2"))
45+
#' head(residualize_over_grid(pr, model))
46+
#' @export
47+
residualize_over_grid <- function(grid, model, ...) {
48+
UseMethod("residualize_over_grid")
49+
}
50+
51+
52+
#' @rdname residualize_over_grid
53+
#' @export
54+
residualize_over_grid.data.frame <- function(grid, model, predictor_name, ...) {
55+
old_d <- insight::get_predictors(model)
56+
fun_link <- insight::link_function(model)
57+
inv_fun <- insight::link_inverse(model)
58+
predicted <- grid[[predictor_name]]
59+
grid[[predictor_name]] <- NULL
60+
61+
is_fixed <- sapply(grid, function(x) length(unique(x))) == 1
62+
grid <- grid[, !is_fixed, drop = FALSE]
63+
old_d <- old_d[, colnames(grid)[colnames(grid) %in% colnames(old_d)], drop = FALSE]
64+
65+
if (!.is_grid(grid)) {
66+
insight::format_error("Grid for partial residuals must be a fully crossed grid.")
67+
}
68+
69+
# for each var
70+
best_match <- NULL
71+
72+
for (p in colnames(old_d)) {
73+
if (is.numeric(old_d[[p]])) {
74+
grid[[p]] <- .validate_num(grid[[p]])
75+
}
76+
# if numeric in old data, find where it is closest
77+
best_match <- .closest(old_d[[p]], grid[[p]], best_match = best_match)
78+
}
79+
80+
idx <- apply(best_match, 2, which)
81+
idx <- sapply(idx, "[", 1)
82+
83+
# extract working residuals
84+
res <- .safe(stats::residuals(model, type = "working"))
85+
86+
# if failed, and model linear, extract response residuals
87+
if (is.null(res)) {
88+
minfo <- insight::model_info(model)
89+
if (minfo$is_linear) {
90+
res <- .safe(insight::get_residuals(model, type = "response"))
91+
}
92+
}
93+
94+
if (is.null(res)) {
95+
insight::format_alert("Could not extract residuals.")
96+
return(NULL)
97+
}
98+
99+
my_points <- grid[idx, , drop = FALSE]
100+
my_points[[predictor_name]] <- inv_fun(fun_link(predicted[idx]) + res) # add errors
101+
102+
my_points
103+
}
104+
105+
106+
#' @export
107+
residualize_over_grid.estimate_means <- function(grid, model, ...) {
108+
new_d <- as.data.frame(grid)
109+
110+
relevant_columns <- unique(c(
111+
attributes(grid)$trend,
112+
attributes(grid)$contrast,
113+
attributes(grid)$focal_terms,
114+
attributes(grid)$coef_name
115+
))
116+
117+
new_d <- new_d[colnames(new_d) %in% relevant_columns]
118+
119+
residualize_over_grid(new_d, model, predictor_name = attributes(grid)$coef_name, ...)
120+
}
121+
122+
123+
# utilities --------------------------------------------------------------------
124+
125+
126+
.is_grid <- function(df) {
127+
unq <- lapply(df, unique)
128+
129+
if (prod(lengths(unq)) != nrow(df)) {
130+
return(FALSE)
131+
}
132+
133+
df2 <- do.call(expand.grid, args = unq)
134+
df2$..1 <- 1
135+
136+
res <- merge(df, df2, by = colnames(df), all = TRUE)
137+
138+
sum(res$..1) == sum(df2$..1)
139+
}
140+
141+
142+
.closest <- function(x, target, best_match) {
143+
if (is.numeric(x)) {
144+
145+
AD <- abs(outer(x, target, FUN = `-`))
146+
idx <- apply(AD, 1, function(x) x == min(x))
147+
} else {
148+
idx <- t(outer(x, target, FUN = `==`))
149+
}
150+
151+
if (is.matrix(best_match)) {
152+
idx <- idx & best_match
153+
}
154+
155+
idx
156+
}
157+
158+
159+
.validate_num <- function(x) {
160+
if (!is.numeric(x)) {
161+
x <- as.numeric(as.character(x))
162+
}
163+
x
164+
}

R/visualisation_recipe.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939
#' predictor. Use `FALSE` to always use continuous color scales for numeric
4040
#' predictors. It is possible to set a global default value using `options()`,
4141
#' e.g. `options(modelbased_numeric_as_discrete = 10)`.
42+
#' @param show_residuals Logical, if `TRUE`, display residuals of the model
43+
#' as a background to the model-based estimation. Residuals will be computed
44+
#' for the predictors in the data grid, using [`residualize_over_grid()`].
4245
#' @param point,line,pointrange,ribbon,facet,grid Additional
4346
#' aesthetics and parameters for the geoms (see customization example).
4447
#' @param ... Arguments passed from `plot()` to `visualisation_recipe()`, or
@@ -150,6 +153,7 @@
150153
#' @export
151154
visualisation_recipe.estimate_predicted <- function(x,
152155
show_data = FALSE,
156+
show_residuals = FALSE,
153157
point = NULL,
154158
line = NULL,
155159
pointrange = NULL,
@@ -173,6 +177,7 @@ visualisation_recipe.estimate_predicted <- function(x,
173177
.visualization_recipe(
174178
x,
175179
show_data = show_data,
180+
show_residuals = show_residuals,
176181
point = point,
177182
line = line,
178183
pointrange = pointrange,
@@ -230,6 +235,7 @@ visualisation_recipe.estimate_slopes <- function(x,
230235
.visualization_recipe(
231236
x,
232237
show_data = FALSE,
238+
show_residuals = FALSE,
233239
line = line,
234240
pointrange = pointrange,
235241
ribbon = ribbon,
@@ -287,6 +293,7 @@ visualisation_recipe.estimate_grouplevel <- function(x,
287293
.visualization_recipe(
288294
x,
289295
show_data = FALSE,
296+
show_residuals = FALSE,
290297
line = line,
291298
pointrange = pointrange,
292299
ribbon = ribbon,

R/visualisation_recipe_internal.R

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@
246246
#' @keywords internal
247247
.visualization_recipe <- function(x,
248248
show_data = TRUE,
249+
show_residuals = FALSE,
249250
point = NULL,
250251
line = NULL,
251252
pointrange = NULL,
@@ -276,7 +277,7 @@
276277

277278
# Don't plot raw data if `predict` is not on the response scale
278279
if (!is.null(response_scale) && !response_scale %in% c("prediction", "response", "expectation", "invlink(link)")) {
279-
show_data <- FALSE
280+
show_data <- show_residuals <- FALSE
280281
}
281282

282283
# Don't plot raw data for transformed responses with no back-transformation
@@ -286,20 +287,27 @@
286287
# add information about response transformation
287288
trans_fun <- .safe(insight::find_transformation(attributes(x)$model))
288289
if (!is.null(trans_fun) && all(trans_fun != "identity")) {
289-
show_data <- FALSE
290+
show_data <- show_residuals <- FALSE
290291
}
291292
}
292293

293-
294294
# add raw data as first layer ----------------------------------
295295
if (show_data) {
296-
layers[[paste0("l", l)]] <- .visualization_recipe_rawdata(x, aes)
296+
layers[[paste0("l", l)]] <- .visualization_recipe_rawdata(x, aes, numeric_as_discrete)
297297
# Update with additional args
298298
if (!is.null(point)) layers[[paste0("l", l)]] <- utils::modifyList(layers[[paste0("l", l)]], point)
299299
l <- l + 1
300300
}
301301

302302

303+
# add residual data as next lowest layer
304+
if (show_residuals) {
305+
layers[[paste0("l", l)]] <- .visualization_recipe_residuals(x, aes, numeric_as_discrete)
306+
# Update with additional args
307+
if (!is.null(point)) layers[[paste0("l", l)]] <- utils::modifyList(layers[[paste0("l", l)]], point)
308+
l <- l + 1
309+
}
310+
303311
# intercept line for slopes ----------------------------------
304312
if (inherits(x, "estimate_slopes")) {
305313
layers[[paste0("l", l)]] <- insight::compact_list(list(
@@ -469,15 +477,64 @@
469477

470478

471479
#' @keywords internal
472-
.visualization_recipe_rawdata <- function(x, aes) {
480+
.visualization_recipe_rawdata <- function(x, aes, numeric_as_discrete = 8) {
473481
model <- attributes(x)$model
474482
rawdata <- insight::get_data(model, verbose = FALSE)
475483

476484
# Add response to data if not there
477485
y <- insight::find_response(attributes(x)$model)
478-
if (!y %in% names(rawdata)) rawdata[y] <- insight::get_response(attributes(x)$model, verbose = FALSE)
486+
if (!y %in% names(rawdata)) {
487+
rawdata[y] <- insight::get_response(attributes(x)$model, verbose = FALSE)
488+
}
489+
490+
# if we have less than 8 values for the legend, a continuous color scale
491+
# is used by default - we then must convert values into factors, when we
492+
# show data or residuals - but we must ensure that the levels are sorted
493+
# according to the original data grid, thus we need "sort()"
494+
if (!is.null(aes$color) && is.numeric(rawdata[[aes$color]]) && insight::n_unique(rawdata[[aes$color]]) < numeric_as_discrete) {
495+
new_values <- insight::format_value(rawdata[[aes$color]], protect_integers = TRUE)
496+
rawdata[[aes$color]] <- factor(new_values, levels = as.character(sort(as.numeric(unique(new_values)))))
497+
}
498+
499+
.data_point_geom(
500+
model = model,
501+
aes = aes,
502+
data = rawdata,
503+
y = y
504+
)
505+
}
479506

480-
if (aes$type == "pointrange" && !is.numeric(rawdata[[aes$x]])) {
507+
508+
# residuals ----------------------------------------------------------------
509+
510+
511+
#' @keywords internal
512+
.visualization_recipe_residuals <- function(x, aes, numeric_as_discrete = 8) {
513+
model <- attributes(x)$model
514+
residual_data <- residualize_over_grid(x, model)
515+
516+
# if we have less than 8 values for the legend, a continuous color scale
517+
# is used by default - we then must convert values into factors, when we
518+
# show data or residuals - but we must ensure that the levels are sorted
519+
# according to the original data grid, thus we need "sort()"
520+
if (!is.null(aes$color) && is.numeric(residual_data[[aes$color]]) && insight::n_unique(residual_data[[aes$color]]) < numeric_as_discrete) {
521+
new_values <- insight::format_value(residual_data[[aes$color]], protect_integers = TRUE)
522+
residual_data[[aes$color]] <- factor(new_values, levels = as.character(sort(as.numeric(unique(new_values)))))
523+
}
524+
525+
.data_point_geom(
526+
model = model,
527+
aes = aes,
528+
data = residual_data,
529+
y = "Mean"
530+
)
531+
}
532+
533+
534+
# helpers -----------------------------------------------------------------
535+
536+
.data_point_geom <- function(model, aes, data, y) {
537+
if (aes$type == "pointrange" && !is.numeric(data[[aes$x]])) {
481538
geom <- "jitter"
482539
} else {
483540
geom <- "point"
@@ -493,7 +550,7 @@
493550

494551
out <- list(
495552
geom = geom,
496-
data = rawdata,
553+
data = data,
497554
aes = list(
498555
y = y,
499556
x = aes$x,
@@ -508,10 +565,10 @@
508565
# check if we have matching columns in the raw data - some functions,
509566
# likes slopes, have mapped these aes to other columns that are not part
510567
# of the raw data - we set them to NULL
511-
if (!is.null(aes$color) && !aes$color %in% colnames(rawdata)) {
568+
if (!is.null(aes$color) && !aes$color %in% colnames(data)) {
512569
out$aes$color <- NULL
513570
}
514-
if (!is.null(aes$alpha) && !aes$alpha %in% colnames(rawdata)) {
571+
if (!is.null(aes$alpha) && !aes$alpha %in% colnames(data)) {
515572
out$aes$alpha <- NULL
516573
}
517574

0 commit comments

Comments
 (0)