-
Notifications
You must be signed in to change notification settings - Fork 0
/
brooks_gelman_stat_multivariate.pro
111 lines (98 loc) · 3.13 KB
/
brooks_gelman_stat_multivariate.pro
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
function brooks_gelman_stat_multivariate,chain
;+
; NAME:
; BROOKS_GELMAN_STAT_MULTIVARIATE
; PURPOSE:
; To compute the Gelman-Rubin statistic (r_hat) [1] of multiple multivariate Markov chains
; EXPLANATION:
; Use multiple Markov chains to compute their Brooks-Gelman statistic using
; the multivariate approach as given by Brooks and Gelman (1998).
; Uses the within-chain variance W, and between-chain variance
; B to compare convergence of the chains.
;
; CALLING SEQUENCE:
; brooks_gelman_stat_multivariate(chain)
;
; INPUTS:
; chain - a P x N x M array of M Markov chains of N elements and P
; parameters. Optionally, chain can be a filename pointing
; to an IDL SAVE file containing an MCMC chain called 'chain'
; as described above.
;;
; OUTPUTS:
; r_hat - the Brooks-Gelman statistic of the chains, the sqrt(r_hat) should be
; less than or equal to 1.2 if convergence of the chains was reached.
; An exact value of 1 mean the chains are exactly the same.
;
; EXAMPLE USAGE:
; IDL> chain = randomn(seed,10,1000,3)
; IDL> r_hat = brooks_gelman_stat_multivariate(chain)
; IDL> print,r_hat
;
; IDL> chain = randomn(seed,1000,3)
; IDL> save, chain, filename='path-to-chain.sav'
; IDL> ; some time later...
; IDL> restore, 'path-to-chain.sav'
; IDL> r_hat = brooks_gelman_stat_multivariate(chain)
; IDL> print,r_hat
;
; REFERENCE:
; [1]: http://www.stat.columbia.edu/~gelman/research/published/brooksgelman2.pdf
;
; REVISON HISTORY:
; Written by K. Doore, 8/26/2021
; Added reference and option to read from file, E.B. Monson, 8/26/2021
;-
Compile_opt idl2
On_error,2
; Check arguments
if (n_params() ne 1 ) then begin
print,'Syntax - brooks_gelman_stat_multivariate(chain)'
return,!null
endif
; If `chain` is a filename then try to RESTORE it.
; If this succeeds it should define a new variable called
; `chain` with the appropriate array structure.
if(isa(chain, /string)) then begin
print, 'Restoring ', strtrim(chain, 2)
restore, chain
endif
; Check for allowable input chains
if size(chain,/type) lt 2 or size(chain,/type) gt 5 then begin
print,'Chain is incorrect data type'
return,!null
endif
size_chain = size(chain)
dim = size_chain[0]
if dim ne 3 then begin
print,'Chain must be three-dimensional'
return,!null
endif
N = double(size_chain[2])
M = double(size_chain[3])
P = double(size_chain[1])
; Check Chains are long enough
if (N lt 10) then begin
print,'Chains not long enough to compute r_hat'
return,!null
endif
; compute Gelman-Rubin statistic
psi_j_t = chain
psi_j_dot = mean(chain,dim=2)
psi_dot_dot = mean(mean(chain,dim=2),dim=2)
w_sum = 0.d0
bn_sum = 0.d0
for j = 0,M-1 do begin
for t = 0,N-1 do begin
w_sum += (psi_j_t[*,t,j] - psi_j_dot[*,j]) ## (psi_j_t[*,t,j] - psi_j_dot[*,j])
endfor
bn_sum += (psi_j_dot[*,j] - psi_dot_dot) ## (psi_j_dot[*,j] - psi_dot_dot)
endfor
W = 1/(M*(N-1))*w_sum
Bn = 1/(m-1)*bn_sum
temp = invert(W) ## Bn
temp = (temp+transpose(temp))/2.
lam1 = max(eigenql(temp))
r_hat = (N-1)/N + (M+1)/M * lam1
return, r_hat
end