Chapter 4 Updating models with stan

When we generate a model we often impose a lot of assumptions on nature of causal relations. This includes “structure” regarding what relates to what but also the nature of those relations—how strong the effect of a given variable is and how it interacts with others, for example. The latter features are captured by parameters whose values, fortunately, can be data based.

The approach used by the CausalQueries package to updating parameter values given observed data uses stan and involves the following elements:

  • Dirichlet priors over parameters, \(\lambda\) (which, in cases without confounding, correspond to nodal types)
  • A mapping from parameters to event probabilities, \(w\)
  • A likelihood function that assumes events are distributed according to a multinomial distribution given event probabilities.

We provide further details below.

4.1 Data for stan

We use a generic stan model that works for all binary causal models. Rather than writing a new stan model for each causal model we send stan details of each particular causal model as data inputs.

In particular we provide a set of matrices that stan tailor itself to particular models: the parameter matrix (\(P\) ) tells stan how many parameters there are, and how they map into causal types; an ambiguity matrix \(A\) tells stan how causal types map into data types; and an event matrix \(E\) relates data types into patterns of observed data (in cases where there are incomplete observations).

The internal function prep_stan_data prepares data for stan. You generally don’t need to use this manually, but we show here a sample of what it produces as input for stan.

We provide prep_stan_data with data in compact form (listing “data events”).

model <- make_model("X->Y")

data  <- data.frame(X = c(0, 1, 1, NA), Y = c(0, 1, 0, 1)) 

compact_data <-  collapse_data(data, model) 

kable(compact_data)
event strategy count
X0Y0 XY 1
X1Y0 XY 1
X0Y1 XY 0
X1Y1 XY 1
Y0 Y 0
Y1 Y 1

Note that NAs are interpreted as data not having been sought. So in this case the interpretation is that there are two data strategies: data on \(Y\) and \(X\) was sought in three cases; data on \(Y\) only was sought in just one case.

prep_stan_data then returns a list of objects that stan expects to receive. These include indicators to figure out where a parameter set starts (l_starts, l_ends) and ends and where a data strategy starts and ends (strategy_starts, strategy_ends), as well as the matrices described above.

CausalQueries:::prep_stan_data(model, compact_data)
$parmap
     X0Y0 X1Y0 X0Y1 X1Y1
X.0     1    0    1    0
X.1     0    1    0    1
Y.00    1    1    0    0
Y.10    0    1    1    0
Y.01    1    0    0    1
Y.11    0    0    1    1
attr(,"map")
     X0Y0 X1Y0 X0Y1 X1Y1
X0Y0    1    0    0    0
X1Y0    0    1    0    0
X0Y1    0    0    1    0
X1Y1    0    0    0    1

$map
     X0Y0 X1Y0 X0Y1 X1Y1
X0Y0    1    0    0    0
X1Y0    0    1    0    0
X0Y1    0    0    1    0
X1Y1    0    0    0    1

$n_paths
[1] 4

$n_params
[1] 6

$n_param_sets
[1] 2

$n_param_each
X Y 
2 4 

$l_starts
X Y 
1 3 

$l_ends
X Y 
2 6 

$node_starts
X Y 
1 3 

$node_ends
X Y 
2 6 

$n_nodes
[1] 2

$lambdas_prior
 X.0  X.1 Y.00 Y.10 Y.01 Y.11 
   1    1    1    1    1    1 

$n_data
[1] 4

$n_events
[1] 6

$n_strategies
[1] 2

$strategy_starts
[1] 1 5

$strategy_ends
[1] 4 6

$keep_transformed
[1] 1

$E
     X0Y0 X1Y0 X0Y1 X1Y1
X0Y0    1    0    0    0
X1Y0    0    1    0    0
X0Y1    0    0    1    0
X1Y1    0    0    0    1
Y0      1    1    0    0
Y1      0    0    1    1

$Y
[1] 1 1 0 1 0 1

$P

Rows are parameters, grouped in parameter sets

Columns are causal types

Cell entries indicate whether a parameter probability is used
in the calculation of causal type probability

     X0.Y00 X1.Y00 X0.Y10 X1.Y10 X0.Y01 X1.Y01 X0.Y11
X.0       1      0      1      0      1      0      1
X.1       0      1      0      1      0      1      0
Y.00      1      1      0      0      0      0      0
Y.10      0      0      1      1      0      0      0
Y.01      0      0      0      0      1      1      0
Y.11      0      0      0      0      0      0      1
     X1.Y11
X.0       0
X.1       1
Y.00      0
Y.10      0
Y.01      0
Y.11      1

 
 param_set  (P)
 
$n_types
[1] 8

4.2 stan code

Below we show the stan code. This starts off with a block saying what input data is to be expected. Then there is a characterization of parameters and the transformed parameters. Then the likelihoods and priors are provided. stan takes it from there and generates a posterior distribution.

functions{
  row_vector col_sums(matrix X) {
    row_vector[cols(X)] s ;
    s = rep_row_vector(1, rows(X)) * X ;
    return s ;
  }
}
data {
int<lower=1> n_params;
int<lower=1> n_paths;
int<lower=1> n_types;
int<lower=1> n_param_sets;
int<lower=1> n_nodes;
array[n_param_sets] int<lower=1> n_param_each;
int<lower=1> n_data;
int<lower=1> n_events;
int<lower=1> n_strategies;
int<lower=0, upper=1> keep_transformed;
vector<lower=0>[n_params] lambdas_prior;
array[n_param_sets] int<lower=1> l_starts;
array[n_param_sets] int<lower=1> l_ends;
array[n_nodes] int<lower=1> node_starts;
array[n_nodes] int<lower=1> node_ends;
array[n_strategies] int<lower=1> strategy_starts;
array[n_strategies] int<lower=1> strategy_ends;
matrix[n_params, n_types] P;
matrix[n_params, n_paths] parmap;
matrix[n_paths, n_data] map;
matrix<lower=0,upper=1>[n_events,n_data] E;
array[n_events] int<lower=0> Y;
}
parameters {
vector<lower=0>[n_params - n_param_sets] gamma;
}
transformed parameters {
vector<lower=0, upper=1>[n_params] lambdas;
vector<lower=1>[n_param_sets] sum_gammas;
matrix[n_params, n_paths] parlam;
matrix[n_nodes, n_paths] parlam2;
vector<lower=0, upper=1>[n_paths] w_0;
vector<lower=0, upper=1>[n_data] w;
vector<lower=0, upper=1>[n_events] w_full;
// Cases in which a parameter set has only one value need special handling
// they have no gamma components and sum_gamma needs to be made manually
for (i in 1:n_param_sets) {
  if (l_starts[i] >= l_ends[i]) {
    sum_gammas[i] = 1;
    // syntax here to return unity as a vector
    lambdas[l_starts[i]] = lambdas_prior[1]/lambdas_prior[1];
    }
  else if (l_starts[i] < l_ends[i]) {
    sum_gammas[i] =
    1 + sum(gamma[(l_starts[i] - (i-1)):(l_ends[i] - i)]);
    lambdas[l_starts[i]:l_ends[i]] =
    append_row(1, gamma[(l_starts[i] - (i-1)):(l_ends[i] - i)]) /
      sum_gammas[i];
    }
  }
// Mapping from parameters to data types
// (usual case): [n_par * n_data] * [n_par * n_data]
parlam  = rep_matrix(lambdas, n_paths) .* parmap;
// Sum probability over nodes on each path
for (i in 1:n_nodes) {
 parlam2[i,] = col_sums(parlam[(node_starts[i]):(node_ends[i]),]);
 }
// then take product  to get probability of data type on path
for (i in 1:n_paths) {
  w_0[i] = prod(parlam2[,i]);
 }
 // last (if confounding): map to n_data columns instead of n_paths
 w = map'*w_0;
  // Extend/reduce to cover all observed data types
 w_full = E * w;
}
model {
// Dirichlet distributions (earlier versions used gamma)
for (i in 1:n_param_sets) {
  target += dirichlet_lpdf(lambdas[l_starts[i]:l_ends[i]]  |
    lambdas_prior[l_starts[i] :l_ends[i]]);
  target += -n_param_each[i] * log(sum_gammas[i]);
 }
// Multinomials
// Note with censoring event_probabilities might not sum to 1
for (i in 1:n_strategies) {
  target += multinomial_lpmf(
  Y[strategy_starts[i]:strategy_ends[i]] |
    w_full[strategy_starts[i]:strategy_ends[i]]/
     sum(w_full[strategy_starts[i]:strategy_ends[i]]));
 }
}
// Option to export distribution of causal types
generated quantities{
vector[n_types] prob_of_types;
if (keep_transformed == 1){
for (i in 1:n_types) {
   prob_of_types[i] = prod(P[, i].*lambdas + 1 - P[,i]);
}}
 if (keep_transformed == 0){
    prob_of_types = rep_vector(1, n_types);
 }
}

The stan model works as follows (technical!):

  • We are interested in “sets” of parameters. For example in the \(X \rightarrow Y\) model we have two parameter sets (param_sets). The first is \(\lambda^X \in \{\lambda^X_0, \lambda^X_1\}\) whose elements give the probability that \(X\) is 0 or 1. These two probabilities sum to one. The second parameter set is \(\lambda^Y \in \{\lambda^Y_{00}, \lambda^Y_{10}, \lambda^Y_{01} \lambda^Y_{11}\}\). These are also probabilities and their values sum to one. Note in all that we have 6 parameters but just 1 + 3 = 4 degrees of freedom.

  • We would like to express priors over these parameters using multiple Dirichlet distributions (two in this case). In practice because we are dealing with multiple simplices of varying length, it is easier to express priors over gamma distributions with a unit scale parameter and shape parameter corresponding to the Dirichlet priors, \(\alpha\). We make use of the fact that \(\lambda^X_0 \sim Gamma(\alpha^X_0,1)\) and \(\lambda^X_1 \sim Gamma(\alpha^X_1,1)\) then \(\frac{1}{\lambda^X_0 +\lambda^X_1}(\lambda^X_0, \lambda^X_1) \sim Dirichlet(\alpha^X_0, \alpha^X_1)\). For a discussion of implementation of this approach in stan see https://discourse.mc-stan.org/t/ragged-array-of-simplexes/1382.

  • For any candidate parameter vector \(\lambda\) we calculate the probability of causal types (prob_of_types) by taking, for each type \(i\), the product of the probabilities of all parameters (\(\lambda_j\)) that appear in column \(i\) of the parameter matrix \(P\). Thus the probability of a \((X_0,Y_{00})\) case is just \(\lambda^X_0 \times \lambda^Y_{00}\). The implementations in stan uses prob_of_types_[i] \(= \prod_j \left(P_{j,i} \lambda_j + (1-P_{j,i})\right)\): this multiplies the probability of all parameters involved in the causal type (and substitutes 1s for parameters that are not). (P and not_P (1-\(P\)) are provided as data to stan).

  • The probability of data types, w, is given by summing up the probabilities of all causal types that produce a given data type. For example, the probability of a \(X=0,Y=0\) case, \(w_{00}\) is \(\lambda^X_0\times \lambda^Y_{00} + \lambda^X_0\times \lambda^Y_{01}\). The ambiguity matrix \(A\) is provided to stan to indicate which probabilities need to be summed.

  • In the case of incomplete data we first identify the set of “data strategies”, where a collection of a data strategy might be of the form “gather data on \(X\) and \(M\), but not \(Y\), for \(n_1\) cases and gather data on \(X\) and \(Y\), but not \(M\), for \(n_2\) cases. The probability of an observed event, within a data strategy, is given by summing the probabilities of the types that could give rise to the incomplete data. For example \(X\) is observed, but \(Y\) is not, then the probability of \(X=0, Y = \text{NA}\) is \(w_{00} +w_{01}\). The matrix \(E\) is passed to stan to figure out which event probabilities need to be combined for events with missing data.

  • The probability of a dataset is then given by a multinomial distribution with these event probabilities (or, in the case of incomplete data, the product of multinomials, one for each data strategy). Justification for this approach relies on the likelihood principle and is discussed in Chapter 6.

4.3 Implementation

To update a CausalQueries model with data use:

update_model(model, data)

where the data argument is a dataset containing some or all of the nodes in the model.

Other stan arguments can be passed to update_data, in particular:

  • iter sets the number of iterations and ultimately the number of draws in the posterior
  • chains sets the number of chains; doing multiple chains in parallel speeds things up
  • lots of other options via ?rstan::stan

If you have multiple cores you can do parallel processing by including this line before running CausalQueries:

options(mc.cores = parallel::detectCores())

The stan output from a simple model looks like this:

Inference for Stan model: simplexes.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

                   mean se_mean     sd   2.5%    25%
gamma[1]           3.38    0.24   9.81   0.27   0.89
gamma[2]           5.55    2.11  94.64   0.02   0.24
gamma[3]           8.13    2.85  81.99   0.04   0.48
gamma[4]           5.62    2.25  71.60   0.03   0.35
lambdas[1]         0.39    0.00   0.20   0.06   0.23
lambdas[2]         0.61    0.00   0.20   0.21   0.47
lambdas[3]         0.26    0.00   0.17   0.01   0.12
lambdas[4]         0.20    0.00   0.15   0.01   0.07
lambdas[5]         0.31    0.00   0.19   0.02   0.15
lambdas[6]         0.23    0.00   0.17   0.01   0.09
sum_gammas[1]      4.38    0.24   9.81   1.27   1.89
sum_gammas[2]     20.30    7.18 226.68   1.56   2.63
parlam[1,1]        0.39    0.00   0.20   0.06   0.23
parlam[1,2]        0.00     NaN   0.00   0.00   0.00
parlam[1,3]        0.39    0.00   0.20   0.06   0.23
parlam[1,4]        0.00     NaN   0.00   0.00   0.00
parlam[2,1]        0.00     NaN   0.00   0.00   0.00
parlam[2,2]        0.61    0.00   0.20   0.21   0.47
parlam[2,3]        0.00     NaN   0.00   0.00   0.00
parlam[2,4]        0.61    0.00   0.20   0.21   0.47
parlam[3,1]        0.26    0.00   0.17   0.01   0.12
parlam[3,2]        0.26    0.00   0.17   0.01   0.12
parlam[3,3]        0.00     NaN   0.00   0.00   0.00
parlam[3,4]        0.00     NaN   0.00   0.00   0.00
parlam[4,1]        0.00     NaN   0.00   0.00   0.00
parlam[4,2]        0.20    0.00   0.15   0.01   0.07
parlam[4,3]        0.20    0.00   0.15   0.01   0.07
parlam[4,4]        0.00     NaN   0.00   0.00   0.00
parlam[5,1]        0.31    0.00   0.19   0.02   0.15
parlam[5,2]        0.00     NaN   0.00   0.00   0.00
parlam[5,3]        0.00     NaN   0.00   0.00   0.00
parlam[5,4]        0.31    0.00   0.19   0.02   0.15
parlam[6,1]        0.00     NaN   0.00   0.00   0.00
parlam[6,2]        0.00     NaN   0.00   0.00   0.00
parlam[6,3]        0.23    0.00   0.17   0.01   0.09
parlam[6,4]        0.23    0.00   0.17   0.01   0.09
parlam2[1,1]       0.39    0.00   0.20   0.06   0.23
parlam2[1,2]       0.61    0.00   0.20   0.21   0.47
parlam2[1,3]       0.39    0.00   0.20   0.06   0.23
parlam2[1,4]       0.61    0.00   0.20   0.21   0.47
parlam2[2,1]       0.57    0.00   0.20   0.17   0.42
parlam2[2,2]       0.46    0.00   0.18   0.12   0.32
parlam2[2,3]       0.43    0.00   0.20   0.07   0.28
parlam2[2,4]       0.54    0.00   0.18   0.19   0.41
w_0[1]             0.22    0.00   0.14   0.03   0.11
w_0[2]             0.28    0.00   0.14   0.05   0.17
w_0[3]             0.17    0.00   0.13   0.01   0.07
w_0[4]             0.34    0.00   0.16   0.07   0.22
w[1]               0.22    0.00   0.14   0.03   0.11
w[2]               0.28    0.00   0.14   0.05   0.17
w[3]               0.17    0.00   0.13   0.01   0.07
w[4]               0.34    0.00   0.16   0.07   0.22
w_full[1]          0.22    0.00   0.14   0.03   0.11
w_full[2]          0.28    0.00   0.14   0.05   0.17
w_full[3]          0.17    0.00   0.13   0.01   0.07
w_full[4]          0.34    0.00   0.16   0.07   0.22
w_full[5]          0.49    0.00   0.15   0.21   0.40
w_full[6]          0.51    0.00   0.15   0.23   0.40
prob_of_types[1]   0.10    0.00   0.09   0.00   0.03
prob_of_types[2]   0.16    0.00   0.12   0.01   0.06
prob_of_types[3]   0.08    0.00   0.08   0.00   0.02
prob_of_types[4]   0.12    0.00   0.10   0.00   0.04
prob_of_types[5]   0.11    0.00   0.10   0.00   0.04
prob_of_types[6]   0.19    0.00   0.14   0.01   0.08
prob_of_types[7]   0.09    0.00   0.09   0.00   0.03
prob_of_types[8]   0.14    0.00   0.12   0.00   0.05
lp__             -10.41    0.05   1.70 -14.74 -11.23
                    50%   75% 97.5% n_eff Rhat
gamma[1]           1.73  3.35 15.40  1716    1
gamma[2]           0.70  1.98 24.44  2021    1
gamma[3]           1.19  2.99 36.14   827    1
gamma[4]           0.88  2.16 20.19  1015    1
lambdas[1]         0.37  0.53  0.79  2420    1
lambdas[2]         0.63  0.77  0.94  2420    1
lambdas[3]         0.24  0.38  0.64  1783    1
lambdas[4]         0.16  0.29  0.57  3125    1
lambdas[5]         0.29  0.44  0.72  3146    1
lambdas[6]         0.20  0.34  0.63  3925    1
sum_gammas[1]      2.73  4.35 16.40  1716    1
sum_gammas[2]      4.22  8.16 82.29   996    1
parlam[1,1]        0.37  0.53  0.79  2420    1
parlam[1,2]        0.00  0.00  0.00   NaN  NaN
parlam[1,3]        0.37  0.53  0.79  2420    1
parlam[1,4]        0.00  0.00  0.00   NaN  NaN
parlam[2,1]        0.00  0.00  0.00   NaN  NaN
parlam[2,2]        0.63  0.77  0.94  2420    1
parlam[2,3]        0.00  0.00  0.00   NaN  NaN
parlam[2,4]        0.63  0.77  0.94  2420    1
parlam[3,1]        0.24  0.38  0.64  1783    1
parlam[3,2]        0.24  0.38  0.64  1783    1
parlam[3,3]        0.00  0.00  0.00   NaN  NaN
parlam[3,4]        0.00  0.00  0.00   NaN  NaN
parlam[4,1]        0.00  0.00  0.00   NaN  NaN
parlam[4,2]        0.16  0.29  0.57  3125    1
parlam[4,3]        0.16  0.29  0.57  3125    1
parlam[4,4]        0.00  0.00  0.00   NaN  NaN
parlam[5,1]        0.29  0.44  0.72  3146    1
parlam[5,2]        0.00  0.00  0.00   NaN  NaN
parlam[5,3]        0.00  0.00  0.00   NaN  NaN
parlam[5,4]        0.29  0.44  0.72  3146    1
parlam[6,1]        0.00  0.00  0.00   NaN  NaN
parlam[6,2]        0.00  0.00  0.00   NaN  NaN
parlam[6,3]        0.20  0.34  0.63  3925    1
parlam[6,4]        0.20  0.34  0.63  3925    1
parlam2[1,1]       0.37  0.53  0.79  2420    1
parlam2[1,2]       0.63  0.77  0.94  2420    1
parlam2[1,3]       0.37  0.53  0.79  2420    1
parlam2[1,4]       0.63  0.77  0.94  2420    1
parlam2[2,1]       0.57  0.72  0.93  3091    1
parlam2[2,2]       0.45  0.59  0.81  3448    1
parlam2[2,3]       0.43  0.58  0.83  3091    1
parlam2[2,4]       0.55  0.68  0.88  3448    1
w_0[1]             0.19  0.30  0.54  2606    1
w_0[2]             0.26  0.37  0.60  3404    1
w_0[3]             0.14  0.24  0.48  2795    1
w_0[4]             0.32  0.45  0.69  2454    1
w[1]               0.19  0.30  0.54  2606    1
w[2]               0.26  0.37  0.60  3404    1
w[3]               0.14  0.24  0.48  2795    1
w[4]               0.32  0.45  0.69  2454    1
w_full[1]          0.19  0.30  0.54  2606    1
w_full[2]          0.26  0.37  0.60  3404    1
w_full[3]          0.14  0.24  0.48  2795    1
w_full[4]          0.32  0.45  0.69  2454    1
w_full[5]          0.50  0.60  0.77  2542    1
w_full[6]          0.50  0.60  0.79  2542    1
prob_of_types[1]   0.08  0.14  0.35  2015    1
prob_of_types[2]   0.13  0.23  0.46  1866    1
prob_of_types[3]   0.05  0.11  0.31  2543    1
prob_of_types[4]   0.09  0.17  0.37  3540    1
prob_of_types[5]   0.09  0.16  0.36  3012    1
prob_of_types[6]   0.16  0.28  0.53  2627    1
prob_of_types[7]   0.06  0.13  0.32  3336    1
prob_of_types[8]   0.11  0.21  0.45  3582    1
lp__             -10.00 -9.16 -8.32  1187    1

Samples were drawn using NUTS(diag_e) at Thu Oct 12 16:06:37 2023.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

Note the parameters include the gamma parameters plus transformed parameters, \(\lambda\), which are our parameters of interest and which CausalQueries then interprets as possible row probabilities for the \(P\) matrix.

4.4 Extensions

4.4.1 Arbitrary parameters

Although the package provides helpers to generate mappings from parameters to causal types via nodal types, it is possible to dispense with the nodal types altogether and provide a direct mapping from parameters to causal types.

For this you need to manually provide a P matrix and a corresponding parameters_df. As an example here is a model with complete confounding and parameters that correspond to causal types directly.

model <- make_model("X->Y")

model$P <- diag(8)
colnames(model$P) <- rownames(model$causal_types)

model$parameters_df <- data.frame(
  param_names = paste0("x",1:8), 
  param_set = 1, 
  priors = 1, 
  parameters = 1/8)

# Update fully confounded model on strongly correlated data
model <- make_model("X->Y")
data <- make_data(make_model("X->Y"), n = 100, parameters = c(.5, .5, .1,.1,.7,.1))

fully_confounded <- update_model(model, data)

4.4.2 Non binary data

In principle the stan model could be extended to handle non binary data. Though a limitation of the current package there is no structural reason why nodes should be constrained to be dichotomous. The set of nodal and causal types however expands even more rapidly in the case of non binary data. .