Quantcast
Channel: Active questions tagged python - Stack Overflow
Viewing all articles
Browse latest Browse all 23247

Problem with Matrix Multiplication in cmdstanpy

$
0
0

I am trying to do a Bayesian linear regression of a target variable (pred_Factor) with the independent variable (expl_factor), where I want predictions for each category combination coming from variable cat_1 (which has 2 categories coded, 0, and 1) and cat_2 (which has 4 categories coded, 0-3).

I have used this coding:

model_formula = "pred_Factor ~ 0 + expl_factor * cat_1 * cat_2 - cat_2 - cat_1 - cat_2 :cat_1 "code = f"""data {{    int<lower=0> N;              // Number of data points    vector[N] y;                 // Response variable    matrix[N, 3] X;              // Predictor matrix}}parameters {{    real<lower=0> beta;          // Population estimate    vector[3] b;                 // Slope coefficients for interactions    real<lower=0> sigma;         // Error scale}}model {{    beta ~ normal(1, 5);                       // Prior for population estimate    b ~ normal(1, 5);                           // Prior for slope coefficients    sigma ~ student_t(3, 0, 5);                 // Prior for error scale    y ~ normal(X * (beta + X * b), sigma);  // Likelihood with interaction terms}}"""stan_file = "model.stan"with open(stan_file, "w") as file:    file.write(code)model = cmdstanpy.CmdStanModel(stan_file=stan_file)data = {"N": len(df),"y": df["pred_Factor"].values,"X": df[["expl_factor", "cat_1", "cat_2"]].values}fit = model.sample(    data=data,    seed=20240122,    chains=2,    iter_warmup=1000,    iter_sampling=4000,    max_treedepth=12,)print(fit.summary())

But I keep getting this error message:

RuntimeError: Error during sampling:Exception: multiply: Columns of m1 (3) and Rows of m2 (121) must match in size (in 'model.stan', line 18, column 4 to column 42)Exception: multiply: Columns of m1 (3) and Rows of m2 (121) must match in size (in 'model.stan', line 18, column 4 to column 42)

I have checked the shape of X and Y and it is as expected:

Shape of X: (121, 3)Shape of y: (121,)

I have also tried to modify the y ~ normal(X * (beta + X * b), sigma) to y ~ normal(X * (beta + X[:,1:3] * b), sigma) but still, the error persists.


Viewing all articles
Browse latest Browse all 23247

Trending Articles



<script src="https://jsc.adskeeper.com/r/s/rssing.com.1596347.js" async> </script>