#! /bin/usr/env rb

# example command string to simulate for "sim.1" to "sim.10"
# cd ~/projects/phyddle/scripts
# rb ./sim/Rev/sim_one.Rev --args ../workspace/simulate/Rev_example/sim.0 1 10

if (exists("args")) {
    out_path   = args[1]
    prefix     = args[2]
    start_idx  = args[3]
    batch_size = args[4]
    rep_idx    = start_idx:(start_idx+batch_size-1)
    seed(start_idx)
} else {
    rep_idx = 0:9
    out_path = "./"
}

# dataset dimensions
num_char <- 3
num_states <- 2
num_ranges <- num_states^num_char - 1
min_taxa <- 2
max_taxa <- 1000
tree_width <- 300
# min_time <- 0.1
# max_time <- 5.0

# model parameters
#sample_frac ~ dnUnif(0.1, 1.0)
rho_w ~ dnLoguniform(0.05, 3.0)
rho_e ~ dnLoguniform(0.05, rho_w)
rho_d ~ dnLoguniform(0.05, 3.0)
rho_b ~ dnLoguniform(0.05, 3.0)
num_taxa ~ dnUniformInteger(min_taxa, max_taxa)
root_age <- 1.0
# root_age ~ dnUniform(min_time, max_time)


#for (i in 1:num_char) {
#    g_w[i] ~ dnExp(1)
#    for (j in 1:num_char) {
#        g_b[i][j] ~ dnExp(1)
#    }
#}

#m_w := fnRelativeFeatureRates( g_w, phi_w, sigma_w )

# anagenetic rates
for (i in 1:num_char) {
    er[i] := rho_e
    for (j in 1:num_char) {
        dr[i][j] <- 0.0
        if (i != j) {
            dr[i][j] := rho_d
        }
    }
}
Q := fnBiogeographyRateMatrix(dispersalRates=dr, extirpationRates=er)

# extinction rates
for (i in 1:num_ranges) {
    mu[i] <- abs(0.)
    if (i <= num_char) {
        mu[i] := rho_e
    }
}

# cladogenetic rates
speciation_rates := [ rho_w, rho_b ]
for (i in 1:num_char) {
    z_e[i] <- 1.
    z_w[i] <- 1.
    for (j in 1:num_char) {
        z_d[i][j] <- 1.0
        z_b[i][j] <- 1.0
    }
}
R := fnBiogeographyCladoEventsBD(
    speciation_rates=speciation_rates,
    within_region_features=z_w,
    between_region_features=z_b
)
                           

# loop over replicates
for (i in 1:rep_idx.size()) {
    idx = rep_idx[i]

    # filenames
    tmp_fn = out_path + "/" + prefix + "." + idx
    phy_fn = tmp_fn + ".tre"
    dat_fn = tmp_fn + ".dat.nex"
    lbl_fn = tmp_fn + ".labels.csv"

    # redraw parameter values
    # root_age.redraw()
    num_taxa.redraw()
    rho_w.redraw()
    rho_e.redraw()
    rho_b.redraw()
    rho_d.redraw()
    if (num_taxa <= tree_width) {
        sample_frac <- 1.0
    } else {
        sample_frac <- tree_width / num_taxa
    }

    # phylogeny
    phy ~ dnCDBDP(
        rootAge=root_age,
        speciationRates=R,
        extinctionRates=mu,
        rho=sample_frac,
        Q=Q,
        condition="survival",
        simulateCondition="numTips",
        exactNumLineages=num_taxa,
        pruneExtinctLineages=true)

    # print simulation summary (optional)
    print(i, num_taxa, phy.rootAge(), rho_w, rho_e, rho_b, rho_d, sample_frac)

    # save data matrix
    dat <- phy.getCharData()
     
    # make training labels
    label_str = "log10_rho_w,log10_rho_e,log10_rho_d,log10_rho_b,sample_frac\n"
    label_values = [ rho_w, rho_e, rho_d, rho_b, sample_frac ]
    for (i in 1:label_values.size()) {
        v = label_values[i]
        if (i > 1 && i <= 4) {
            v = log(v, base=10)
        }
        if (i != 1) {
            label_str += ","
        }
        label_str += ""+v
    }
    label_str += "\n"

    # save output
    write(phy, filename=phy_fn)
    writeNexus(data=dat, filename=dat_fn)
    write(label_str, filename=lbl_fn)
}

# done!
quit()