I've been using gammit::predict_gamm()
(which I understand uses mgcv::predict.gam()
) to predict with GAMM/GLMM models quite happily for a while. Recently while rerunning my workflow on two training and test datasets, I've discovered an issue with regenerating predicted values, affecting reproducibility of my work.
The predicted values appear to be equal to the random intercept coefficients of each particular group (e.g. Country_fact
), when I specify prediction with the random effect using the argument re_form = c("s(Country_fact)")
. Hence all subject/groups appear to have the same predicted values, regardless of the variation in values across other variables in the test data. When I predict without random effects using the argument re_form = NA
, I get predicted values that are appropriately varying. I have recreated a toy dataset which reproduces this result, reprex below (Edited following Roland's comments).
This seems to be happening regardless of model specification - GLMM or GAMM.
Is there something I have completely misunderstood about predicting with GAMMs? Could someone please test this and explain why I am getting these predicted values? Thanks very much in advance!
# create training data
set.seed(123)
dat <- tibble(
Reference_fact = as.factor(rep(seq(1,10,1),5)),
Country_fact = as.factor(rep(stri_rand_strings(10,2),5)),
exp_A = runif(50,0,500),
exp_B = runif(50,0,20),
resp = exp_A^2 + 0.1*exp_B * 0.3*exp_A*exp_B + rnorm(50, 0, 10)
)
# define model specification
model_spec <- c("resp ~
s(exp_A) +
s(exp_B) +
s(Reference_fact, bs = 're') +
s(Country_fact, bs = 're')"
)
# fit model (GLMM)
fit <- mgcv::gam(formula(str_replace_all(model_spec, "[\r\n]", "")),
method = 'REML',
family = 'gaussian',
data=dat)
# create test data
set.seed(123)
dat_pred <- tibble(
Reference_fact = as.factor(rep(seq(1,10,1),5)),
Country_fact = as.factor(rep(stri_rand_strings(10,2),5)),
exp_A = runif(50,0,500),
exp_B = runif(50,0,20)
)
# predict with random effects
p <- gammit::predict_gamm(
fit,
dat_pred,
re_form = c("s(Country_fact)"),
keep_prediction_data = TRUE,
newdata.guaranteed = FALSE,
se.fit = FALSE)
I got this result where prediction values correspond to random effects matching output from extract_ranef(fit)
. This does not seem expected. When predicting without random effects using the same package gammit
, unsurprisingly I get a different set of results.
> p %>% as_tibble()
# A tibble: 50 × 5
Reference_fact Country_fact exp_A exp_B prediction
<fct> <fct> <dbl> <dbl> <dbl>
1 1 Hm 445. 15.1 0.0156
2 2 Ps 346. 12.6 0.0137
3 3 w2 320. 14.2 -0.00998
4 4 Wt 497. 0.0125 -0.0117
5 5 YS 328. 9.51 0.0102
6 6 xS 354. 4.40 -0.0121
7 7 gZ 272. 7.60 -0.00652
8 8 6t 297. 12.3 0.00131
9 9 F2 145. 7.04 -0.00568
10 10 Kx 73.6 2.22 0.00514
# ℹ 40 more rows
# ℹ Use `print(n = ...)` to see more rows
> extract_ranef(fit)
# A tibble: 20 × 7
group_var effect group value se lower_2.5 upper_97.5
<chr> <chr> <chr> <dbl> <dbl> <dbl> <dbl>
1 Reference_fact Intercept 1 0.016 1.73 -3.38 3.41
2 Reference_fact Intercept 2 0.014 1.73 -3.38 3.40
3 Reference_fact Intercept 3 -0.01 1.73 -3.40 3.38
4 Reference_fact Intercept 4 -0.012 1.73 -3.40 3.38
5 Reference_fact Intercept 5 0.01 1.73 -3.38 3.40
6 Reference_fact Intercept 6 -0.012 1.73 -3.40 3.38
7 Reference_fact Intercept 7 -0.007 1.73 -3.40 3.38
8 Reference_fact Intercept 8 0.001 1.73 -3.39 3.39
9 Reference_fact Intercept 9 -0.006 1.73 -3.40 3.38
10 Reference_fact Intercept 10 0.005 1.73 -3.39 3.40
11 Country_fact Intercept 6t 0.001 1.73 -3.39 3.39
12 Country_fact Intercept F2 -0.006 1.73 -3.40 3.38
13 Country_fact Intercept gZ -0.007 1.73 -3.40 3.38
14 Country_fact Intercept Hm 0.016 1.73 -3.38 3.41
15 Country_fact Intercept Kx 0.005 1.73 -3.39 3.40
16 Country_fact Intercept Ps 0.014 1.73 -3.38 3.40
17 Country_fact Intercept w2 -0.01 1.73 -3.40 3.38
18 Country_fact Intercept Wt -0.012 1.73 -3.40 3.38
19 Country_fact Intercept xS -0.012 1.73 -3.40 3.38
20 Country_fact Intercept YS 0.01 1.73 -3.38 3.40
# predict without random effects
p_no_re <- gammit::predict_gamm(
fit,
dat_pred,
re_form = NA,
keep_prediction_data = TRUE,
newdata.guaranteed = FALSE,
se.fit = FALSE)
> p_no_re %>% as_tibble()
# A tibble: 50 × 5
Reference_fact Country_fact exp_A exp_B prediction
<fct> <fct> <dbl> <dbl> <dbl>
1 1 Hm 445. 15.1 200361.
2 2 Ps 346. 12.6 121563.
3 3 w2 320. 14.2 104321.
4 4 Wt 497. 0.0125 246972.
5 5 YS 328. 9.51 108463.
6 6 xS 354. 4.40 126635.
7 7 gZ 272. 7.60 74660.
8 8 6t 297. 12.3 89526.
9 9 F2 145. 7.04 20891.
10 10 Kx 73.6 2.22 5062.
# ℹ 40 more rows
# ℹ Use `print(n = ...)` to see more rows
When I use mgcv::predict.gam
, I get identical results regardless of whether the re.form
argument includes the random effect or not. These are also identical to the results from using gammit::predict_gamm
with re_form = NA
i.e. with no random effects.
pg <- mgcv::predict.gam(
fit,
dat_pred,
re.form = c("s(Country_fact)"),
keep_prediction_data = TRUE,
newdata.guaranteed = FALSE,
se.fit = FALSE)
> pg %>% as_tibble()
# A tibble: 50 × 1
value
<dbl>
1 200361.
2 121564.
3 104321.
4 246972.
5 108463.
6 126634.
7 74660.
8 89526.
9 20891.
10 5062.
pg_no_re <- mgcv::predict.gam(
fit,
dat_pred,
re.form = ~0,
#re.form = c("s(Country_fact)"),
keep_prediction_data = TRUE,
newdata.guaranteed = FALSE,
se.fit = FALSE)
> pg_no_re %>% as_tibble()
# A tibble: 50 × 1
value
<dbl>
1 200361.
2 121564.
3 104321.
4 246972.
5 108463.
6 126634.
7 74660.
8 89526.
9 20891.
10 5062.
So the question is: I would have expected the argument re_form
/re.form
in either package package to generate predicted values making use of Country-specific intercepts. Why is this not the case?
Session Info:
> sessionInfo()
R version 4.3.0 (2023-04-21 ucrt)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 19044)
Matrix products: default
locale:
[1] LC_COLLATE=English_Australia.utf8 LC_CTYPE=English_Australia.utf8 LC_MONETARY=English_Australia.utf8
[4] LC_NUMERIC=C LC_TIME=English_Australia.utf8
time zone: Australia/Sydney
tzcode source: internal
attached base packages:
[1] grid stats graphics grDevices datasets utils methods base
other attached packages:
[1] stringi_1.7.12 fstcore_0.9.14 gammit_0.3.2 furrr_0.3.1 rsample_1.1.1
[6] ggpubr_0.6.0 cowplot_1.1.1 tmap_3.3-3 fst_0.9.8 parallelly_1.35.0
[11] purrr_1.0.1 future_1.32.0 rlang_1.1.1 qs_0.25.5 terra_1.7-29
[16] readr_2.1.4 targets_1.0.0 renv_0.17.3 itsadug_2.4.1 plotfunctions_1.4
[21] mgcv_1.8-42 nlme_3.1-162 metafor_4.0-0 numDeriv_2016.8-1.1 metadat_1.2-0
[26] Matrix_1.5-4 forestplot_3.1.1 checkmate_2.2.0 GGally_2.1.2 ggExtra_0.10.0
[31] ggeffects_1.2.1 rasterize_0.1 rworldxtra_1.01 cleangeo_0.2-4 rgeos_0.6-2
[36] exactextractr_0.9.1 groupdata2_2.0.2 cvms_1.3.9 VIM_6.2.2 colorspace_2.1-0
[41] Rcpp_1.0.10 miceadds_3.16-18 mice_3.15.0 ncdf4_1.21 compareDF_2.3.5
[46] stringr_1.5.0 R.utils_2.12.2 R.oo_1.25.0 R.methodsS3_1.8.2 rworldmap_1.3-6
[51] here_1.0.1 RNetCDF_2.6-2 devtools_2.4.5 usethis_2.1.6 stars_0.6-1
[56] abind_1.4-5 ncmeta_0.3.5 tidync_0.3.0 ggthemes_4.2.4 gridBase_0.4-7
[61] raster_3.6-20 gridExtra_2.3 htmlwidgets_1.6.2 plotly_4.10.1 broom_1.0.4
[66] rgdal_1.6-6 geojsonio_0.11.0 sf_1.0-12 maps_3.4.1 ggmap_3.0.2
[71] maptools_1.1-6 sp_1.6-0 viridis_0.6.2 viridisLite_0.4.1 networkD3_0.4
[76] lubridate_1.9.2 zoo_1.8-12 Hmisc_5.0-1 readxl_1.4.2 ggplot2_3.4.2
[81] data.table_1.14.8 forecast_8.21 janitor_2.2.0 magrittr_2.0.3 tidyr_1.3.0
[86] dplyr_1.1.2 rvest_1.0.3
loaded via a namespace (and not attached):
[1] fs_1.6.2 bitops_1.0-7 httr_1.4.5 RColorBrewer_1.1-3 profvis_0.3.7
[6] tools_4.3.0 backports_1.4.1 utf8_1.2.3 R6_2.5.1 lazyeval_0.2.2
[11] urlchecker_1.0.1 withr_2.5.0 prettyunits_1.1.1 leaflet_2.1.2 leafem_0.2.0
[16] cli_3.6.1 tseries_0.10-53 robustbase_0.95-1 proxy_0.4-27 foreign_0.8-84
[21] dichromat_2.0-0.1 sessioninfo_1.2.2 TTR_0.24.3 rstudioapi_0.14 httpcode_0.3.0
[26] RApiSerialize_0.1.2 generics_0.1.3 crosstalk_1.2.0 car_3.1-2 fansi_1.0.4
[31] gratia_0.8.1 lifecycle_1.0.3 yaml_2.3.7 snakecase_0.11.0 carData_3.0-5
[36] mathjaxr_1.6-0 tmaptools_3.1-1 promises_1.2.0.1 crayon_1.5.2 miniUI_0.1.1.1
[41] lattice_0.21-8 geojson_0.3.4 pillar_1.9.0 knitr_1.42 boot_1.3-28.1
[46] future.apply_1.10.0 codetools_0.2-19 glue_1.6.2 urca_1.3-3 V8_4.3.0
[51] remotes_2.4.2 vcd_1.4-11 vctrs_0.6.2 png_0.1-8 spam_2.9-1
[56] cellranger_1.1.0 gtable_0.3.3 cachem_1.0.7 xfun_0.39 mime_0.12
[61] timeDate_4022.108 units_0.8-2 fields_14.1 ellipsis_0.3.2 xts_0.13.1
[66] rprojroot_2.0.3 KernSmooth_2.23-20 rpart_4.1.19 DBI_1.1.3 nnet_7.3-18
[71] tidyselect_1.2.0 processx_3.8.1 compiler_4.3.0 curl_5.0.0 htmlTable_2.4.1
[76] geojsonsf_2.0.3 xml2_1.3.4 stringfish_0.15.7 scales_1.2.1 DEoptimR_1.0-12
[81] classInt_0.4-9 lmtest_0.9-40 fracdiff_1.5-2 quadprog_1.5-8 callr_3.7.3
[86] mvnfast_0.2.8 digest_0.6.31 rmarkdown_2.21 htmltools_0.5.5 pkgconfig_2.0.3
[91] jpeg_0.1-10 base64enc_0.1-3 fastmap_1.1.1 quantmod_0.4.22 shiny_1.7.4
[96] visibly_0.2.9 jsonlite_1.8.4 Formula_1.2-5 dotCall64_1.0-2 patchwork_1.1.2
[101] munsell_0.5.0 leafsync_0.1.0 MASS_7.3-59 plyr_1.8.8 jqr_1.2.3
[106] pkgbuild_1.4.0 ggrepel_0.9.3 parallel_4.3.0 listenv_0.9.0 forcats_1.0.0
[111] splines_4.3.0 hms_1.1.3 ps_1.7.5 ranger_0.15.1 igraph_1.4.2
[116] base64url_1.4 ggsignif_0.6.4 reshape2_1.4.4 pkgload_1.3.2 crul_1.3
[121] XML_3.99-0.14 evaluate_0.20 mitools_2.4 RcppParallel_5.1.7 laeken_0.5.2
[126] BiocManager_1.30.20 tzdb_0.3.0 httpuv_1.6.9 RgoogleMaps_1.4.5.3 reshape_0.8.9
[131] lwgeom_0.2-11 xtable_1.8-4 e1071_1.7-13 rstatix_0.7.2 later_1.3.0
[136] class_7.3-21 tibble_3.2.1 memoise_2.0.1 cluster_2.1.4 timechange_0.2.0
[141] globals_0.16.2