-
-
Notifications
You must be signed in to change notification settings - Fork 478
/
3.7_Prediction.R
60 lines (48 loc) · 1.91 KB
/
3.7_Prediction.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
library(rstan)
library(ggplot2)
### Data
source("ARM/Ch.3/kidiq.data.R", echo = TRUE)
### Model: kid_score ~ mom_hs + mom_iq
data.list <- c("N", "kid_score", "mom_hs", "mom_iq", "mom_hs_new", "mom_iq_new")
kidiq_prediction.sf <- stan(file = 'ARM/Ch.3/kidiq_prediction.stan',
data = data.list,
iter = 500, chains = 4)
print(kidiq_prediction.sf, pars = c("kid_score_pred"), prob = c(0.025, 0.975))
pairs(kidiq_prediction.sf)
### Data
source("ARM/Ch.3/kids_before1987.data.R", echo = TRUE)
### Model: ppvt ~ hs + afqt
data.list <- c("N", "ppvt", "hs", "afqt")
kidiq_pre1987.sf <- stan(file = 'ARM/Ch.3/kidiq_validation.stan',
data = data.list,
iter = 500, chains = 4)
print(kidiq_pre1987.sf, pars = c("beta", "sigma", "lp__"))
pairs(kidiq_prediction.sf)
### External validation
## Data
source("ARM/Ch.3/kids_after1987.data.R", echo = TRUE)
## Predicted scores
beta.post <- extract(kidiq_pre1987.sf, "beta")$beta
beta.mean <- colMeans(beta.post)
cscores.new <- beta.mean[1] + beta.mean[2] * hs_ev + beta.mean[3] * afqt_ev
resid <- ppvt_ev - cscores.new
resid.sd <- sd(resid)
## Figure 3.13
# left
p1 <- ggplot(data.frame(cscores.new, ppvt_ev), aes(x = cscores.new, y = ppvt_ev)) +
geom_point() +
geom_abline(intercept = 0, slope = 1) +
scale_x_continuous("Predicted score", limits = c(20, 140), breaks = seq(20, 140, 20)) +
scale_y_continuous("Actual score", limits = c(20, 140), breaks = seq(20, 140, 20)) +
theme_bw()
print(p1)
# right
dev.new()
p2 <- ggplot(data.frame(cscores.new, resid), aes(x = cscores.new, y = resid)) +
geom_point() +
geom_hline(yintercept = 0) +
geom_hline(yintercept = c(-resid.sd, resid.sd), linetype = "dashed") +
scale_x_continuous("Predicted score", breaks = seq(70, 100, 10)) +
scale_y_continuous("Prediction error", breaks = seq(-60, 40, 20)) +
theme_bw()
print(p2)