Chapter 4 Updating models with stan
When we generate a model we often impose a lot of assumptions on nature of causal relations. This includes “structure” regarding what relates to what but also the nature of those relations—how strong the effect of a given variable is and how it interacts with others, for example. The latter features are captured by parameters whose values, fortunately, can be data based.
The approach used by the CausalQueries
package to updating parameter values given observed data uses stan
and involves the following elements:
- Dirichlet priors over parameters, \(\lambda\) (which, in cases without confounding, correspond to nodal types)
- A mapping from parameters to event probabilities, \(w\)
- A likelihood function that assumes events are distributed according to a multinomial distribution given event probabilities.
We provide further details below.
4.1 Data for stan
We use a generic stan
model that works for all binary causal models. Rather than writing a new stan
model for each causal model we send stan
details of each particular causal model as data inputs.
In particular we provide a set of matrices that stan
tailor itself to particular models: the parameter matrix (\(P\) ) tells stan
how many parameters there are, and how they map into causal types; an ambiguity matrix \(A\) tells stan
how causal types map into data types; and an event matrix \(E\) relates data types into patterns of observed data (in cases where there are incomplete observations).
The internal function prep_stan_data
prepares data for stan
. You generally don’t need to use this manually, but we show here a sample of what it produces as input for stan
.
We provide prep_stan_data
with data in compact form (listing “data events”).
model <- make_model("X->Y")
data <- data.frame(X = c(0, 1, 1, NA), Y = c(0, 1, 0, 1))
compact_data <- collapse_data(data, model)
kable(compact_data)
event | strategy | count |
---|---|---|
X0Y0 | XY | 1 |
X1Y0 | XY | 1 |
X0Y1 | XY | 0 |
X1Y1 | XY | 1 |
Y0 | Y | 0 |
Y1 | Y | 1 |
Note that NAs are interpreted as data not having been sought. So in this case the interpretation is that there are two data strategies: data on \(Y\) and \(X\) was sought in three cases; data on \(Y\) only was sought in just one case.
prep_stan_data
then returns a list of objects that stan
expects to receive. These include indicators to figure out where a parameter set starts (l_starts
, l_ends
) and ends and where a data strategy starts and ends (strategy_starts
, strategy_ends
), as well as the matrices described above.
$parmap
X0Y0 X1Y0 X0Y1 X1Y1
X.0 1 0 1 0
X.1 0 1 0 1
Y.00 1 1 0 0
Y.10 0 1 1 0
Y.01 1 0 0 1
Y.11 0 0 1 1
attr(,"map")
X0Y0 X1Y0 X0Y1 X1Y1
X0Y0 1 0 0 0
X1Y0 0 1 0 0
X0Y1 0 0 1 0
X1Y1 0 0 0 1
$map
X0Y0 X1Y0 X0Y1 X1Y1
X0Y0 1 0 0 0
X1Y0 0 1 0 0
X0Y1 0 0 1 0
X1Y1 0 0 0 1
$n_paths
[1] 4
$n_params
[1] 6
$n_param_sets
[1] 2
$n_param_each
X Y
2 4
$l_starts
X Y
1 3
$l_ends
X Y
2 6
$node_starts
X Y
1 3
$node_ends
X Y
2 6
$n_nodes
[1] 2
$lambdas_prior
X.0 X.1 Y.00 Y.10 Y.01 Y.11
1 1 1 1 1 1
$n_data
[1] 4
$n_events
[1] 6
$n_strategies
[1] 2
$strategy_starts
[1] 1 5
$strategy_ends
[1] 4 6
$keep_transformed
[1] 1
$E
X0Y0 X1Y0 X0Y1 X1Y1
X0Y0 1 0 0 0
X1Y0 0 1 0 0
X0Y1 0 0 1 0
X1Y1 0 0 0 1
Y0 1 1 0 0
Y1 0 0 1 1
$Y
[1] 1 1 0 1 0 1
$P
Rows are parameters, grouped in parameter sets
Columns are causal types
Cell entries indicate whether a parameter probability is used
in the calculation of causal type probability
X0.Y00 X1.Y00 X0.Y10 X1.Y10 X0.Y01 X1.Y01 X0.Y11
X.0 1 0 1 0 1 0 1
X.1 0 1 0 1 0 1 0
Y.00 1 1 0 0 0 0 0
Y.10 0 0 1 1 0 0 0
Y.01 0 0 0 0 1 1 0
Y.11 0 0 0 0 0 0 1
X1.Y11
X.0 0
X.1 1
Y.00 0
Y.10 0
Y.01 0
Y.11 1
param_set (P)
$n_types
[1] 8
4.2 stan
code
Below we show the stan
code. This starts off with a block saying what input data is to be expected. Then there is a characterization of parameters and the transformed parameters. Then the likelihoods and priors are provided. stan
takes it from there and generates a posterior distribution.
functions{
row_vector col_sums(matrix X) {
row_vector[cols(X)] s ;
s = rep_row_vector(1, rows(X)) * X ;
return s ;
}
}
data {
int<lower=1> n_params;
int<lower=1> n_paths;
int<lower=1> n_types;
int<lower=1> n_param_sets;
int<lower=1> n_nodes;
array[n_param_sets] int<lower=1> n_param_each;
int<lower=1> n_data;
int<lower=1> n_events;
int<lower=1> n_strategies;
int<lower=0, upper=1> keep_transformed;
vector<lower=0>[n_params] lambdas_prior;
array[n_param_sets] int<lower=1> l_starts;
array[n_param_sets] int<lower=1> l_ends;
array[n_nodes] int<lower=1> node_starts;
array[n_nodes] int<lower=1> node_ends;
array[n_strategies] int<lower=1> strategy_starts;
array[n_strategies] int<lower=1> strategy_ends;
matrix[n_params, n_types] P;
matrix[n_params, n_paths] parmap;
matrix[n_paths, n_data] map;
matrix<lower=0,upper=1>[n_events,n_data] E;
array[n_events] int<lower=0> Y;
}
parameters {
vector<lower=0>[n_params - n_param_sets] gamma;
}
transformed parameters {
vector<lower=0, upper=1>[n_params] lambdas;
vector<lower=1>[n_param_sets] sum_gammas;
matrix[n_params, n_paths] parlam;
matrix[n_nodes, n_paths] parlam2;
vector<lower=0, upper=1>[n_paths] w_0;
vector<lower=0, upper=1>[n_data] w;
vector<lower=0, upper=1>[n_events] w_full;
// Cases in which a parameter set has only one value need special handling
// they have no gamma components and sum_gamma needs to be made manually
for (i in 1:n_param_sets) {
if (l_starts[i] >= l_ends[i]) {
sum_gammas[i] = 1;
// syntax here to return unity as a vector
lambdas[l_starts[i]] = lambdas_prior[1]/lambdas_prior[1];
}
else if (l_starts[i] < l_ends[i]) {
sum_gammas[i] =
1 + sum(gamma[(l_starts[i] - (i-1)):(l_ends[i] - i)]);
lambdas[l_starts[i]:l_ends[i]] =
append_row(1, gamma[(l_starts[i] - (i-1)):(l_ends[i] - i)]) /
sum_gammas[i];
}
}
// Mapping from parameters to data types
// (usual case): [n_par * n_data] * [n_par * n_data]
parlam = rep_matrix(lambdas, n_paths) .* parmap;
// Sum probability over nodes on each path
for (i in 1:n_nodes) {
parlam2[i,] = col_sums(parlam[(node_starts[i]):(node_ends[i]),]);
}
// then take product to get probability of data type on path
for (i in 1:n_paths) {
w_0[i] = prod(parlam2[,i]);
}
// last (if confounding): map to n_data columns instead of n_paths
w = map'*w_0;
// Extend/reduce to cover all observed data types
w_full = E * w;
}
model {
// Dirichlet distributions (earlier versions used gamma)
for (i in 1:n_param_sets) {
target += dirichlet_lpdf(lambdas[l_starts[i]:l_ends[i]] |
lambdas_prior[l_starts[i] :l_ends[i]]);
target += -n_param_each[i] * log(sum_gammas[i]);
}
// Multinomials
// Note with censoring event_probabilities might not sum to 1
for (i in 1:n_strategies) {
target += multinomial_lpmf(
Y[strategy_starts[i]:strategy_ends[i]] |
w_full[strategy_starts[i]:strategy_ends[i]]/
sum(w_full[strategy_starts[i]:strategy_ends[i]]));
}
}
// Option to export distribution of causal types
generated quantities{
vector[n_types] prob_of_types;
if (keep_transformed == 1){
for (i in 1:n_types) {
prob_of_types[i] = prod(P[, i].*lambdas + 1 - P[,i]);
}}
if (keep_transformed == 0){
prob_of_types = rep_vector(1, n_types);
}
}
The stan
model works as follows (technical!):
We are interested in “sets” of parameters. For example in the \(X \rightarrow Y\) model we have two parameter sets (
param_sets
). The first is \(\lambda^X \in \{\lambda^X_0, \lambda^X_1\}\) whose elements give the probability that \(X\) is 0 or 1. These two probabilities sum to one. The second parameter set is \(\lambda^Y \in \{\lambda^Y_{00}, \lambda^Y_{10}, \lambda^Y_{01} \lambda^Y_{11}\}\). These are also probabilities and their values sum to one. Note in all that we have 6 parameters but just 1 + 3 = 4 degrees of freedom.We would like to express priors over these parameters using multiple Dirichlet distributions (two in this case). In practice because we are dealing with multiple simplices of varying length, it is easier to express priors over gamma distributions with a unit scale parameter and shape parameter corresponding to the Dirichlet priors, \(\alpha\). We make use of the fact that \(\lambda^X_0 \sim Gamma(\alpha^X_0,1)\) and \(\lambda^X_1 \sim Gamma(\alpha^X_1,1)\) then \(\frac{1}{\lambda^X_0 +\lambda^X_1}(\lambda^X_0, \lambda^X_1) \sim Dirichlet(\alpha^X_0, \alpha^X_1)\). For a discussion of implementation of this approach in
stan
see https://discourse.mc-stan.org/t/ragged-array-of-simplexes/1382.For any candidate parameter vector \(\lambda\) we calculate the probability of causal types (
prob_of_types
) by taking, for each type \(i\), the product of the probabilities of all parameters (\(\lambda_j\)) that appear in column \(i\) of the parameter matrix \(P\). Thus the probability of a \((X_0,Y_{00})\) case is just \(\lambda^X_0 \times \lambda^Y_{00}\). The implementations instan
usesprob_of_types_[i]
\(= \prod_j \left(P_{j,i} \lambda_j + (1-P_{j,i})\right)\): this multiplies the probability of all parameters involved in the causal type (and substitutes 1s for parameters that are not). (P
andnot_P
(1-\(P\)) are provided as data tostan
).The probability of data types,
w
, is given by summing up the probabilities of all causal types that produce a given data type. For example, the probability of a \(X=0,Y=0\) case, \(w_{00}\) is \(\lambda^X_0\times \lambda^Y_{00} + \lambda^X_0\times \lambda^Y_{01}\). The ambiguity matrix \(A\) is provided tostan
to indicate which probabilities need to be summed.In the case of incomplete data we first identify the set of “data strategies”, where a collection of a data strategy might be of the form “gather data on \(X\) and \(M\), but not \(Y\), for \(n_1\) cases and gather data on \(X\) and \(Y\), but not \(M\), for \(n_2\) cases. The probability of an observed event, within a data strategy, is given by summing the probabilities of the types that could give rise to the incomplete data. For example \(X\) is observed, but \(Y\) is not, then the probability of \(X=0, Y = \text{NA}\) is \(w_{00} +w_{01}\). The matrix \(E\) is passed to
stan
to figure out which event probabilities need to be combined for events with missing data.The probability of a dataset is then given by a multinomial distribution with these event probabilities (or, in the case of incomplete data, the product of multinomials, one for each data strategy). Justification for this approach relies on the likelihood principle and is discussed in Chapter 6.
4.3 Implementation
To update a CausalQueries model with data use:
where the data argument is a dataset containing some or all of the nodes in the model.
Other stan
arguments can be passed to update_data
, in particular:
iter
sets the number of iterations and ultimately the number of draws in the posteriorchains
sets the number of chains; doing multiple chains in parallel speeds things up- lots of other options via
?rstan::stan
If you have multiple cores you can do parallel processing by including this line before running CausalQueries
:
The stan
output from a simple model looks like this:
Inference for Stan model: simplexes.
4 chains, each with iter=2000; warmup=1000; thin=1;
post-warmup draws per chain=1000, total post-warmup draws=4000.
mean se_mean sd 2.5% 25%
gamma[1] 3.38 0.24 9.81 0.27 0.89
gamma[2] 5.55 2.11 94.64 0.02 0.24
gamma[3] 8.13 2.85 81.99 0.04 0.48
gamma[4] 5.62 2.25 71.60 0.03 0.35
lambdas[1] 0.39 0.00 0.20 0.06 0.23
lambdas[2] 0.61 0.00 0.20 0.21 0.47
lambdas[3] 0.26 0.00 0.17 0.01 0.12
lambdas[4] 0.20 0.00 0.15 0.01 0.07
lambdas[5] 0.31 0.00 0.19 0.02 0.15
lambdas[6] 0.23 0.00 0.17 0.01 0.09
sum_gammas[1] 4.38 0.24 9.81 1.27 1.89
sum_gammas[2] 20.30 7.18 226.68 1.56 2.63
parlam[1,1] 0.39 0.00 0.20 0.06 0.23
parlam[1,2] 0.00 NaN 0.00 0.00 0.00
parlam[1,3] 0.39 0.00 0.20 0.06 0.23
parlam[1,4] 0.00 NaN 0.00 0.00 0.00
parlam[2,1] 0.00 NaN 0.00 0.00 0.00
parlam[2,2] 0.61 0.00 0.20 0.21 0.47
parlam[2,3] 0.00 NaN 0.00 0.00 0.00
parlam[2,4] 0.61 0.00 0.20 0.21 0.47
parlam[3,1] 0.26 0.00 0.17 0.01 0.12
parlam[3,2] 0.26 0.00 0.17 0.01 0.12
parlam[3,3] 0.00 NaN 0.00 0.00 0.00
parlam[3,4] 0.00 NaN 0.00 0.00 0.00
parlam[4,1] 0.00 NaN 0.00 0.00 0.00
parlam[4,2] 0.20 0.00 0.15 0.01 0.07
parlam[4,3] 0.20 0.00 0.15 0.01 0.07
parlam[4,4] 0.00 NaN 0.00 0.00 0.00
parlam[5,1] 0.31 0.00 0.19 0.02 0.15
parlam[5,2] 0.00 NaN 0.00 0.00 0.00
parlam[5,3] 0.00 NaN 0.00 0.00 0.00
parlam[5,4] 0.31 0.00 0.19 0.02 0.15
parlam[6,1] 0.00 NaN 0.00 0.00 0.00
parlam[6,2] 0.00 NaN 0.00 0.00 0.00
parlam[6,3] 0.23 0.00 0.17 0.01 0.09
parlam[6,4] 0.23 0.00 0.17 0.01 0.09
parlam2[1,1] 0.39 0.00 0.20 0.06 0.23
parlam2[1,2] 0.61 0.00 0.20 0.21 0.47
parlam2[1,3] 0.39 0.00 0.20 0.06 0.23
parlam2[1,4] 0.61 0.00 0.20 0.21 0.47
parlam2[2,1] 0.57 0.00 0.20 0.17 0.42
parlam2[2,2] 0.46 0.00 0.18 0.12 0.32
parlam2[2,3] 0.43 0.00 0.20 0.07 0.28
parlam2[2,4] 0.54 0.00 0.18 0.19 0.41
w_0[1] 0.22 0.00 0.14 0.03 0.11
w_0[2] 0.28 0.00 0.14 0.05 0.17
w_0[3] 0.17 0.00 0.13 0.01 0.07
w_0[4] 0.34 0.00 0.16 0.07 0.22
w[1] 0.22 0.00 0.14 0.03 0.11
w[2] 0.28 0.00 0.14 0.05 0.17
w[3] 0.17 0.00 0.13 0.01 0.07
w[4] 0.34 0.00 0.16 0.07 0.22
w_full[1] 0.22 0.00 0.14 0.03 0.11
w_full[2] 0.28 0.00 0.14 0.05 0.17
w_full[3] 0.17 0.00 0.13 0.01 0.07
w_full[4] 0.34 0.00 0.16 0.07 0.22
w_full[5] 0.49 0.00 0.15 0.21 0.40
w_full[6] 0.51 0.00 0.15 0.23 0.40
prob_of_types[1] 0.10 0.00 0.09 0.00 0.03
prob_of_types[2] 0.16 0.00 0.12 0.01 0.06
prob_of_types[3] 0.08 0.00 0.08 0.00 0.02
prob_of_types[4] 0.12 0.00 0.10 0.00 0.04
prob_of_types[5] 0.11 0.00 0.10 0.00 0.04
prob_of_types[6] 0.19 0.00 0.14 0.01 0.08
prob_of_types[7] 0.09 0.00 0.09 0.00 0.03
prob_of_types[8] 0.14 0.00 0.12 0.00 0.05
lp__ -10.41 0.05 1.70 -14.74 -11.23
50% 75% 97.5% n_eff Rhat
gamma[1] 1.73 3.35 15.40 1716 1
gamma[2] 0.70 1.98 24.44 2021 1
gamma[3] 1.19 2.99 36.14 827 1
gamma[4] 0.88 2.16 20.19 1015 1
lambdas[1] 0.37 0.53 0.79 2420 1
lambdas[2] 0.63 0.77 0.94 2420 1
lambdas[3] 0.24 0.38 0.64 1783 1
lambdas[4] 0.16 0.29 0.57 3125 1
lambdas[5] 0.29 0.44 0.72 3146 1
lambdas[6] 0.20 0.34 0.63 3925 1
sum_gammas[1] 2.73 4.35 16.40 1716 1
sum_gammas[2] 4.22 8.16 82.29 996 1
parlam[1,1] 0.37 0.53 0.79 2420 1
parlam[1,2] 0.00 0.00 0.00 NaN NaN
parlam[1,3] 0.37 0.53 0.79 2420 1
parlam[1,4] 0.00 0.00 0.00 NaN NaN
parlam[2,1] 0.00 0.00 0.00 NaN NaN
parlam[2,2] 0.63 0.77 0.94 2420 1
parlam[2,3] 0.00 0.00 0.00 NaN NaN
parlam[2,4] 0.63 0.77 0.94 2420 1
parlam[3,1] 0.24 0.38 0.64 1783 1
parlam[3,2] 0.24 0.38 0.64 1783 1
parlam[3,3] 0.00 0.00 0.00 NaN NaN
parlam[3,4] 0.00 0.00 0.00 NaN NaN
parlam[4,1] 0.00 0.00 0.00 NaN NaN
parlam[4,2] 0.16 0.29 0.57 3125 1
parlam[4,3] 0.16 0.29 0.57 3125 1
parlam[4,4] 0.00 0.00 0.00 NaN NaN
parlam[5,1] 0.29 0.44 0.72 3146 1
parlam[5,2] 0.00 0.00 0.00 NaN NaN
parlam[5,3] 0.00 0.00 0.00 NaN NaN
parlam[5,4] 0.29 0.44 0.72 3146 1
parlam[6,1] 0.00 0.00 0.00 NaN NaN
parlam[6,2] 0.00 0.00 0.00 NaN NaN
parlam[6,3] 0.20 0.34 0.63 3925 1
parlam[6,4] 0.20 0.34 0.63 3925 1
parlam2[1,1] 0.37 0.53 0.79 2420 1
parlam2[1,2] 0.63 0.77 0.94 2420 1
parlam2[1,3] 0.37 0.53 0.79 2420 1
parlam2[1,4] 0.63 0.77 0.94 2420 1
parlam2[2,1] 0.57 0.72 0.93 3091 1
parlam2[2,2] 0.45 0.59 0.81 3448 1
parlam2[2,3] 0.43 0.58 0.83 3091 1
parlam2[2,4] 0.55 0.68 0.88 3448 1
w_0[1] 0.19 0.30 0.54 2606 1
w_0[2] 0.26 0.37 0.60 3404 1
w_0[3] 0.14 0.24 0.48 2795 1
w_0[4] 0.32 0.45 0.69 2454 1
w[1] 0.19 0.30 0.54 2606 1
w[2] 0.26 0.37 0.60 3404 1
w[3] 0.14 0.24 0.48 2795 1
w[4] 0.32 0.45 0.69 2454 1
w_full[1] 0.19 0.30 0.54 2606 1
w_full[2] 0.26 0.37 0.60 3404 1
w_full[3] 0.14 0.24 0.48 2795 1
w_full[4] 0.32 0.45 0.69 2454 1
w_full[5] 0.50 0.60 0.77 2542 1
w_full[6] 0.50 0.60 0.79 2542 1
prob_of_types[1] 0.08 0.14 0.35 2015 1
prob_of_types[2] 0.13 0.23 0.46 1866 1
prob_of_types[3] 0.05 0.11 0.31 2543 1
prob_of_types[4] 0.09 0.17 0.37 3540 1
prob_of_types[5] 0.09 0.16 0.36 3012 1
prob_of_types[6] 0.16 0.28 0.53 2627 1
prob_of_types[7] 0.06 0.13 0.32 3336 1
prob_of_types[8] 0.11 0.21 0.45 3582 1
lp__ -10.00 -9.16 -8.32 1187 1
Samples were drawn using NUTS(diag_e) at Thu Oct 12 16:06:37 2023.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at
convergence, Rhat=1).
Note the parameters include the gamma parameters plus transformed parameters, \(\lambda\), which are our parameters of interest and which CausalQueries
then interprets as possible row probabilities for the \(P\) matrix.
4.4 Extensions
4.4.1 Arbitrary parameters
Although the package provides helpers to generate mappings from parameters to causal types via nodal types, it is possible to dispense with the nodal types altogether and provide a direct mapping from parameters to causal types.
For this you need to manually provide a P
matrix and a corresponding parameters_df
. As an example here is a model with complete confounding and parameters that correspond to causal types directly.
model <- make_model("X->Y")
model$P <- diag(8)
colnames(model$P) <- rownames(model$causal_types)
model$parameters_df <- data.frame(
param_names = paste0("x",1:8),
param_set = 1,
priors = 1,
parameters = 1/8)
# Update fully confounded model on strongly correlated data
model <- make_model("X->Y")
data <- make_data(make_model("X->Y"), n = 100, parameters = c(.5, .5, .1,.1,.7,.1))
fully_confounded <- update_model(model, data)
4.4.2 Non binary data
In principle the stan
model could be extended to handle non binary data. Though a limitation of the current package there is no structural reason why nodes should be constrained to be dichotomous. The set of nodal and causal types however expands even more rapidly in the case of non binary data. .