# number of samples
n_batches <- 1

# origin time
num_taxa ~ unif(n=n_batches, min=10, max=500)

# constant-node variable ('birth_rate0_t0') assignment
# first epoch (old)
birth_rate0_t0 ~ unif(n=n_batches, min=0.50, max=1.00)
death_rate0_t0 <- 0.1
birth_rate1_t0 <- 0.5
death_rate1_t0 <- 0.1
trans_rate01_t0 <- 0.2
trans_rate10_t0 <- 0.2

# second epoch (young)
birth_rate0_t1 ~ unif(n=n_batches, min=0.25, max=0.50)
death_rate0_t1 <- 0.1
birth_rate1_t1 <- 0.5
death_rate1_t1 <- 0.1
trans_rate01_t1 <- 0.2
trans_rate10_t1 <- 0.2

# deterministic-node variable ('det_birth_rate0_t0') assignment
# second epoch (old)
det_birth_rate0_t0 := sse_rate(name="lambda0_t0", value=birth_rate0_t0, states=[0,0,0], event="w_speciation", epoch=1)
det_death_rate0_t0 := sse_rate(name="mu0_t0", value=death_rate0_t0, states=[0], event="extinction", epoch=1)
det_birth_rate1_t0 := sse_rate(name="lambda1_t0", value=birth_rate1_t0, states=[1,1,1], event="w_speciation", epoch=1)
det_death_rate1_t0 := sse_rate(name="mu1_t0", value=death_rate1_t0, states=[1], event="extinction", epoch=1)
det_trans_rate01_t0 := sse_rate(name="q01_t0", value=trans_rate01_t0, states=[0,1], event="transition", epoch=1)
det_trans_rate10_t0 := sse_rate(name="q10_t0", value=trans_rate10_t0, states=[1,0], event="transition", epoch=1)

# first epoch (young)
det_birth_rate0_t1 := sse_rate(name="lambda0_t1", value=birth_rate0_t1, states=[0,0,0], event="w_speciation", epoch=2)
det_death_rate0_t1 := sse_rate(name="mu0_t1", value=death_rate0_t1, states=[0], event="extinction", epoch=2)
det_birth_rate1_t1 := sse_rate(name="lambda1_t1", value=birth_rate1_t1, states=[1,1,1], event="w_speciation", epoch=2)
det_death_rate1_t1 := sse_rate(name="mu1_t1", value=death_rate1_t1, states=[1], event="extinction", epoch=2)
det_trans_rate01_t1 := sse_rate(name="q01_t1", value=trans_rate01_t1, states=[0,1], event="transition", epoch=2)
det_trans_rate10_t1 := sse_rate(name="q10_t1", value=trans_rate10_t1, states=[1,0], event="transition", epoch=2)

# deterministic-node variable ('sse_stash') assignment
stash := sse_stash(flat_rate_mat=[det_birth_rate0_t0, det_death_rate0_t0, det_birth_rate1_t0, det_death_rate1_t0, det_trans_rate01_t0, det_trans_rate10_t0, det_birth_rate0_t1, det_death_rate0_t1, det_birth_rate1_t1, det_death_rate1_t1, det_trans_rate01_t1, det_trans_rate10_t1], n_states=2, n_epochs=2, epoch_age_ends=[2.0], seed_age=10)

# stochastic-node variable ('trs') assignment
trs ~ discrete_sse(n=n_batches, nr=1, stash=stash, start_state=[0], stop="age", stop_value=10, origin="true", runtime_limit=10, min_rec_taxa=10, max_rec_taxa=500)