easy conversion to draws from rstantools format
I am struggling with converting posteriors I get from rstantools things like posterior_linpred to a draws object correctly. The problem is that the chain information gets dropped. Here is an example illustrating what I'd like to have:
library(posterior)
#> Warning: package 'posterior' was built under R version 4.1.2
#> This is posterior version 1.2.2
#>
#> Attaching package: 'posterior'
#> The following objects are masked from 'package:stats':
#>
#> mad, sd, var
samp <- as_draws_matrix(example_draws())
## posterior_* functions from stan tools return matrices like
rstantools_samp <- matrix(as.matrix(samp), niterations(samp)*nchains(samp), nvariables(samp))
colnames(rstantools_samp) <- variables(samp)
head(rstantools_samp)
#> mu tau theta[1] theta[2] theta[3] theta[4] theta[5]
#> [1,] 2.005831 2.767367 3.9617520 0.27123540 -0.7431706 2.104805 0.92348879
#> [2,] 1.458316 6.979976 0.1237101 -0.06901539 0.9518270 7.281225 -0.06195211
#> [3,] 5.814947 9.677075 21.2510465 14.93055775 1.8290945 1.381443 0.53106337
#> [4,] 6.849586 4.788366 14.6996540 8.58618604 2.6749150 4.393232 4.75807198
#> [5,] 1.805168 2.848165 5.9600546 1.15573721 3.1088628 1.994890 0.76885094
#> [6,] 3.841243 4.083357 5.7601096 9.90920447 -0.9956266 5.328625 5.88894271
#> theta[6] theta[7] theta[8]
#> [1,] 1.650237 3.320019 4.848542
#> [2,] 11.257502 9.621128 -8.640446
#> [3,] 7.155371 14.802013 -1.736363
#> [4,] 8.101547 9.491277 5.281551
#> [5,] 4.656270 1.208251 -4.540236
#> [6,] -1.701463 2.780403 7.075855
dim(rstantools_samp)
#> [1] 400 10
## things are order by chain, so we have
all(rstantools_samp[1:100,1] == subset_draws(samp, variable="mu", chain=1))
#> [1] TRUE
all(rstantools_samp[101:200,1] == subset_draws(samp, variable="mu", chain=2))
#> [1] TRUE
## now we should have a posterior function which lets me create from
## rstantools_samp a posterior draws thing which knows the number of
## chains. This does not work:
as_draws_matrix(rstantools_samp, .nchains=4)
#> # A draws_matrix: 400 iterations, 1 chains, and 10 variables
#> variable
#> draw mu tau theta[1] theta[2] theta[3] theta[4] theta[5] theta[6]
#> 1 2.01 2.8 3.96 0.271 -0.74 2.1 0.923 1.7
#> 2 1.46 7.0 0.12 -0.069 0.95 7.3 -0.062 11.3
#> 3 5.81 9.7 21.25 14.931 1.83 1.4 0.531 7.2
#> 4 6.85 4.8 14.70 8.586 2.67 4.4 4.758 8.1
#> 5 1.81 2.8 5.96 1.156 3.11 2.0 0.769 4.7
#> 6 3.84 4.1 5.76 9.909 -1.00 5.3 5.889 -1.7
#> 7 5.47 4.0 4.03 4.151 10.15 6.6 3.741 -2.2
#> 8 1.20 1.5 -0.28 1.846 0.47 4.3 1.467 3.3
#> 9 0.15 3.9 1.81 0.661 0.86 4.5 -1.025 1.1
#> 10 7.17 1.8 6.08 8.102 7.68 5.6 7.106 8.5
#> # ... with 390 more draws, and 2 more variables
Created on 2022-08-04 by the reprex package (v2.0.1)
What I can do is to crudely set the nchains attribute to the number of chains. So I think the above should just work and give me a draws thing with 4 chains... this obviously requires documented formatting of the input samples to be column major sorted...
One option is to ingest posterior_...() function output via rvar(), as the internal format of rvar() is (by design) the same as the output of those functions. rvar() lets you set the number of chains:
library(posterior)
library(rstanarm)
mtcars_subset = mtcars[, c("hp", "cyl", "mpg")]
m = stan_glm(mpg ~ hp*cyl, data = mtcars_subset, chains = 4)
epred = rvar(posterior_epred(m), nchains = 4)
epred
#> rvar<1000,4>[32] mean ± sd:
#>
#> Mazda RX4 Mazda RX4 Wag Datsun 710 Hornet 4 Drive
#> 20 ± 0.79 20 ± 0.79 26 ± 0.93 20 ± 0.79
#> Hornet Sportabout Valiant Duster 360 Merc 240D
#> 16 ± 0.88 21 ± 0.80 15 ± 1.01 28 ± 1.18
#> Merc 230 Merc 280 Merc 280C Merc 450SE
#> 26 ± 0.94 20 ± 0.81 20 ± 0.81 15 ± 0.85
#> Merc 450SL Merc 450SLC Cadillac Fleetwood Lincoln Continental
#> 15 ± 0.85 15 ± 0.85 15 ± 0.79 15 ± 0.81
#> Chrysler Imperial Fiat 128 Honda Civic Toyota Corolla
#> 15 ± 0.89 28 ± 1.10 29 ± 1.41 28 ± 1.12
#> Toyota Corona Dodge Challenger AMC Javelin Camaro Z28
#> 26 ± 0.97 16 ± 1.09 16 ± 1.09 15 ± 1.01
#> Pontiac Firebird Fiat X1-9 Porsche 914-2 Lotus Europa
#> 16 ± 0.88 28 ± 1.10 26 ± 0.91 24 ± 1.25
#> Ford Pantera L Ferrari Dino Maserati Bora Volvo 142E
#> 14 ± 1.21 18 ± 1.44 13 ± 2.11 25 ± 1.17
This can be especially useful for the posterior_...() functions since you can put the resulting rvars in data frame alongside the data used to make the predictions:
cbind(mtcars_subset, epred = epred)
#> hp cyl mpg epred
#> Mazda RX4 110 6 21.0 20.48962 ± 0.7908150
#> Mazda RX4 Wag 110 6 21.0 20.48962 ± 0.7908150
#> Datsun 710 93 4 22.8 25.80881 ± 0.9256434
#> Hornet 4 Drive 110 6 21.4 20.48962 ± 0.7908150
#> Hornet Sportabout 175 8 18.7 15.51820 ± 0.8760158
#> Valiant 105 6 18.1 20.70694 ± 0.8036316
#> Duster 360 245 8 14.3 14.55358 ± 1.0112446
#> Merc 240D 62 4 24.4 28.07635 ± 1.1825719
#> Merc 230 95 4 22.8 25.66252 ± 0.9437801
#> Merc 280 123 6 19.2 19.92459 ± 0.8124373
#> Merc 280C 123 6 17.8 19.92459 ± 0.8124373
#> Merc 450SE 180 8 16.4 15.44930 ± 0.8459401
#> Merc 450SL 180 8 17.3 15.44930 ± 0.8459401
#> Merc 450SLC 180 8 15.2 15.44930 ± 0.8459401
#> Cadillac Fleetwood 205 8 10.4 15.10479 ± 0.7862891
#> Lincoln Continental 215 8 10.4 14.96699 ± 0.8091412
#> Chrysler Imperial 230 8 14.7 14.76028 ± 0.8889254
#> Fiat 128 66 4 32.4 27.78376 ± 1.1026873
#> Honda Civic 52 4 30.4 28.80781 ± 1.4145373
#> Toyota Corolla 65 4 33.9 27.85691 ± 1.1217972
#> Toyota Corona 97 4 21.5 25.51622 ± 0.9659045
#> Dodge Challenger 150 8 15.5 15.86271 ± 1.0899219
#> AMC Javelin 150 8 15.2 15.86271 ± 1.0899219
#> Camaro Z28 245 8 13.3 14.55358 ± 1.0112446
#> Pontiac Firebird 175 8 19.2 15.51820 ± 0.8760158
#> Fiat X1-9 66 4 27.3 27.78376 ± 1.1026873
#> Porsche 914-2 91 4 26.0 25.95510 ± 0.9117325
#> Lotus Europa 113 4 30.4 24.34588 ± 1.2535618
#> Ford Pantera L 264 8 15.8 14.29175 ± 1.2067307
#> Ferrari Dino 175 6 19.7 17.66450 ± 1.4376777
#> Maserati Bora 335 8 15.0 13.31335 ± 2.1102372
#> Volvo 142E 109 4 21.4 24.63846 ± 1.1669329
And if you do want it as a draws_matrix, you can use as_draws_matrix():
as_draws_matrix(epred)
#> # A draws_matrix: 1000 iterations, 4 chains, and 32 variables
#> variable
#> draw x[Mazda RX4] x[Mazda RX4 Wag] x[Datsun 710] x[Hornet 4 Drive]
#> 1 21 21 26 21
#> 2 21 21 25 21
#> 3 18 18 25 18
#> 4 21 21 24 21
#> 5 21 21 25 21
#> 6 21 21 24 21
#> 7 23 23 28 23
#> 8 21 21 26 21
#> 9 21 21 28 21
#> 10 22 22 25 22
#> variable
#> draw x[Hornet Sportabout] x[Valiant] x[Duster 360] x[Merc 240D]
#> 1 15 21 13 28
#> 2 15 21 13 28
#> 3 14 19 15 27
#> 4 17 21 17 26
#> 5 17 21 16 27
#> 6 17 21 17 26
#> 7 15 23 12 30
#> 8 16 21 16 28
#> 9 15 22 13 30
#> 10 18 22 17 27
#> # ... with 3990 more draws, and 24 more variables
That said, it does seem like since draws_matrix() has an .nchains argument, perhaps as_draws_matrix() should too?
Hi!
Indeed, as_draws_matrix(rvar(rstantools_samp, nchains=4)) gives me what I want for the example I quoted. Maybe all of the as_draws_* should have a .nchains argument? Certainly, the as_draws_matrix needs it... and this needs doc on the format posterior expects things to be (column-major).
Thanks!