#! /bin/usr/env rb

# example command string to simulate for "out.1" to "out.10"
# cd ~/projects/phyddle/workspace/mol_revbayes
# rb sim_mol.Rev --args ./simulate out 1 10

if (exists("args")) {
    out_path   = args[1]
    prefix     = args[2]
    start_idx  = args[3]
    batch_size = args[4]
} else {
    out_path   = "./simulate"
    prefix     = "out"
    start_idx  = 0
    batch_size = 10
}
rep_idx    = start_idx:(start_idx+batch_size-1)
seed(start_idx)

# dataset dimensions
num_char <- 10
num_states <- 4
min_taxa <- 2
max_taxa <- 1000
tree_width <- 300

# model parameters
mu ~ dnLoguniform(0.1,1.0)
kappa ~ dnGamma(2,2)
birth ~ dnExp(10)

num_taxa ~ dnUniformInteger(min_taxa, max_taxa)
root_age <- 1.0
Q := fnK80(kappa)

# loop over replicates
for (i in 1:rep_idx.size()) {
    idx = rep_idx[i]

    # filenames
    tmp_fn = out_path + "/" + prefix + "." + idx
    phy_fn = tmp_fn + ".tre"
    dat_fn = tmp_fn + ".dat.nex"
    lbl_fn = tmp_fn + ".labels.csv"

    # redraw parameter values
    num_taxa.redraw()
    mu.redraw()
    kappa.redraw()
    if (num_taxa <= tree_width) {
        sample_frac <- 1.0
    } else {
        sample_frac <- tree_width / num_taxa
    }

    # taxa
    if (exists("taxa")) { clear(taxa) }
    for (i in 1:num_taxa) {
        taxa[i] = taxon("T"+i)
    }

    # phylogeny
    phy ~ dnBDP(
        rootAge=root_age,
        lambda=birth,
        mu=0.0,
        rho=sample_frac,
        condition="nTaxa",
        taxa=taxa)

    # characters
    dat ~ dnPhyloCTMC(
        tree=phy,
        Q=Q,
        branchRates=mu,
        nSites=num_char,
        type="Standard")

    # print simulation summary (optional)
    print(i, num_taxa, phy.rootAge(), mu, kappa, birth, sample_frac)

    # make training labels
    label_str = "mu,kappa,sample_frac\n"
    label_values = [ mu, kappa, sample_frac ]
    for (i in 1:label_values.size()) {
        v = label_values[i]
        if (i > 1 && i <= 4) {
            v = log(v, base=10)
        }
        if (i != 1) {
            label_str += ","
        }
        label_str += ""+v
    }
    label_str += "\n"

    # save output
    write(phy, filename=phy_fn)
    writeNexus(data=dat, filename=dat_fn)
    write(label_str, filename=lbl_fn)
}

# done!
quit()