psborrow2
psborrow2 copied to clipboard
piecewise exponential
We can use this thread to discuss the PEM outcome distribution.
What would be helpful is:
- function name
- arguments
- default behavior
- common errors to anticipate
AFter this we can start breaking it up and writing the code.
Good function to look at would be exp_surv_dist()
And a good first step would be cloning the repo and getting devtools::test() to run.
I think this is what I got working from Manoj's original script:
#This R script fits a commensurate prior model using rstan.
# Date: 08/04/2023
library(psborrow2)
library(WeightIt)
library(R2jags)
example_dataframe <- as.data.frame(example_matrix)
example_dataframe$int <- 1 - example_dataframe$ext
weightit_model <- weightit(
int ~ cov1 + cov2 + cov3 + cov4,
data = example_dataframe,
method = "gbm",
estimand = "ATT"
)
example_dataframe$att <- weightit_model$weights
#Trial Data
trial.data <- example_dataframe[example_dataframe$int==1,]
#External Control Data
ext.control.data <- example_dataframe[example_dataframe$int==0,]
#Input in STAN code
nE <- nrow(ext.control.data) #Sample size in external control data
nT <- nrow(trial.data) #Sample size in trial data
time <- trial.data$time #Survival time in trial data
timeE <- ext.control.data$time #Survival time in external control data
Z <- trial.data$trt #Treatment indicator in trial data
status <- trial.data$status #Event indicator in trial data
statusE <- ext.control.data$status #Event indicator in external control data
wt <- trial.data$att #Vector of weights in trial data
wtE <- ext.control.data$att #Vector of weights in external control data
X <- trial.data[,paste("cov",1:4,sep="")] #Covariates in trial data
XE <- ext.control.data[,paste("cov",1:4,sep="")] #Covariates in external control data
Nbetas <- ncol(X)
#Creating zero vector for using zeros trick
zeros = rep(0, nT)
zerosE = rep(0, nE)
# Time axis partition
K <- 5 # number of intervals
#Cut points
a=c(0,quantile(trial.data$time[trial.data$status==1],seq(0,1,by=1/K))[-c(1,K+1)],
max(c(trial.data$time,ext.control.data$time))+0.0001)
#Trial data
# int.obs: vector that tells us at which interval each observation is
int.obs <- matrix(data = NA, nrow = nrow(trial.data), ncol = length(a) - 1)
d <- matrix(data = NA, nrow = nrow(trial.data), ncol = length(a) - 1)
for(i in 1:nrow(trial.data)) {
for(k in 1:(length(a) - 1)) {
d[i, k] <- ifelse(trial.data$time[i] - a[k] > 0, 1, 0) * ifelse(a[k + 1] - trial.data$time[i] > 0, 1,
0)
int.obs[i, k] <- d[i, k] * k
}
}
int.obs <- rowSums(int.obs)
#External control data
# int.obs: vector that tells us at which interval each observation is
int.obsE <- matrix(data = NA, nrow = nrow(ext.control.data), ncol = length(a) - 1)
d <- matrix(data = NA, nrow = nrow(ext.control.data), ncol = length(a) - 1)
for(i in 1:nrow(ext.control.data)) {
for(k in 1:(length(a) - 1)) {
d[i, k] <- ifelse(ext.control.data$time[i] - a[k] > 0, 1, 0) * ifelse(a[k + 1] - ext.control.data$time[i] > 0, 1,
0)
int.obsE[i, k] <- d[i, k] * k
}
}
int.obsE <- rowSums(int.obsE)
#Using rstan
library(rstan)
d.jags <- list(nT=nT, nE=nE, time=time, timeE=timeE,
a=a, X=X, XE=XE,int_obs=int.obs,
int_obsE=int.obsE,Nbetas=Nbetas,
zeros=zeros,zerosE=zerosE,wt=wt,
status=status, statusE=statusE,wtE=wtE,Z=Z,K=K)
teststan <- "
data {
int<lower=1> nT; // Number of trial data points
int<lower=1> nE; // Number of external control data points
int<lower=1> K;
int<lower=1,upper=K> int_obs[nT]; // Number of observations per trial data point
int<lower=1,upper=K> int_obsE[nE]; // Number of observations per external control data point
real time[nT]; // Time points for trial data
real timeE[nE]; // Time points for external control data
vector[K + 1] a; // Thresholds for the cumulative hazard function for trial data
int<lower=1> Nbetas; // Number of covariates
matrix[nT, Nbetas] X; // Covariate matrix for trial data
matrix[nE, Nbetas] XE; // Covariate matrix for external control data
vector[nT] wt; // Weights for trial data
vector[nE] wtE; // Weights for external control data
vector[nT] status; // Status for trial data (1 for event, 0 for censored)
vector[nE] statusE; // Status for external control data (1 for event, 0 for censored)
vector[nT] Z;
}
parameters {
vector[K] alpha; // Coefficients for the cumulative hazard function
vector[K] alphaE;
vector[Nbetas] beta0; // Intercept for covariates
vector<lower=0>[Nbetas] tau; // Precision for covariates
vector[Nbetas] beta; // Coefficients for covariates
real gamma; // Treatment effect (log hazard ratio)
}
transformed parameters {
real hr;
hr = exp(gamma); // Calculate hazard ratio
}
model {
matrix[nT, K] cond; // Step function indicators for trial data
matrix[nT, K] HH; // Intermediate values for trial data
vector[nT] H; // Cumulative hazard function for trial data
matrix[nE, K] condE; // Step function indicators for external control data
matrix[nE, K] HHE; // Intermediate values for external control data
vector[nE] HE; // Cumulative hazard function for external control data
vector[nT] elinpred;
vector[nE] elinpredE;
vector[nT] logHaz;
vector[nE] logHazE;
vector[nT] logSurv;
vector[nE] logSurvE;
// Trial Data
for (i in 1:nT) {
for (k in 1:int_obs[i]) {
cond[i, k] = step(time[i] - a[k + 1]);
HH[i, k] = cond[i, k] * (a[k + 1] - a[k]) * exp(alpha[k]) +
(1 - cond[i, k]) * (time[i] - a[k]) * exp(alpha[k]);
}
H[i] = sum(HH[i, 1:int_obs[i]]);
elinpred[i] = exp(X[i,] * beta + gamma*Z[i]);
logHaz[i] = log(exp(alpha[int_obs[i]]) * elinpred[i]);
logSurv[i] = -H[i] * elinpred[i];
target += wt[i] * status[i] * logHaz[i] + wt[i] * logSurv[i];
}
// External Control Data
for (i in 1:nE) {
for (k in 1:int_obsE[i]) {
condE[i, k] = step(timeE[i] - a[k + 1]);
HHE[i, k] = condE[i, k] * (a[k + 1] - a[k]) * exp(alphaE[k]) +
(1 - condE[i, k]) * (timeE[i] - a[k]) * exp(alphaE[k]);
}
HE[i] = sum(HHE[i, 1:int_obsE[i]]);
elinpredE[i] = exp(XE[i,] * beta0);
logHazE[i] = log(wtE[i] * exp(alphaE[int_obsE[i]]) * elinpredE[i]);
logSurvE[i] = -HE[i] * elinpredE[i];
target += wtE[i] * statusE[i] * logHazE[i] + wtE[i] * logSurvE[i];
}
}
generated quantities {
real hr_out;
hr_out = hr;
}
"
st <- Sys.time()
fit <- stan(model_code=teststan,data=d.jags,iter=5000,chains = 3,warmup=3000)
en <- Sys.time()