/* * lrnnSMDDS - linear RNN/Reservoir hybrid for CPU (C Implementation) * * PolyForm Noncommercial License 1.0.0 * https://polyformproject.org/licenses/noncommercial/1.0.0/ * * ====================================================== * * Features: * S. SwiGLU in Channel Mixing (more coherence) * M. Multi-Scale Token Shift (larger context/"infinite") * D. Data-Dependent Decay with Low-Rank (speed in large context) * D. Dynamic State Checkpointing (faster/linear generation) * S. Slot-memory resorvoir (perfect recall, transformers style, legacy/proven) * * Compile on cygwin on Windows (or POSIX linux, gcc is needed!): * * gcc -std=c17 -O3 -march=native -Wall --fast-math -Wextra -o lrnn aismdd.c * * Usage: * Training: ./lrnn --train corpus.txt --save model.bin --epochs 20 * Generate: ./lrnn --load model.bin --seed "Hello world" --tokens 200 */ #define _POSIX_C_SOURCE 200809L #include #include #include #include #include #include #include #include #include #include /* ============================================================ * Constants and Configuration * ============================================================ */ #define MAX_VOCAB_SIZE 65536 #define MAX_LAYERS 32 #define EPSILON 1e-6f #define GRAD_CLIP 50.0f typedef struct { int vocab_size; int n_layer; int n_embd; int n_head; int ctx_len; int decay_lora_rank; float ffn_multiplier; int n_mem_slots; } lrnnConfig; static lrnnConfig default_config(void) { lrnnConfig cfg = { .vocab_size = 256, // overwritten by build_vocabulary() .n_layer = 2, .n_embd = 64, .n_head = 1, .ctx_len = 64, .decay_lora_rank = 2, .ffn_multiplier = 2.0f, .n_mem_slots = 4 }; return cfg; } static inline int ffn_hidden(const lrnnConfig *cfg) { return (int)(cfg->n_embd * cfg->ffn_multiplier); } //forward help /* ============================================================ * Tensor Structure and Operations * ============================================================ */ typedef struct { float *data; int rows; int cols; int size; } Tensor; static Tensor tensor_alloc(int rows, int cols) { Tensor t; t.rows = rows; t.cols = cols; t.size = rows * cols; t.data = NULL; if (t.size > 0) { t.data = (float *)calloc((size_t)t.size, sizeof(float)); if (!t.data) { fprintf(stderr, "Error: tensor allocation failed (%d x %d)\n", rows, cols); exit(1); } } return t; } static Tensor tensor_alloc_1d(int size) { return tensor_alloc(size, 1); } static void tensor_free(Tensor *t) { if (t && t->data) { free(t->data); t->data = NULL; } if (t) { t->rows = t->cols = t->size = 0; } } static void tensor_copy(Tensor *dst, const Tensor *src) { if (dst->size != src->size) { fprintf(stderr, "Error: tensor_copy size mismatch (%d vs %d)\n", dst->size, src->size); exit(1); } if (src->size > 0) { memcpy(dst->data, src->data, (size_t)src->size * sizeof(float)); } } static void tensor_fill(Tensor *t, float val) { for (int i = 0; i < t->size; i++) { t->data[i] = val; } } static void tensor_zero(Tensor *t) { if (t->data && t->size > 0) { memset(t->data, 0, (size_t)t->size * sizeof(float)); } } /* Random initialization */ static float randn(void) { /* Box-Muller transform */ float u1 = ((float)rand() + 1.0f) / ((float)RAND_MAX + 2.0f); float u2 = ((float)rand() + 1.0f) / ((float)RAND_MAX + 2.0f); return sqrtf(-2.0f * logf(u1)) * cosf(2.0f * 3.14159265f * u2); } static float rand_uniform(float lo, float hi) { return lo + ((float)rand() / (float)RAND_MAX) * (hi - lo); } static void tensor_randn(Tensor *t, float scale) { for (int i = 0; i < t->size; i++) { t->data[i] = randn() * scale; } } static void tensor_rand_uniform(Tensor *t, float lo, float hi) { for (int i = 0; i < t->size; i++) { t->data[i] = rand_uniform(lo, hi); } } /* ============================================================ * Activation Functions * ============================================================ */ static inline float sigmoid_f(float x) { return 1.0f / (1.0f + expf(-x)); } static inline float silu_f(float x) { return x * sigmoid_f(x); } static inline float clamp_f(float x, float lo, float hi) { if (x < lo) return lo; if (x > hi) return hi; return x; } static void sigmoid_vec(float *out, const float *in, int n) { for (int i = 0; i < n; i++) { out[i] = sigmoid_f(in[i]); } } static void exp_vec(float *out, const float *in, int n) { for (int i = 0; i < n; i++) { out[i] = expf(clamp_f(in[i], -10.0f, 10.0f)); } } static void softmax_vec(float *out, const float *in, int n) { float max_val = in[0]; for (int i = 1; i < n; i++) { if (in[i] > max_val) max_val = in[i]; } float sum = 0.0f; for (int i = 0; i < n; i++) { out[i] = expf(in[i] - max_val); sum += out[i]; } float inv_sum = 1.0f / (sum + EPSILON); for (int i = 0; i < n; i++) { out[i] *= inv_sum; } } /* ============================================================ * Vector/Matrix Operations * ============================================================ */ /* out = a + b (element-wise) */ static void vec_add(float *out, const float *a, const float *b, int n) { for (int i = 0; i < n; i++) { out[i] = a[i] + b[i]; } } /* out = a * b (element-wise) */ static void vec_mul(float *out, const float *a, const float *b, int n) { for (int i = 0; i < n; i++) { out[i] = a[i] * b[i]; } } /* out = x @ W where x: (1, in_dim), W: (in_dim, out_dim) -> out: (1, out_dim) */ static void matvec(float *out, const float *x, const Tensor *W) { int in_dim = W->rows; int out_dim = W->cols; for (int j = 0; j < out_dim; j++) { float sum = 0.0f; for (int i = 0; i < in_dim; i++) { sum += x[i] * W->data[i * out_dim + j]; } out[j] = sum; } } /* out = X @ W where X: (seq_len, in_dim), W: (in_dim, out_dim) -> out: (seq_len, out_dim) */ static void matmul(Tensor *out, const Tensor *X, const Tensor *W) { int seq_len = X->rows; int in_dim = X->cols; int out_dim = W->cols; if (W->rows != in_dim) { fprintf(stderr, "matmul dimension mismatch: X(%d,%d) @ W(%d,%d)\n", X->rows, X->cols, W->rows, W->cols); exit(1); } for (int s = 0; s < seq_len; s++) { for (int j = 0; j < out_dim; j++) { float sum = 0.0f; for (int i = 0; i < in_dim; i++) { sum += X->data[s * in_dim + i] * W->data[i * out_dim + j]; } out->data[s * out_dim + j] = sum; } } } /* Layer normalization: out = weight * (x - mean) / sqrt(var + eps) + bias */ static void layer_norm(float *out, const float *x, const float *weight, const float *bias, int n) { float mean = 0.0f; for (int i = 0; i < n; i++) { mean += x[i]; } mean /= (float)n; float var = 0.0f; for (int i = 0; i < n; i++) { float d = x[i] - mean; var += d * d; } var /= (float)n; float inv_std = 1.0f / sqrtf(var + EPSILON); for (int i = 0; i < n; i++) { out[i] = weight[i] * (x[i] - mean) * inv_std + bias[i]; } } /* Layer norm for sequence: (seq_len, n_embd) */ static void layer_norm_seq(Tensor *out, const Tensor *x, const Tensor *weight, const Tensor *bias) { int seq_len = x->rows; int n_embd = x->cols; for (int s = 0; s < seq_len; s++) { layer_norm(out->data + s * n_embd, x->data + s * n_embd, weight->data, bias->data, n_embd); } } /* ============================================================ * ALiBi - Attention with Linear Biases * ============================================================ * Each head h gets slope m_h = 1 / 2^(8h/H) where H = n_head. * For token at position t attending to position s: * bias = -m_h * |t - s| * * In WKV recurrence, this manifests as an additional * geometric decay per head that compounds with the * data-dependent decay. * ============================================================ */ static void compute_alibi_slopes(float *slopes, int n_head) { /* slopes[h] = 2^(-8*(h+1)/n_head) */ for (int h = 0; h < n_head; h++) { float exponent = -8.0f * (float)(h + 1) / (float)n_head; slopes[h] = powf(2.0f, exponent); } } /* ============================================================ * Layer Parameters Structure * ============================================================ */ typedef struct { /* Layer norms */ Tensor ln1_weight, ln1_bias; Tensor ln2_weight, ln2_bias; /* Multi-scale token shift */ Tensor time_shift_w1, time_shift_w2, time_shift_w4; /* Token mixing ratios */ Tensor time_mix_r, time_mix_k, time_mix_v; /* Data-dependent decay (low-rank) */ Tensor decay_lora_a, decay_lora_b; Tensor decay_base; Tensor time_first; /* Projections */ Tensor Wr, Wk, Wv, Wo; /* Channel mix */ Tensor channel_mix; Tensor ffn_gate, ffn_up, ffn_down; Tensor alibi_slopes; Tensor mem_gate_write; /* (n_embd, n_mem_slots) - write gate */ Tensor mem_gate_read; /* (n_embd, n_mem_slots) - read gate */ } LayerParams; typedef struct { Tensor emb; Tensor ln0_weight, ln0_bias; LayerParams *layers; Tensor ln_out_weight, ln_out_bias; Tensor head; int n_layers; } ModelParams; /* Layer state for generation */ typedef struct { Tensor x_prev_1, x_prev_2, x_prev_3, x_prev_4; Tensor wkv_num; /* (n_mem_slots, n_embd) */ Tensor wkv_den; /* (n_mem_slots, n_embd) */ Tensor ffn_prev; } LayerState; typedef struct { LayerState *layers; int n_layers; } ModelState; /* Vocabulary */ typedef struct { char chars[MAX_VOCAB_SIZE]; int char_to_idx[256]; int size; } Vocabulary; /* ============================================================ * Word-Level Tokenizer Structures * ============================================================ */ typedef enum { TOKENIZER_CHAR, TOKENIZER_WORD, TOKENIZER_AUTO } TokenizerType; #define MAX_WORD_LEN 64 #define MAX_WORDS 32768 #define WORD_HASH_SIZE 65536 typedef struct { char **words; /* Array of word strings */ int *hash_table; /* Hash table: hash -> word index (-1 if empty) */ int *hash_keys; /* Hash table: stored hash for collision detection */ int size; /* Number of words */ int capacity; /* Allocated capacity */ int unk_idx; /* token index */ int pad_idx; /* token index */ int space_idx; /* Space token index */ int newline_idx; /* Newline token index */ } WordVocabulary; typedef struct { TokenizerType type; Vocabulary char_vocab; /* Used if type == TOKENIZER_CHAR */ WordVocabulary word_vocab; /* Used if type == TOKENIZER_WORD */ } Tokenizer; /* ============================================================ * Forward Declarations * ============================================================ */ /* Word vocabulary functions */ static void init_word_vocabulary(WordVocabulary *wv); static void free_word_vocabulary(WordVocabulary *wv); static int word_vocab_find(const WordVocabulary *wv, const char *word); static int word_vocab_add(WordVocabulary *wv, const char *word); static void build_word_vocabulary(WordVocabulary *wv, const char *text, size_t len); static int *tokenize_words(const char *text, size_t len, const WordVocabulary *wv, int *out_len); static const char *decode_word_token(int token, const WordVocabulary *wv); /* Character vocabulary functions */ static void build_vocabulary(Vocabulary *vocab, const char *text, size_t len); static int *tokenize(const char *text, size_t len, const Vocabulary *vocab, int *out_len); /* Tokenizer interface */ static void init_tokenizer(Tokenizer *tok, TokenizerType type); static void free_tokenizer(Tokenizer *tok); static void build_tokenizer(Tokenizer *tok, const char *text, size_t len, TokenizerType requested_type); static int tokenizer_vocab_size(const Tokenizer *tok); static int *tokenizer_encode(const Tokenizer *tok, const char *text, size_t len, int *out_len); static void tokenizer_decode_token(const Tokenizer *tok, int token, char *out, int out_size); /* ============================================================ * Parameter Initialization * ============================================================ */ static void init_layer_params(LayerParams *lp, const lrnnConfig *cfg, int layer_idx) { int n_embd = cfg->n_embd; int ffn_h = ffn_hidden(cfg); int lora_rank = cfg->decay_lora_rank; float proj_scale = 0.02f / sqrtf((float)cfg->n_layer); float ffn_scale = proj_scale; //reservoir/slots initialization int n_slots = cfg->n_mem_slots; lp->mem_gate_write = tensor_alloc(n_embd, n_slots); tensor_randn(&lp->mem_gate_write, 0.01f); lp->mem_gate_read = tensor_alloc(n_embd, n_slots); tensor_randn(&lp->mem_gate_read, 0.01f); /* Layer norms */ lp->ln1_weight = tensor_alloc_1d(n_embd); tensor_fill(&lp->ln1_weight, 1.0f); lp->ln1_bias = tensor_alloc_1d(n_embd); lp->ln2_weight = tensor_alloc_1d(n_embd); tensor_fill(&lp->ln2_weight, 1.0f); lp->ln2_bias = tensor_alloc_1d(n_embd); /* Multi-scale token shift */ lp->time_shift_w1 = tensor_alloc_1d(n_embd); tensor_rand_uniform(&lp->time_shift_w1, 0.3f, 0.7f); lp->time_shift_w2 = tensor_alloc_1d(n_embd); tensor_rand_uniform(&lp->time_shift_w2, 0.1f, 0.3f); lp->time_shift_w4 = tensor_alloc_1d(n_embd); tensor_rand_uniform(&lp->time_shift_w4, 0.0f, 0.2f); /* Token mixing */ lp->time_mix_r = tensor_alloc_1d(n_embd); tensor_fill(&lp->time_mix_r, 0.5f); lp->time_mix_k = tensor_alloc_1d(n_embd); tensor_fill(&lp->time_mix_k, 0.5f); lp->time_mix_v = tensor_alloc_1d(n_embd); tensor_fill(&lp->time_mix_v, 0.5f); /* Decay LoRA */ lp->decay_lora_a = tensor_alloc(n_embd, lora_rank); tensor_randn(&lp->decay_lora_a, 0.01f); lp->decay_lora_b = tensor_alloc(lora_rank, n_embd); tensor_randn(&lp->decay_lora_b, 0.01f); /* Per-head initialization for better multi-head diversity */ lp->decay_base = tensor_alloc_1d(n_embd); { int hdim_val = n_embd / cfg->n_head; for (int h = 0; h < cfg->n_head; h++) { float base_val = 1.5f - 0.1f * layer_idx - 0.2f * h; for (int d = 0; d < hdim_val; d++) { lp->decay_base.data[h * hdim_val + d] = base_val; } } } lp->time_first = tensor_alloc_1d(n_embd); { int hdim_val = n_embd / cfg->n_head; for (int h = 0; h < cfg->n_head; h++) { float tf_val = -3.0f + layer_idx * 0.3f + h * 0.5f; for (int d = 0; d < hdim_val; d++) { lp->time_first.data[h * hdim_val + d] = tf_val; } } } /* Projections */ lp->Wr = tensor_alloc(n_embd, n_embd); tensor_randn(&lp->Wr, proj_scale); lp->Wk = tensor_alloc(n_embd, n_embd); tensor_randn(&lp->Wk, proj_scale); lp->Wv = tensor_alloc(n_embd, n_embd); tensor_randn(&lp->Wv, proj_scale); lp->Wo = tensor_alloc(n_embd, n_embd); tensor_randn(&lp->Wo, proj_scale); /* Channel mix */ lp->channel_mix = tensor_alloc_1d(n_embd); tensor_fill(&lp->channel_mix, 0.5f); lp->ffn_gate = tensor_alloc(n_embd, ffn_h); tensor_randn(&lp->ffn_gate, ffn_scale); lp->ffn_up = tensor_alloc(n_embd, ffn_h); tensor_randn(&lp->ffn_up, ffn_scale); lp->ffn_down = tensor_alloc(ffn_h, n_embd); tensor_randn(&lp->ffn_down, ffn_scale); /* ALiBi slopes - fixed, not learned */ lp->alibi_slopes = tensor_alloc_1d(cfg->n_head); compute_alibi_slopes(lp->alibi_slopes.data, cfg->n_head); } static void free_layer_params(LayerParams *lp) { tensor_free(&lp->ln1_weight); tensor_free(&lp->ln1_bias); tensor_free(&lp->ln2_weight); tensor_free(&lp->ln2_bias); tensor_free(&lp->time_shift_w1); tensor_free(&lp->time_shift_w2); tensor_free(&lp->time_shift_w4); tensor_free(&lp->time_mix_r); tensor_free(&lp->time_mix_k); tensor_free(&lp->time_mix_v); tensor_free(&lp->decay_lora_a); tensor_free(&lp->decay_lora_b); tensor_free(&lp->decay_base); tensor_free(&lp->time_first); tensor_free(&lp->Wr); tensor_free(&lp->Wk); tensor_free(&lp->Wv); tensor_free(&lp->Wo); tensor_free(&lp->channel_mix); tensor_free(&lp->ffn_gate); tensor_free(&lp->ffn_up); tensor_free(&lp->ffn_down); tensor_free(&lp->alibi_slopes); tensor_free(&lp->mem_gate_write); tensor_free(&lp->mem_gate_read); } static void init_model_params(ModelParams *mp, const lrnnConfig *cfg) { int n_embd = cfg->n_embd; int vocab_size = cfg->vocab_size; if (n_embd % cfg->n_head != 0) { fprintf(stderr, "Error: n_embd (%d) must be divisible by n_head (%d)\n", n_embd, cfg->n_head); exit(1); } mp->n_layers = cfg->n_layer; /* Embedding */ mp->emb = tensor_alloc(vocab_size, n_embd); tensor_randn(&mp->emb, 0.02f); /* Initial layer norm */ mp->ln0_weight = tensor_alloc_1d(n_embd); tensor_fill(&mp->ln0_weight, 1.0f); mp->ln0_bias = tensor_alloc_1d(n_embd); /* Layers */ mp->layers = (LayerParams *)calloc((size_t)cfg->n_layer, sizeof(LayerParams)); if (!mp->layers) { fprintf(stderr, "Error: failed to allocate layers\n"); exit(1); } for (int i = 0; i < cfg->n_layer; i++) { init_layer_params(&mp->layers[i], cfg, i); } /* Output */ mp->ln_out_weight = tensor_alloc_1d(n_embd); tensor_fill(&mp->ln_out_weight, 1.0f); mp->ln_out_bias = tensor_alloc_1d(n_embd); mp->head = tensor_alloc(n_embd, vocab_size); tensor_randn(&mp->head, 0.02f); } static void free_model_params(ModelParams *mp) { tensor_free(&mp->emb); tensor_free(&mp->ln0_weight); tensor_free(&mp->ln0_bias); for (int i = 0; i < mp->n_layers; i++) { free_layer_params(&mp->layers[i]); } free(mp->layers); mp->layers = NULL; tensor_free(&mp->ln_out_weight); tensor_free(&mp->ln_out_bias); tensor_free(&mp->head); } /* ============================================================ * State Management * ============================================================ */ static void init_layer_state(LayerState *ls, int n_embd, int n_mem_slots) { ls->x_prev_1 = tensor_alloc_1d(n_embd); ls->x_prev_2 = tensor_alloc_1d(n_embd); ls->x_prev_3 = tensor_alloc_1d(n_embd); ls->x_prev_4 = tensor_alloc_1d(n_embd); ls->wkv_num = tensor_alloc(n_mem_slots, n_embd); ls->wkv_den = tensor_alloc(n_mem_slots, n_embd); ls->ffn_prev = tensor_alloc_1d(n_embd); } static void free_layer_state(LayerState *ls) { tensor_free(&ls->x_prev_1); tensor_free(&ls->x_prev_2); tensor_free(&ls->x_prev_3); tensor_free(&ls->x_prev_4); tensor_free(&ls->wkv_num); tensor_free(&ls->wkv_den); tensor_free(&ls->ffn_prev); } static void init_model_state(ModelState *ms, const lrnnConfig *cfg) { ms->n_layers = cfg->n_layer; ms->layers = (LayerState *)calloc((size_t)cfg->n_layer, sizeof(LayerState)); if (!ms->layers) { fprintf(stderr, "Error: failed to allocate state\n"); exit(1); } for (int i = 0; i < cfg->n_layer; i++) { init_layer_state(&ms->layers[i], cfg->n_embd, cfg->n_mem_slots); } } static void free_model_state(ModelState *ms) { for (int i = 0; i < ms->n_layers; i++) { free_layer_state(&ms->layers[i]); } free(ms->layers); ms->layers = NULL; } static inline int head_dim(const lrnnConfig *cfg) { return cfg->n_embd / cfg->n_head; } /* ============================================================ * Forward Pass - Single Token (for Generation) * ============================================================ */ static void forward_single(float *logits, int token, const ModelParams *mp, ModelState *state, const lrnnConfig *cfg) { int n_embd = cfg->n_embd; int n_head = cfg->n_head; int hdim = n_embd / n_head; int ffn_h = ffn_hidden(cfg); int lora_rank = cfg->decay_lora_rank; float *x = (float *)malloc((size_t)n_embd * sizeof(float)); float *x_norm = (float *)malloc((size_t)n_embd * sizeof(float)); float *x_shifted= (float *)malloc((size_t)n_embd * sizeof(float)); float *xr = (float *)malloc((size_t)n_embd * sizeof(float)); float *xk = (float *)malloc((size_t)n_embd * sizeof(float)); float *xv = (float *)malloc((size_t)n_embd * sizeof(float)); float *r = (float *)malloc((size_t)n_embd * sizeof(float)); float *k = (float *)malloc((size_t)n_embd * sizeof(float)); float *v = (float *)malloc((size_t)n_embd * sizeof(float)); float *decay_delta = (float *)malloc((size_t)n_embd * sizeof(float)); float *decay = (float *)malloc((size_t)n_embd * sizeof(float)); float *k_exp = (float *)malloc((size_t)n_embd * sizeof(float)); float *time_first_val = (float *)malloc((size_t)n_embd * sizeof(float)); float *wkv = (float *)malloc((size_t)n_embd * sizeof(float)); float *tm_out = (float *)malloc((size_t)n_embd * sizeof(float)); float *xm = (float *)malloc((size_t)n_embd * sizeof(float)); float *gate = (float *)malloc((size_t)ffn_h * sizeof(float)); float *up = (float *)malloc((size_t)ffn_h * sizeof(float)); float *hidden = (float *)malloc((size_t)ffn_h * sizeof(float)); float *cm_out = (float *)malloc((size_t)n_embd * sizeof(float)); float *lora_tmp = (float *)malloc((size_t)lora_rank * sizeof(float)); float *w1_sig = (float *)malloc((size_t)n_embd * sizeof(float)); float *w2_sig = (float *)malloc((size_t)n_embd * sizeof(float)); float *w4_sig = (float *)malloc((size_t)n_embd * sizeof(float)); /* Token embedding */ memcpy(x, mp->emb.data + token * n_embd, (size_t)n_embd * sizeof(float)); /* Initial layer norm */ layer_norm(x, x, mp->ln0_weight.data, mp->ln0_bias.data, n_embd); for (int layer_idx = 0; layer_idx < mp->n_layers; layer_idx++) { const LayerParams *lp = &mp->layers[layer_idx]; LayerState *ls = &state->layers[layer_idx]; /* ============ TimeMix ============ */ layer_norm(x_norm, x, lp->ln1_weight.data, lp->ln1_bias.data, n_embd); /* Multi-scale shift */ sigmoid_vec(w1_sig, lp->time_shift_w1.data, n_embd); sigmoid_vec(w2_sig, lp->time_shift_w2.data, n_embd); sigmoid_vec(w4_sig, lp->time_shift_w4.data, n_embd); for (int i = 0; i < n_embd; i++) { float w_sum = w1_sig[i] + w2_sig[i] + w4_sig[i] + EPSILON; float nw1 = w1_sig[i] / w_sum; float nw2 = w2_sig[i] / w_sum; float nw4 = w4_sig[i] / w_sum; x_shifted[i] = nw1 * ls->x_prev_1.data[i] + nw2 * ls->x_prev_2.data[i] + nw4 * ls->x_prev_4.data[i]; } /* Mix current with shifted */ for (int i = 0; i < n_embd; i++) { float mr = sigmoid_f(lp->time_mix_r.data[i]); float mk = sigmoid_f(lp->time_mix_k.data[i]); float mv = sigmoid_f(lp->time_mix_v.data[i]); xr[i] = x_norm[i] * mr + x_shifted[i] * (1.0f - mr); xk[i] = x_norm[i] * mk + x_shifted[i] * (1.0f - mk); xv[i] = x_norm[i] * mv + x_shifted[i] * (1.0f - mv); } /* R, K, V projections */ matvec(r, xr, &lp->Wr); matvec(k, xk, &lp->Wk); matvec(v, xv, &lp->Wv); /* Data-dependent decay */ matvec(lora_tmp, x_norm, &lp->decay_lora_a); matvec(decay_delta, lora_tmp, &lp->decay_lora_b); for (int i = 0; i < n_embd; i++) { decay[i] = sigmoid_f(lp->decay_base.data[i] + decay_delta[i]); } /* Receptance gate */ sigmoid_vec(r, r, n_embd); for (int i = 0; i < n_embd; i++) { time_first_val[i] = expf(clamp_f(lp->time_first.data[i], -10.0f, 10.0f)); } exp_vec(k_exp, k, n_embd); /* ---- Multi-Head Multi-Slot WKV with ALiBi ---- */ { int n_slots = cfg->n_mem_slots; /* Compute write gates: softmax(x_norm @ mem_gate_write) -> (n_slots,) */ float *write_logits = (float *)malloc((size_t)n_slots * sizeof(float)); float *write_gates = (float *)malloc((size_t)n_slots * sizeof(float)); float *read_logits = (float *)malloc((size_t)n_slots * sizeof(float)); float *read_gates = (float *)malloc((size_t)n_slots * sizeof(float)); matvec(write_logits, x_norm, &lp->mem_gate_write); softmax_vec(write_gates, write_logits, n_slots); matvec(read_logits, x_norm, &lp->mem_gate_read); softmax_vec(read_gates, read_logits, n_slots); for (int h = 0; h < n_head; h++) { int base = h * hdim; float alibi_decay_h = expf(-lp->alibi_slopes.data[h]); for (int d = 0; d < hdim; d++) { int i = base + d; float kv = k_exp[i] * v[i]; /* Read: weighted sum across slots */ float read_num = 0.0f, read_den = 0.0f; for (int s = 0; s < n_slots; s++) { int si = s * n_embd + i; read_num += read_gates[s] * ls->wkv_num.data[si]; read_den += read_gates[s] * ls->wkv_den.data[si]; } /* WKV output with time_first boost */ float num = read_num + time_first_val[i] * kv; float den = read_den + time_first_val[i] * k_exp[i] + EPSILON; wkv[i] = num / den; /* Write: update each slot weighted by write gate */ for (int s = 0; s < n_slots; s++) { int si = s * n_embd + i; float wg = write_gates[s]; float combined = decay[i] * alibi_decay_h; /* Slot update: interpolate between decay and new info */ ls->wkv_num.data[si] = combined * ls->wkv_num.data[si] + wg * kv; ls->wkv_den.data[si] = combined * ls->wkv_den.data[si] + wg * k_exp[i]; } } } free(write_logits); free(write_gates); free(read_logits); free(read_gates); } /* Apply receptance and output projection */ vec_mul(wkv, r, wkv, n_embd); matvec(tm_out, wkv, &lp->Wo); vec_add(x, x, tm_out, n_embd); /* Update previous tokens */ tensor_copy(&ls->x_prev_4, &ls->x_prev_3); tensor_copy(&ls->x_prev_3, &ls->x_prev_2); tensor_copy(&ls->x_prev_2, &ls->x_prev_1); memcpy(ls->x_prev_1.data, x_norm, (size_t)n_embd * sizeof(float)); /* ============ ChannelMix ============ */ layer_norm(x_norm, x, lp->ln2_weight.data, lp->ln2_bias.data, n_embd); for (int i = 0; i < n_embd; i++) { float mix = sigmoid_f(lp->channel_mix.data[i]); xm[i] = x_norm[i] * mix + ls->ffn_prev.data[i] * (1.0f - mix); } /* SwiGLU */ matvec(gate, xm, &lp->ffn_gate); matvec(up, xm, &lp->ffn_up); for (int i = 0; i < ffn_h; i++) { hidden[i] = silu_f(gate[i]) * up[i]; } matvec(cm_out, hidden, &lp->ffn_down); vec_add(x, x, cm_out, n_embd); memcpy(ls->ffn_prev.data, x_norm, (size_t)n_embd * sizeof(float)); } /* Output layer norm and projection */ layer_norm(x, x, mp->ln_out_weight.data, mp->ln_out_bias.data, n_embd); matvec(logits, x, &mp->head); /* Cleanup */ free(x); free(x_norm); free(x_shifted); free(xr); free(xk); free(xv); free(r); free(k); free(v); free(decay_delta); free(decay); free(k_exp); free(time_first_val); free(wkv); free(tm_out); free(xm); free(gate); free(up); free(hidden); free(cm_out); free(lora_tmp); free(w1_sig); free(w2_sig); free(w4_sig); } /* ============================================================ * Loss Computation * ============================================================ */ static float cross_entropy_loss(const Tensor *logits, const int *targets, int n) { int vocab_size = logits->cols; float *probs = (float *)malloc((size_t)vocab_size * sizeof(float)); float total_loss = 0.0f; for (int t = 0; t < n; t++) { softmax_vec(probs, logits->data + t * vocab_size, vocab_size); int target = targets[t]; float p = probs[target]; if (p < EPSILON) p = EPSILON; total_loss -= logf(p); } free(probs); return total_loss / (float)n; } /* ============================================================ * File I/O * ============================================================ */ static void write_tensor(FILE *f, const Tensor *t) { fwrite(&t->rows, sizeof(int), 1, f); fwrite(&t->cols, sizeof(int), 1, f); fwrite(t->data, sizeof(float), (size_t)t->size, f); } static void read_tensor(FILE *f, Tensor *t) { int rows, cols; if (fread(&rows, sizeof(int), 1, f) != 1) return; if (fread(&cols, sizeof(int), 1, f) != 1) return; *t = tensor_alloc(rows, cols); if (fread(t->data, sizeof(float), (size_t)t->size, f) != (size_t)t->size) { fprintf(stderr, "Warning: incomplete tensor read\n"); } } static void write_layer_params(FILE *f, const LayerParams *lp) { write_tensor(f, &lp->ln1_weight); write_tensor(f, &lp->ln1_bias); write_tensor(f, &lp->ln2_weight); write_tensor(f, &lp->ln2_bias); write_tensor(f, &lp->time_shift_w1); write_tensor(f, &lp->time_shift_w2); write_tensor(f, &lp->time_shift_w4); write_tensor(f, &lp->time_mix_r); write_tensor(f, &lp->time_mix_k); write_tensor(f, &lp->time_mix_v); write_tensor(f, &lp->decay_lora_a); write_tensor(f, &lp->decay_lora_b); write_tensor(f, &lp->decay_base); write_tensor(f, &lp->time_first); write_tensor(f, &lp->Wr); write_tensor(f, &lp->Wk); write_tensor(f, &lp->Wv); write_tensor(f, &lp->Wo); write_tensor(f, &lp->channel_mix); write_tensor(f, &lp->ffn_gate); write_tensor(f, &lp->ffn_up); write_tensor(f, &lp->ffn_down); write_tensor(f, &lp->mem_gate_write); write_tensor(f, &lp->mem_gate_read); } static void read_layer_params(FILE *f, LayerParams *lp) { read_tensor(f, &lp->ln1_weight); read_tensor(f, &lp->ln1_bias); read_tensor(f, &lp->ln2_weight); read_tensor(f, &lp->ln2_bias); read_tensor(f, &lp->time_shift_w1); read_tensor(f, &lp->time_shift_w2); read_tensor(f, &lp->time_shift_w4); read_tensor(f, &lp->time_mix_r); read_tensor(f, &lp->time_mix_k); read_tensor(f, &lp->time_mix_v); read_tensor(f, &lp->decay_lora_a); read_tensor(f, &lp->decay_lora_b); read_tensor(f, &lp->decay_base); read_tensor(f, &lp->time_first); read_tensor(f, &lp->Wr); read_tensor(f, &lp->Wk); read_tensor(f, &lp->Wv); read_tensor(f, &lp->Wo); read_tensor(f, &lp->channel_mix); read_tensor(f, &lp->ffn_gate); read_tensor(f, &lp->ffn_up); read_tensor(f, &lp->ffn_down); read_tensor(f, &lp->mem_gate_write); read_tensor(f, &lp->mem_gate_read); } /* ============================================================ * File I/O (Updated for Hybrid Tokenizer) * ============================================================ */ static int save_model(const char *path, const ModelParams *mp, const lrnnConfig *cfg, const Tokenizer *tok) { FILE *f = fopen(path, "wb"); if (!f) { fprintf(stderr, "Error: cannot open %s for writing\n", path); return -1; } /* Magic and version (updated magic for new format) */ const char magic[] = "lrnnC02"; /* Version 2 for tokenizer support */ fwrite(magic, 1, 8, f); fwrite(cfg, sizeof(lrnnConfig), 1, f); /* Tokenizer type */ int tok_type = (int)tok->type; fwrite(&tok_type, sizeof(int), 1, f); /* Save vocabulary based on type */ if (tok->type == TOKENIZER_CHAR) { fwrite(&tok->char_vocab.size, sizeof(int), 1, f); fwrite(tok->char_vocab.chars, sizeof(char), (size_t)tok->char_vocab.size, f); fwrite(tok->char_vocab.char_to_idx, sizeof(int), 256, f); } else { /* Word vocabulary */ fwrite(&tok->word_vocab.size, sizeof(int), 1, f); fwrite(&tok->word_vocab.unk_idx, sizeof(int), 1, f); fwrite(&tok->word_vocab.pad_idx, sizeof(int), 1, f); fwrite(&tok->word_vocab.space_idx, sizeof(int), 1, f); fwrite(&tok->word_vocab.newline_idx, sizeof(int), 1, f); /* Save each word with length prefix */ for (int i = 0; i < tok->word_vocab.size; i++) { int len = (int)strlen(tok->word_vocab.words[i]); fwrite(&len, sizeof(int), 1, f); fwrite(tok->word_vocab.words[i], sizeof(char), (size_t)len, f); } } /* Model params */ write_tensor(f, &mp->emb); write_tensor(f, &mp->ln0_weight); write_tensor(f, &mp->ln0_bias); fwrite(&mp->n_layers, sizeof(int), 1, f); for (int i = 0; i < mp->n_layers; i++) { write_layer_params(f, &mp->layers[i]); } write_tensor(f, &mp->ln_out_weight); write_tensor(f, &mp->ln_out_bias); write_tensor(f, &mp->head); fclose(f); return 0; } static int load_model(const char *path, ModelParams *mp, lrnnConfig *cfg, Tokenizer *tok) { FILE *f = fopen(path, "rb"); if (!f) { fprintf(stderr, "Error: cannot open %s for reading\n", path); return -1; } char magic[8]; if (fread(magic, 1, 8, f) != 8) { fclose(f); return -1; } /* Check version */ bool is_v1 = (strncmp(magic, "lrnnC01", 7) == 0); bool is_v2 = (strncmp(magic, "lrnnC02", 7) == 0); if (!is_v1 && !is_v2) { fprintf(stderr, "Error: invalid model file format\n"); fclose(f); return -1; } if (fread(cfg, sizeof(lrnnConfig), 1, f) != 1) { fclose(f); return -1; } memset(tok, 0, sizeof(Tokenizer)); if (is_v1) { /* Old format: character vocabulary only */ tok->type = TOKENIZER_CHAR; if (fread(&tok->char_vocab.size, sizeof(int), 1, f) != 1) { fclose(f); return -1; } if (fread(tok->char_vocab.chars, sizeof(char), (size_t)tok->char_vocab.size, f) != (size_t)tok->char_vocab.size) { fclose(f); return -1; } if (fread(tok->char_vocab.char_to_idx, sizeof(int), 256, f) != 256) { fclose(f); return -1; } } else { int tok_type; if (fread(&tok_type, sizeof(int), 1, f) != 1) { fclose(f); return -1; } tok->type = (TokenizerType)tok_type; if (tok->type == TOKENIZER_CHAR) { if (fread(&tok->char_vocab.size, sizeof(int), 1, f) != 1) { fclose(f); return -1; } if (fread(tok->char_vocab.chars, sizeof(char), (size_t)tok->char_vocab.size, f) != (size_t)tok->char_vocab.size) { fclose(f); return -1; } if (fread(tok->char_vocab.char_to_idx, sizeof(int), 256, f) != 256) { fclose(f); return -1; } } else { /* Word vocabulary */ init_word_vocabulary(&tok->word_vocab); int size; if (fread(&size, sizeof(int), 1, f) != 1) { fclose(f); return -1; } if (fread(&tok->word_vocab.unk_idx, sizeof(int), 1, f) != 1) { fclose(f); return -1; } if (fread(&tok->word_vocab.pad_idx, sizeof(int), 1, f) != 1) { fclose(f); return -1; } if (fread(&tok->word_vocab.space_idx, sizeof(int), 1, f) != 1) { fclose(f); return -1; } if (fread(&tok->word_vocab.newline_idx, sizeof(int), 1, f) != 1) { fclose(f); return -1; } /* Read each word */ for (int i = 0; i < size; i++) { int len; if (fread(&len, sizeof(int), 1, f) != 1) { fclose(f); return -1; } char word[MAX_WORD_LEN]; if (len >= MAX_WORD_LEN) len = MAX_WORD_LEN - 1; if (fread(word, sizeof(char), (size_t)len, f) != (size_t)len) { fclose(f); return -1; } word[len] = '\0'; word_vocab_add(&tok->word_vocab, word); } } } /* Model params */ read_tensor(f, &mp->emb); read_tensor(f, &mp->ln0_weight); read_tensor(f, &mp->ln0_bias); if (fread(&mp->n_layers, sizeof(int), 1, f) != 1) { fclose(f); return -1; } mp->layers = (LayerParams *)calloc((size_t)mp->n_layers, sizeof(LayerParams)); for (int i = 0; i < mp->n_layers; i++) { read_layer_params(f, &mp->layers[i]); } /* Recompute ALiBi slopes (deterministic from config) */ for (int i = 0; i < mp->n_layers; i++) { mp->layers[i].alibi_slopes = tensor_alloc_1d(cfg->n_head); compute_alibi_slopes(mp->layers[i].alibi_slopes.data, cfg->n_head); } read_tensor(f, &mp->ln_out_weight); read_tensor(f, &mp->ln_out_bias); read_tensor(f, &mp->head); fclose(f); return 0; } /* ============================================================ * Vocabulary Building * ============================================================ */ static void build_vocabulary(Vocabulary *vocab, const char *text, size_t len) { bool seen[256] = {false}; vocab->size = 0; for (size_t i = 0; i < len; i++) { unsigned char c = (unsigned char)text[i]; if (!seen[c]) { seen[c] = true; vocab->chars[vocab->size] = (char)c; vocab->char_to_idx[c] = vocab->size; vocab->size++; } } /* Sort for consistency */ for (int i = 0; i < vocab->size - 1; i++) { for (int j = i + 1; j < vocab->size; j++) { if ((unsigned char)vocab->chars[i] > (unsigned char)vocab->chars[j]) { char tmp = vocab->chars[i]; vocab->chars[i] = vocab->chars[j]; vocab->chars[j] = tmp; } } } /* Rebuild index */ for (int i = 0; i < 256; i++) { vocab->char_to_idx[i] = 0; } for (int i = 0; i < vocab->size; i++) { vocab->char_to_idx[(unsigned char)vocab->chars[i]] = i; } } /* ============================================================ * Word-Level Vocabulary Implementation * ============================================================ */ static unsigned int word_hash(const char *word) { unsigned int hash = 5381; while (*word) { hash = ((hash << 5) + hash) ^ (unsigned char)*word++; } return hash; } static void init_word_vocabulary(WordVocabulary *wv) { wv->capacity = MAX_WORDS; wv->words = (char **)calloc((size_t)wv->capacity, sizeof(char *)); wv->hash_table = (int *)malloc(WORD_HASH_SIZE * sizeof(int)); wv->hash_keys = (int *)malloc(WORD_HASH_SIZE * sizeof(int)); for (int i = 0; i < WORD_HASH_SIZE; i++) { wv->hash_table[i] = -1; wv->hash_keys[i] = -1; } wv->size = 0; wv->unk_idx = -1; wv->pad_idx = -1; wv->space_idx = -1; wv->newline_idx = -1; } static void free_word_vocabulary(WordVocabulary *wv) { if (wv->words) { for (int i = 0; i < wv->size; i++) { free(wv->words[i]); } free(wv->words); wv->words = NULL; } if (wv->hash_table) { free(wv->hash_table); wv->hash_table = NULL; } if (wv->hash_keys) { free(wv->hash_keys); wv->hash_keys = NULL; } wv->size = 0; } static int word_vocab_find(const WordVocabulary *wv, const char *word) { unsigned int hash = word_hash(word); unsigned int idx = hash % WORD_HASH_SIZE; for (int probe = 0; probe < 1000; probe++) { unsigned int slot = (idx + probe) % WORD_HASH_SIZE; if (wv->hash_table[slot] < 0) { return -1; /* Not found */ } if (wv->hash_keys[slot] == (int)hash) { int word_idx = wv->hash_table[slot]; if (strcmp(wv->words[word_idx], word) == 0) { return word_idx; } } } return -1; } static int word_vocab_add(WordVocabulary *wv, const char *word) { /* Check if already exists */ int existing = word_vocab_find(wv, word); if (existing >= 0) return existing; /* Check capacity */ if (wv->size >= wv->capacity - 1) { fprintf(stderr, "Warning: word vocabulary full\n"); return wv->unk_idx; } /* Add word */ int word_idx = wv->size; wv->words[word_idx] = strdup(word); wv->size++; /* Add to hash table */ unsigned int hash = word_hash(word); unsigned int idx = hash % WORD_HASH_SIZE; for (int probe = 0; probe < 1000; probe++) { unsigned int slot = (idx + probe) % WORD_HASH_SIZE; if (wv->hash_table[slot] < 0) { wv->hash_table[slot] = word_idx; wv->hash_keys[slot] = (int)hash; break; } } return word_idx; } static inline int is_word_boundary(char c) { return c == ' ' || c == '\n' || c == '\t' || c == '\r' || c == '.' || c == ',' || c == '!' || c == '?' || c == ':' || c == ';' || c == '"' || c == '\'' || c == '(' || c == ')' || c == '[' || c == ']' || c == '{' || c == '}' || c == '-' || c == '/' || c == '\\' || c == '@' || c == '#' || c == '$' || c == '%' || c == '&' || c == '*' || c == '+' || c == '=' || c == '<' || c == '>' || c == '|' || c == '~' || c == '`' || c == '^'; } static void build_word_vocabulary(WordVocabulary *wv, const char *text, size_t len) { init_word_vocabulary(wv); /* Add special tokens first */ wv->unk_idx = word_vocab_add(wv, ""); wv->pad_idx = word_vocab_add(wv, ""); wv->space_idx = word_vocab_add(wv, " "); wv->newline_idx = word_vocab_add(wv, "\n"); /* Add common punctuation as separate tokens */ word_vocab_add(wv, "."); word_vocab_add(wv, ","); word_vocab_add(wv, "!"); word_vocab_add(wv, "?"); word_vocab_add(wv, ":"); word_vocab_add(wv, ";"); word_vocab_add(wv, "\""); word_vocab_add(wv, "'"); word_vocab_add(wv, "("); word_vocab_add(wv, ")"); word_vocab_add(wv, "-"); word_vocab_add(wv, "\t"); /* Parse text and add words */ char word[MAX_WORD_LEN]; int word_len = 0; for (size_t i = 0; i < len; i++) { char c = text[i]; if (is_word_boundary(c)) { /* End current word */ if (word_len > 0) { word[word_len] = '\0'; word_vocab_add(wv, word); word_len = 0; } /* Add boundary char as token (except space/tab which we handle specially) */ if (c != ' ' && c != '\t' && c != '\r') { char punct[2] = {c, '\0'}; word_vocab_add(wv, punct); } } else { /* Accumulate word */ if (word_len < MAX_WORD_LEN - 1) { word[word_len++] = c; } } } /* Handle last word */ if (word_len > 0) { word[word_len] = '\0'; word_vocab_add(wv, word); } } static int *tokenize_words(const char *text, size_t len, const WordVocabulary *wv, int *out_len) { /* Estimate max tokens */ int max_tokens = (int)(len / 2) + 100; int *tokens = (int *)malloc((size_t)max_tokens * sizeof(int)); int n_tokens = 0; char word[MAX_WORD_LEN]; int word_len = 0; for (size_t i = 0; i < len; i++) { char c = text[i]; if (is_word_boundary(c)) { /* End current word */ if (word_len > 0) { word[word_len] = '\0'; int idx = word_vocab_find(wv, word); tokens[n_tokens++] = (idx >= 0) ? idx : wv->unk_idx; word_len = 0; } /* Add boundary token */ if (c == ' ') { tokens[n_tokens++] = wv->space_idx; } else if (c == '\n') { tokens[n_tokens++] = wv->newline_idx; } else if (c == '\t') { tokens[n_tokens++] = wv->space_idx; /* Treat tab as space */ } else if (c != '\r') { char punct[2] = {c, '\0'}; int idx = word_vocab_find(wv, punct); if (idx >= 0) { tokens[n_tokens++] = idx; } } } else { if (word_len < MAX_WORD_LEN - 1) { word[word_len++] = c; } } /* Grow buffer if needed */ if (n_tokens >= max_tokens - 10) { max_tokens *= 2; tokens = (int *)realloc(tokens, (size_t)max_tokens * sizeof(int)); } } /* Handle last word */ if (word_len > 0) { word[word_len] = '\0'; int idx = word_vocab_find(wv, word); tokens[n_tokens++] = (idx >= 0) ? idx : wv->unk_idx; } *out_len = n_tokens; return tokens; } static const char *decode_word_token(int token, const WordVocabulary *wv) { if (token >= 0 && token < wv->size && wv->words[token]) { return wv->words[token]; } return ""; } /* ============================================================ * Unified Tokenizer Interface * ============================================================ */ static void init_tokenizer(Tokenizer *tok, TokenizerType type) { memset(tok, 0, sizeof(Tokenizer)); tok->type = type; } static void free_tokenizer(Tokenizer *tok) { if (tok->type == TOKENIZER_WORD) { free_word_vocabulary(&tok->word_vocab); } /* char_vocab doesn't need explicit free (static arrays) */ } static void build_tokenizer(Tokenizer *tok, const char *text, size_t len, TokenizerType requested_type) { /* Auto-select based on corpus size */ if (requested_type == TOKENIZER_AUTO) { if (len < 20000) { tok->type = TOKENIZER_CHAR; printf(" Auto-selected: character tokenizer (corpus < 20KB)\n"); } else { tok->type = TOKENIZER_WORD; printf(" Auto-selected: word tokenizer (corpus >= 20KB)\n"); } } else { tok->type = requested_type; } if (tok->type == TOKENIZER_CHAR) { build_vocabulary(&tok->char_vocab, text, len); } else { build_word_vocabulary(&tok->word_vocab, text, len); } } static int tokenizer_vocab_size(const Tokenizer *tok) { if (tok->type == TOKENIZER_CHAR) { return tok->char_vocab.size; } else { return tok->word_vocab.size; } } static int *tokenizer_encode(const Tokenizer *tok, const char *text, size_t len, int *out_len) { if (tok->type == TOKENIZER_CHAR) { return tokenize(text, len, &tok->char_vocab, out_len); } else { return tokenize_words(text, len, &tok->word_vocab, out_len); } } static void tokenizer_decode_token(const Tokenizer *tok, int token, char *out, int out_size) { if (tok->type == TOKENIZER_CHAR) { if (token >= 0 && token < tok->char_vocab.size) { out[0] = tok->char_vocab.chars[token]; out[1] = '\0'; } else { out[0] = '?'; out[1] = '\0'; } } else { const char *word = decode_word_token(token, &tok->word_vocab); strncpy(out, word, out_size - 1); out[out_size - 1] = '\0'; } } /* Auto-configure model based on corpus and tokenizer */ static lrnnConfig config_for_corpus(long corpus_bytes, TokenizerType tok_type, int vocab_size) { lrnnConfig cfg = default_config(); cfg.vocab_size = vocab_size; /* Word tokenizer is more efficient, so we can use smaller models */ float efficiency = (tok_type == TOKENIZER_WORD) ? 5.0f : 1.0f; long effective_size = (long)(corpus_bytes / efficiency); if (effective_size < 5000) { cfg.n_layer = 2; cfg.n_embd = 64; cfg.ctx_len = 64; cfg.decay_lora_rank = 4; cfg.ffn_multiplier = 1.5f; cfg.n_mem_slots = 2; } else if (effective_size < 50000) { cfg.n_layer = 4; cfg.n_embd = 128; cfg.ctx_len = 128; cfg.decay_lora_rank = 8; cfg.ffn_multiplier = 2.0f; cfg.n_mem_slots = 4; } else if (effective_size < 500000) { cfg.n_layer = 6; cfg.n_embd = 256; cfg.ctx_len = 256; cfg.decay_lora_rank = 16; cfg.ffn_multiplier = 2.5f; cfg.n_mem_slots = 4; } else { cfg.n_layer = 8; cfg.n_embd = 384; cfg.ctx_len = 512; cfg.decay_lora_rank = 32; cfg.ffn_multiplier = 3.0f; cfg.n_mem_slots = 8; } cfg.n_head = cfg.n_embd / 32; if (cfg.n_head < 2) cfg.n_head = 2; return cfg; } static int *tokenize(const char *text, size_t len, const Vocabulary *vocab, int *out_len) { int *tokens = (int *)malloc(len * sizeof(int)); for (size_t i = 0; i < len; i++) { tokens[i] = vocab->char_to_idx[(unsigned char)text[i]]; } *out_len = (int)len; return tokens; } /* ============================================================ * TRAINING SECTION - Full Backpropagation Implementation * ============================================================ * * This implements analytical gradients for all model parameters: * - Embedding layer * - Layer normalization (all instances) * - Multi-scale token shift weights * - Token mixing parameters (r, k, v) * - Data-dependent decay (LoRA + base) * - Projection matrices (Wr, Wk, Wv, Wo) * - SwiGLU FFN (gate, up, down) * - Output head */ /* ============================================================ * Gradient Structures * ============================================================ */ typedef struct { /* Layer norms */ Tensor ln1_weight, ln1_bias; Tensor ln2_weight, ln2_bias; /* Multi-scale token shift */ Tensor time_shift_w1, time_shift_w2, time_shift_w4; /* Token mixing ratios */ Tensor time_mix_r, time_mix_k, time_mix_v; /* Data-dependent decay (low-rank) */ Tensor decay_lora_a, decay_lora_b; Tensor decay_base; Tensor time_first; /* Projections */ Tensor Wr, Wk, Wv, Wo; /* Channel mix */ Tensor channel_mix; Tensor ffn_gate, ffn_up, ffn_down; /* In LayerGrads, add: */ Tensor mem_gate_write; /* (n_embd, n_mem_slots) */ Tensor mem_gate_read; /* (n_embd, n_mem_slots) */ } LayerGrads; typedef struct { Tensor emb; Tensor ln0_weight, ln0_bias; LayerGrads *layers; Tensor ln_out_weight, ln_out_bias; Tensor head; int n_layers; } ModelGrads; /* ============================================================ * Forward Pass Cache (for backpropagation) * ============================================================ */ typedef struct { Tensor x_in; /* TimeMix forward cache */ Tensor x_ln1; /* After first layer norm */ Tensor x_shifted; /* Multi-scale shifted */ Tensor shift_w1_sig; /* sigmoid(time_shift_w1) */ Tensor shift_w2_sig; /* sigmoid(time_shift_w2) */ Tensor shift_w4_sig; /* sigmoid(time_shift_w4) */ Tensor shift_w_sum; /* w1 + w2 + w4 + eps */ Tensor xr, xk, xv; /* After mixing */ Tensor mix_r_sig; /* sigmoid(time_mix_r) */ Tensor mix_k_sig; /* sigmoid(time_mix_k) */ Tensor mix_v_sig; /* sigmoid(time_mix_v) */ Tensor r_pre; /* Before sigmoid */ Tensor k_pre; /* Before exp */ Tensor v; /* v values */ Tensor r; /* After sigmoid (receptance) */ Tensor k_exp; /* After exp */ Tensor decay_tmp; /* LoRA intermediate (seq, rank) */ Tensor decay_delta; /* LoRA output */ Tensor decay_pre; /* Before sigmoid */ Tensor decay; /* After sigmoid */ Tensor time_first_exp; /* exp(time_first) */ Tensor *num_states; /* (seq+1) tensors, each (n_mem_slots, n_embd) */ Tensor *den_states; /* (seq+1) tensors, each (n_mem_slots, n_embd) */ Tensor *write_gates; /* (seq) tensors, each (n_mem_slots,) */ Tensor *read_gates; /* (seq) tensors, each (n_mem_slots,) */ Tensor wkv; /* WKV output */ Tensor wkv_r; /* wkv * r */ Tensor tm_out; /* After Wo projection */ Tensor x_after_tm; /* x + tm_out (residual) */ /* ChannelMix forward cache */ Tensor x_ln2; /* After second layer norm */ Tensor xm; /* After channel mixing */ Tensor cm_mix_sig; /* sigmoid(channel_mix) */ Tensor gate_pre; /* Before silu */ Tensor up_val; /* up values */ Tensor gate_silu; /* After silu */ Tensor hidden; /* gate * up */ Tensor cm_out; /* After down projection */ } LayerCache; typedef struct { int seq_len; int n_layers; Tensor emb_out; /* Token embeddings (seq, n_embd) */ Tensor x_ln0; /* After initial layer norm */ LayerCache *layers; /* Per-layer cache */ Tensor x_final; /* Final hidden states before output ln */ Tensor x_ln_out; /* After final layer norm */ Tensor logits; /* Final logits (seq, vocab) */ } ForwardCache; /* ============================================================ * Gradient Allocation and Deallocation * ============================================================ */ static void init_layer_grads(LayerGrads *lg, const lrnnConfig *cfg, const LayerParams *lp) { int n_embd = cfg->n_embd; int ffn_h = ffn_hidden(cfg); int lora_rank = cfg->decay_lora_rank; int n_slots = cfg->n_mem_slots; lg->ln1_weight = tensor_alloc(lp->ln1_weight.rows, lp->ln1_weight.cols); lg->ln1_bias = tensor_alloc(lp->ln1_bias.rows, lp->ln1_bias.cols); lg->ln2_weight = tensor_alloc(lp->ln2_weight.rows, lp->ln2_weight.cols); lg->ln2_bias = tensor_alloc(lp->ln2_bias.rows, lp->ln2_bias.cols); lg->time_shift_w1 = tensor_alloc_1d(n_embd); lg->time_shift_w2 = tensor_alloc_1d(n_embd); lg->time_shift_w4 = tensor_alloc_1d(n_embd); lg->time_mix_r = tensor_alloc_1d(n_embd); lg->time_mix_k = tensor_alloc_1d(n_embd); lg->time_mix_v = tensor_alloc_1d(n_embd); lg->decay_lora_a = tensor_alloc(n_embd, lora_rank); lg->decay_lora_b = tensor_alloc(lora_rank, n_embd); lg->decay_base = tensor_alloc_1d(n_embd); lg->time_first = tensor_alloc_1d(n_embd); lg->Wr = tensor_alloc(n_embd, n_embd); lg->Wk = tensor_alloc(n_embd, n_embd); lg->Wv = tensor_alloc(n_embd, n_embd); lg->Wo = tensor_alloc(n_embd, n_embd); lg->channel_mix = tensor_alloc_1d(n_embd); lg->ffn_gate = tensor_alloc(n_embd, ffn_h); lg->ffn_up = tensor_alloc(n_embd, ffn_h); lg->ffn_down = tensor_alloc(ffn_h, n_embd); lg->mem_gate_write = tensor_alloc(n_embd, n_slots); lg->mem_gate_read = tensor_alloc(n_embd, n_slots); } static void zero_layer_grads(LayerGrads *lg) { tensor_zero(&lg->ln1_weight); tensor_zero(&lg->ln1_bias); tensor_zero(&lg->ln2_weight); tensor_zero(&lg->ln2_bias); tensor_zero(&lg->time_shift_w1); tensor_zero(&lg->time_shift_w2); tensor_zero(&lg->time_shift_w4); tensor_zero(&lg->time_mix_r); tensor_zero(&lg->time_mix_k); tensor_zero(&lg->time_mix_v); tensor_zero(&lg->decay_lora_a); tensor_zero(&lg->decay_lora_b); tensor_zero(&lg->decay_base); tensor_zero(&lg->time_first); tensor_zero(&lg->Wr); tensor_zero(&lg->Wk); tensor_zero(&lg->Wv); tensor_zero(&lg->Wo); tensor_zero(&lg->channel_mix); tensor_zero(&lg->ffn_gate); tensor_zero(&lg->ffn_up); tensor_zero(&lg->ffn_down); // ADD: tensor_zero(&lg->mem_gate_write); tensor_zero(&lg->mem_gate_read); } static void free_layer_grads(LayerGrads *lg) { tensor_free(&lg->ln1_weight); tensor_free(&lg->ln1_bias); tensor_free(&lg->ln2_weight); tensor_free(&lg->ln2_bias); tensor_free(&lg->time_shift_w1); tensor_free(&lg->time_shift_w2); tensor_free(&lg->time_shift_w4); tensor_free(&lg->time_mix_r); tensor_free(&lg->time_mix_k); tensor_free(&lg->time_mix_v); tensor_free(&lg->decay_lora_a); tensor_free(&lg->decay_lora_b); tensor_free(&lg->decay_base); tensor_free(&lg->time_first); tensor_free(&lg->Wr); tensor_free(&lg->Wk); tensor_free(&lg->Wv); tensor_free(&lg->Wo); tensor_free(&lg->channel_mix); tensor_free(&lg->ffn_gate); tensor_free(&lg->ffn_up); tensor_free(&lg->ffn_down); tensor_free(&lg->mem_gate_write); tensor_free(&lg->mem_gate_read); } static void init_model_grads(ModelGrads *mg, const ModelParams *mp, const lrnnConfig *cfg) { int n_embd = cfg->n_embd; int vocab_size = cfg->vocab_size; mg->n_layers = cfg->n_layer; mg->emb = tensor_alloc(vocab_size, n_embd); mg->ln0_weight = tensor_alloc_1d(n_embd); mg->ln0_bias = tensor_alloc_1d(n_embd); mg->layers = (LayerGrads *)calloc((size_t)cfg->n_layer, sizeof(LayerGrads)); for (int i = 0; i < cfg->n_layer; i++) { init_layer_grads(&mg->layers[i], cfg, &mp->layers[i]); } mg->ln_out_weight = tensor_alloc_1d(n_embd); mg->ln_out_bias = tensor_alloc_1d(n_embd); mg->head = tensor_alloc(n_embd, vocab_size); } static void zero_model_grads(ModelGrads *mg) { tensor_zero(&mg->emb); tensor_zero(&mg->ln0_weight); tensor_zero(&mg->ln0_bias); for (int i = 0; i < mg->n_layers; i++) { zero_layer_grads(&mg->layers[i]); } tensor_zero(&mg->ln_out_weight); tensor_zero(&mg->ln_out_bias); tensor_zero(&mg->head); } static void free_model_grads(ModelGrads *mg) { tensor_free(&mg->emb); tensor_free(&mg->ln0_weight); tensor_free(&mg->ln0_bias); for (int i = 0; i < mg->n_layers; i++) { free_layer_grads(&mg->layers[i]); } free(mg->layers); mg->layers = NULL; tensor_free(&mg->ln_out_weight); tensor_free(&mg->ln_out_bias); tensor_free(&mg->head); } /* ============================================================ * Forward Cache Allocation and Deallocation * ============================================================ */ static void init_layer_cache(LayerCache *lc, int seq_len, const lrnnConfig *cfg) { if (seq_len <= 0) { fprintf(stderr, "FATAL: init_layer_cache called with seq_len=%d\n", seq_len); exit(1); } int n_embd = cfg->n_embd; int ffn_h = ffn_hidden(cfg); int lora_rank = cfg->decay_lora_rank; int n_slots = cfg->n_mem_slots; /* --- All regular cache tensors --- */ lc->x_in = tensor_alloc(seq_len, n_embd); lc->x_ln1 = tensor_alloc(seq_len, n_embd); lc->x_shifted = tensor_alloc(seq_len, n_embd); lc->shift_w1_sig = tensor_alloc_1d(n_embd); lc->shift_w2_sig = tensor_alloc_1d(n_embd); lc->shift_w4_sig = tensor_alloc_1d(n_embd); lc->shift_w_sum = tensor_alloc_1d(n_embd); lc->xr = tensor_alloc(seq_len, n_embd); lc->xk = tensor_alloc(seq_len, n_embd); lc->xv = tensor_alloc(seq_len, n_embd); lc->mix_r_sig = tensor_alloc_1d(n_embd); lc->mix_k_sig = tensor_alloc_1d(n_embd); lc->mix_v_sig = tensor_alloc_1d(n_embd); lc->r_pre = tensor_alloc(seq_len, n_embd); lc->k_pre = tensor_alloc(seq_len, n_embd); lc->v = tensor_alloc(seq_len, n_embd); lc->r = tensor_alloc(seq_len, n_embd); lc->k_exp = tensor_alloc(seq_len, n_embd); lc->decay_tmp = tensor_alloc(seq_len, lora_rank); lc->decay_delta = tensor_alloc(seq_len, n_embd); lc->decay_pre = tensor_alloc(seq_len, n_embd); lc->decay = tensor_alloc(seq_len, n_embd); lc->time_first_exp = tensor_alloc_1d(n_embd); lc->wkv = tensor_alloc(seq_len, n_embd); lc->wkv_r = tensor_alloc(seq_len, n_embd); lc->tm_out = tensor_alloc(seq_len, n_embd); lc->x_after_tm = tensor_alloc(seq_len, n_embd); lc->x_ln2 = tensor_alloc(seq_len, n_embd); lc->xm = tensor_alloc(seq_len, n_embd); lc->cm_mix_sig = tensor_alloc_1d(n_embd); lc->gate_pre = tensor_alloc(seq_len, ffn_h); lc->up_val = tensor_alloc(seq_len, ffn_h); lc->gate_silu = tensor_alloc(seq_len, ffn_h); lc->hidden = tensor_alloc(seq_len, ffn_h); lc->cm_out = tensor_alloc(seq_len, n_embd); /* --- Multi-slot WKV states --- */ lc->num_states = (Tensor *)calloc((size_t)(seq_len + 1), sizeof(Tensor)); lc->den_states = (Tensor *)calloc((size_t)(seq_len + 1), sizeof(Tensor)); if (!lc->num_states || !lc->den_states) { fprintf(stderr, "Failed to allocate WKV state arrays\n"); exit(1); } for (int t = 0; t <= seq_len; t++) { lc->num_states[t] = tensor_alloc(n_slots, n_embd); lc->den_states[t] = tensor_alloc(n_slots, n_embd); } /* --- Memory gates --- */ lc->write_gates = (Tensor *)calloc((size_t)seq_len, sizeof(Tensor)); lc->read_gates = (Tensor *)calloc((size_t)seq_len, sizeof(Tensor)); if (!lc->write_gates || !lc->read_gates) { fprintf(stderr, "Failed to allocate gate arrays\n"); exit(1); } for (int t = 0; t < seq_len; t++) { lc->write_gates[t] = tensor_alloc_1d(n_slots); lc->read_gates[t] = tensor_alloc_1d(n_slots); } } static void free_layer_cache(LayerCache *lc, int seq_len) { tensor_free(&lc->x_in); tensor_free(&lc->x_ln1); tensor_free(&lc->x_shifted); tensor_free(&lc->shift_w1_sig); tensor_free(&lc->shift_w2_sig); tensor_free(&lc->shift_w4_sig); tensor_free(&lc->shift_w_sum); tensor_free(&lc->xr); tensor_free(&lc->xk); tensor_free(&lc->xv); tensor_free(&lc->mix_r_sig); tensor_free(&lc->mix_k_sig); tensor_free(&lc->mix_v_sig); tensor_free(&lc->r_pre); tensor_free(&lc->k_pre); tensor_free(&lc->v); tensor_free(&lc->r); tensor_free(&lc->k_exp); tensor_free(&lc->decay_tmp); tensor_free(&lc->decay_delta); tensor_free(&lc->decay_pre); tensor_free(&lc->decay); tensor_free(&lc->time_first_exp); for (int t = 0; t <= seq_len; t++) { tensor_free(&lc->num_states[t]); tensor_free(&lc->den_states[t]); } free(lc->num_states); free(lc->den_states); tensor_free(&lc->wkv); tensor_free(&lc->wkv_r); tensor_free(&lc->tm_out); tensor_free(&lc->x_after_tm); tensor_free(&lc->x_ln2); tensor_free(&lc->xm); tensor_free(&lc->cm_mix_sig); tensor_free(&lc->gate_pre); tensor_free(&lc->up_val); tensor_free(&lc->gate_silu); tensor_free(&lc->hidden); tensor_free(&lc->cm_out); // free(lc->num_states); // free(lc->den_states); for (int t = 0; t < seq_len; t++) { tensor_free(&lc->write_gates[t]); tensor_free(&lc->read_gates[t]); } free(lc->write_gates); free(lc->read_gates); } static void init_forward_cache(ForwardCache *fc, int seq_len, const lrnnConfig *cfg) { int n_embd = cfg->n_embd; int vocab_size = cfg->vocab_size; fc->seq_len = seq_len; fc->n_layers = cfg->n_layer; fc->emb_out = tensor_alloc(seq_len, n_embd); fc->x_ln0 = tensor_alloc(seq_len, n_embd); fc->layers = (LayerCache *)calloc((size_t)cfg->n_layer, sizeof(LayerCache)); for (int i = 0; i < cfg->n_layer; i++) { init_layer_cache(&fc->layers[i], seq_len, cfg); } fc->x_final = tensor_alloc(seq_len, n_embd); fc->x_ln_out = tensor_alloc(seq_len, n_embd); fc->logits = tensor_alloc(seq_len, vocab_size); } static void free_forward_cache(ForwardCache *fc) { tensor_free(&fc->emb_out); tensor_free(&fc->x_ln0); for (int i = 0; i < fc->n_layers; i++) { free_layer_cache(&fc->layers[i], fc->seq_len); } free(fc->layers); tensor_free(&fc->x_final); tensor_free(&fc->x_ln_out); tensor_free(&fc->logits); } /* ============================================================ * Backward Primitive Operations * ============================================================ */ /* Sigmoid backward: d_input = d_output * sigmoid(x) * (1 - sigmoid(x)) * Given y = sigmoid(x), d_input = d_output * y * (1 - y) */ static void sigmoid_backward(float *d_input, const float *d_output, const float *y, int n) { for (int i = 0; i < n; i++) { d_input[i] = d_output[i] * y[i] * (1.0f - y[i]); } } /* SiLU backward: y = x * sigmoid(x) * dy/dx = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x)) * = sigmoid(x) * (1 + x * (1 - sigmoid(x))) */ static void silu_backward(float *d_input, const float *d_output, const float *x, int n) { for (int i = 0; i < n; i++) { float s = sigmoid_f(x[i]); float grad = s * (1.0f + x[i] * (1.0f - s)); d_input[i] = d_output[i] * grad; } } /* Exp backward: d_input = d_output * exp(x) = d_output * y * But we clamp, so we need the clamped version */ static void exp_backward_clamped(float *d_input, const float *d_output, const float *x, int n) { for (int i = 0; i < n; i++) { float clamped = clamp_f(x[i], -10.0f, 10.0f); float y = expf(clamped); /* Gradient is zero outside the clamp range */ if (x[i] < -10.0f || x[i] > 10.0f) { d_input[i] = 0.0f; } else { d_input[i] = d_output[i] * y; } } } /* Matrix multiply backward: Y = X @ W * dL/dX = dL/dY @ W^T * dL/dW = X^T @ dL/dY */ static void matmul_backward_x(Tensor *d_X, const Tensor *d_Y, const Tensor *W) { /* d_X = d_Y @ W^T */ /* d_Y: (seq, out_dim), W: (in_dim, out_dim), d_X: (seq, in_dim) */ int seq_len = d_Y->rows; int out_dim = d_Y->cols; int in_dim = W->rows; for (int s = 0; s < seq_len; s++) { for (int i = 0; i < in_dim; i++) { float sum = 0.0f; for (int j = 0; j < out_dim; j++) { sum += d_Y->data[s * out_dim + j] * W->data[i * out_dim + j]; } d_X->data[s * in_dim + i] = sum; } } } static void matmul_backward_w(Tensor *d_W, const Tensor *d_Y, const Tensor *X) { /* d_W = X^T @ d_Y */ /* X: (seq, in_dim), d_Y: (seq, out_dim), d_W: (in_dim, out_dim) */ int seq_len = X->rows; int in_dim = X->cols; int out_dim = d_Y->cols; for (int i = 0; i < in_dim; i++) { for (int j = 0; j < out_dim; j++) { float sum = 0.0f; for (int s = 0; s < seq_len; s++) { sum += X->data[s * in_dim + i] * d_Y->data[s * out_dim + j]; } d_W->data[i * out_dim + j] += sum; /* accumulate */ } } } /* Layer normalization backward * y = weight * (x - mean) / sqrt(var + eps) + bias * This is a bit complex due to the mean and variance dependencies */ static void layer_norm_backward_single(float *d_x, float *d_weight, float *d_bias, const float *d_y, const float *x, const float *weight, int n) { /* Forward stats */ float mean = 0.0f; for (int i = 0; i < n; i++) mean += x[i]; mean /= (float)n; float var = 0.0f; for (int i = 0; i < n; i++) { float a = x[i] - mean; var += a * a; } var /= (float)n; float inv_std = 1.0f / sqrtf(var + EPSILON); /* Accumulate d_gamma, d_beta and helper sums for dx */ float sum_dy_gamma = 0.0f; float sum_dy_gamma_xhat = 0.0f; for (int i = 0; i < n; i++) { float x_hat = (x[i] - mean) * inv_std; float dy = d_y[i]; d_weight[i] += dy * x_hat; d_bias[i] += dy; float dy_gamma = dy * weight[i]; sum_dy_gamma += dy_gamma; sum_dy_gamma_xhat += dy_gamma * x_hat; } /* dx formula */ float inv_n = 1.0f / (float)n; for (int i = 0; i < n; i++) { float x_hat = (x[i] - mean) * inv_std; float dy_gamma = d_y[i] * weight[i]; d_x[i] = inv_n * inv_std * ((float)n * dy_gamma - sum_dy_gamma - x_hat * sum_dy_gamma_xhat); } } static void layer_norm_backward_seq(Tensor *d_x, Tensor *d_weight, Tensor *d_bias, const Tensor *d_y, const Tensor *x, const Tensor *weight) { int seq_len = x->rows; int n_embd = x->cols; for (int s = 0; s < seq_len; s++) { layer_norm_backward_single( d_x->data + s * n_embd, d_weight->data, d_bias->data, d_y->data + s * n_embd, x->data + s * n_embd, weight->data, n_embd ); } } /* Softmax cross-entropy backward * For softmax with cross-entropy, the gradient is simply: probs - one_hot(target) * This is one of the nice properties of this combination */ static void softmax_cross_entropy_backward(Tensor *d_logits, const Tensor *logits, const int *targets) { int seq_len = logits->rows; int vocab_size = logits->cols; float *probs = (float *)malloc((size_t)vocab_size * sizeof(float)); float eps = 0.1f; // smoothing factor float invV = 1.0f / (float)vocab_size; for (int t = 0; t < seq_len; t++) { softmax_vec(probs, logits->data + t * vocab_size, vocab_size); int target = targets[t]; for (int v = 0; v < vocab_size; v++) { float q = (v == target) ? (1.0f - eps + eps * invV) // mostly target… : (eps * invV); // …but small mass on others d_logits->data[t * vocab_size + v] = (probs[v] - q) / (float)seq_len; } } free(probs); } /* ============================================================ * WKV Backward Pass * ============================================================ * * Forward recurrence: * wkv_t = (num_t + tf * k_t * v_t) / (den_t + tf * k_t + eps) * num_{t+1} = decay_t * num_t + k_t * v_t * den_{t+1} = decay_t * den_t + k_t * * Where: * k_t = exp(k_pre_t) * tf = exp(time_first) * decay_t = sigmoid(decay_base + decay_delta_t) * * We need backward pass through this recurrence. */ static void wkv_backward( Tensor *d_k_exp, Tensor *d_v, Tensor *d_decay, Tensor *d_time_first_exp, Tensor *d_mem_gate_write_proj, Tensor *d_mem_gate_read_proj, const Tensor *d_wkv, const Tensor *k_exp, const Tensor *v, const Tensor *decay, const Tensor *time_first_exp, Tensor *num_states, Tensor *den_states, Tensor *write_gates, Tensor *read_gates, int seq_len, int n_embd, int n_head, int n_slots, const float *alibi_slopes ) { int hdim = n_embd / n_head; tensor_zero(d_k_exp); tensor_zero(d_v); tensor_zero(d_decay); tensor_zero(d_time_first_exp); /* Gradient accumulators for slot states */ /* d_num_next[s * n_embd + i], d_den_next[s * n_embd + i] */ float *d_num_next = (float *)calloc((size_t)(n_slots * n_embd), sizeof(float)); float *d_den_next = (float *)calloc((size_t)(n_slots * n_embd), sizeof(float)); /* Gradient accumulators for gate logits (pre-softmax) */ /* We'll accumulate d_write_gates and d_read_gates, then backprop through softmax outside this function */ float *d_write_g = (float *)calloc((size_t)n_slots, sizeof(float)); float *d_read_g = (float *)calloc((size_t)n_slots, sizeof(float)); float *alibi_decay_arr = (float *)malloc((size_t)n_head * sizeof(float)); for (int h = 0; h < n_head; h++) { alibi_decay_arr[h] = expf(-alibi_slopes[h]); } for (int t = seq_len - 1; t >= 0; t--) { memset(d_write_g, 0, (size_t)n_slots * sizeof(float)); memset(d_read_g, 0, (size_t)n_slots * sizeof(float)); for (int h = 0; h < n_head; h++) { int base_h = h * hdim; float ad = alibi_decay_arr[h]; for (int d_idx = 0; d_idx < hdim; d_idx++) { int i = base_h + d_idx; int idx = t * n_embd + i; float ki = k_exp->data[idx]; float vi = v->data[idx]; float di = decay->data[idx]; float tfi = time_first_exp->data[i]; float kv = ki * vi; float combined = di * ad; /* Recompute read values */ float read_num = 0.0f, read_den = 0.0f; for (int s = 0; s < n_slots; s++) { int si = s * n_embd + i; read_num += read_gates[t].data[s] * num_states[t].data[si]; read_den += read_gates[t].data[s] * den_states[t].data[si]; } float numerator = read_num + tfi * kv; float denominator = read_den + tfi * ki + EPSILON; float inv_den = 1.0f / denominator; float dw = d_wkv->data[idx]; float d_numerator = dw * inv_den; float d_denominator = -dw * numerator * inv_den * inv_den; /* Gradients for read */ float d_read_num = d_numerator; float d_read_den = d_denominator; for (int s = 0; s < n_slots; s++) { int si = s * n_embd + i; d_read_g[s] += d_read_num * num_states[t].data[si]; d_read_g[s] += d_read_den * den_states[t].data[si]; /* Gradient to states from read */ float d_state_num = d_read_num * read_gates[t].data[s]; float d_state_den = d_read_den * read_gates[t].data[s]; /* Add gradient from future via state update */ d_state_num += d_num_next[si] * combined; d_state_den += d_den_next[si] * combined; /* Gradient for decay from state update */ d_decay->data[idx] += (d_num_next[si] * num_states[t].data[si] + d_den_next[si] * den_states[t].data[si]) * ad; /* Gradient for write gate */ d_write_g[s] += d_num_next[si] * kv + d_den_next[si] * ki; /* Gradient for k, v from write */ float wg = write_gates[t].data[s]; d_k_exp->data[idx] += d_num_next[si] * wg * vi + d_den_next[si] * wg; d_v->data[idx] += d_num_next[si] * wg * ki; /* Propagate to previous timestep */ d_num_next[si] = d_state_num; d_den_next[si] = d_state_den; } /* Gradients from wkv output for tf, k, v */ d_time_first_exp->data[i] += d_numerator * kv + d_denominator * ki; d_k_exp->data[idx] += d_numerator * tfi * vi + d_denominator * tfi; d_v->data[idx] += d_numerator * tfi * ki; } } /* Backprop through softmax for write/read gates at timestep t */ /* d_logits = softmax_backward(d_gates, gates) */ /* For softmax: d_logit_i = sum_j (d_gate_j * gate_j * (delta_ij - gate_i)) */ /* We need to accumulate into d_mem_gate_write_proj / d_mem_gate_read_proj */ { float *wg = write_gates[t].data; float *rg = read_gates[t].data; float d_wl[n_slots], d_rl[n_slots]; /* VLA ok, n_slots is small */ /* Softmax backward for write gates */ float dot_w = 0.0f; for (int s = 0; s < n_slots; s++) dot_w += d_write_g[s] * wg[s]; for (int s = 0; s < n_slots; s++) { d_wl[s] = wg[s] * (d_write_g[s] - dot_w); } /* Softmax backward for read gates */ float dot_r = 0.0f; for (int s = 0; s < n_slots; s++) dot_r += d_read_g[s] * rg[s]; for (int s = 0; s < n_slots; s++) { d_rl[s] = rg[s] * (d_read_g[s] - dot_r); } /* These are gradients w.r.t. the gate logits = x_ln1 @ mem_gate_{write,read} * We accumulate d_x_ln1 contribution and d_W contribution outside */ /* Store in the gradient tensors for later matmul backward */ for (int s = 0; s < n_slots; s++) { d_mem_gate_write_proj->data[t * n_slots + s] = d_wl[s]; d_mem_gate_read_proj->data[t * n_slots + s] = d_rl[s]; } } } /* Clamp */ for (int i = 0; i < d_k_exp->size; i++) { d_k_exp->data[i] = clamp_f(d_k_exp->data[i], -GRAD_CLIP, GRAD_CLIP); d_v->data[i] = clamp_f(d_v->data[i], -GRAD_CLIP, GRAD_CLIP); } for (int i = 0; i < d_decay->size; i++) { d_decay->data[i] = clamp_f(d_decay->data[i], -GRAD_CLIP, GRAD_CLIP); } for (int i = 0; i < n_embd; i++) { d_time_first_exp->data[i] = clamp_f(d_time_first_exp->data[i], -GRAD_CLIP, GRAD_CLIP); } free(d_num_next); free(d_den_next); free(d_write_g); free(d_read_g); free(alibi_decay_arr); } static void multi_scale_shift_backward( Tensor *d_x, /* Gradient w.r.t. input x */ Tensor *d_shift_w1, /* Gradient w.r.t. time_shift_w1 (before sigmoid) */ Tensor *d_shift_w2, Tensor *d_shift_w4, const Tensor *d_out, /* Incoming gradient */ const Tensor *x, /* Original input */ const Tensor *shift_w1_sig, /* Cached sigmoid outputs */ const Tensor *shift_w2_sig, const Tensor *shift_w4_sig, const Tensor *shift_w_sum, /* w1 + w2 + w4 + eps */ int seq_len, int n_embd ) { tensor_zero(d_x); tensor_zero(d_shift_w1); tensor_zero(d_shift_w2); tensor_zero(d_shift_w4); for (int t = 0; t < seq_len; t++) { for (int i = 0; i < n_embd; i++) { float w1 = shift_w1_sig->data[i]; float w2 = shift_w2_sig->data[i]; float w4 = shift_w4_sig->data[i]; float w_sum = shift_w_sum->data[i]; float inv_sum = 1.0f / w_sum; float x1 = (t >= 1) ? x->data[(t-1) * n_embd + i] : 0.0f; float x2 = (t >= 2) ? x->data[(t-2) * n_embd + i] : 0.0f; float x4 = (t >= 4) ? x->data[(t-4) * n_embd + i] : 0.0f; float d_out_ti = d_out->data[t * n_embd + i]; /* out = (w1*x1 + w2*x2 + w4*x4) / w_sum * where w_sum = w1 + w2 + w4 + eps */ /* Gradient w.r.t. x1, x2, x4 */ float d_x1 = d_out_ti * w1 * inv_sum; float d_x2 = d_out_ti * w2 * inv_sum; float d_x4 = d_out_ti * w4 * inv_sum; if (t >= 1) d_x->data[(t-1) * n_embd + i] += d_x1; if (t >= 2) d_x->data[(t-2) * n_embd + i] += d_x2; if (t >= 4) d_x->data[(t-4) * n_embd + i] += d_x4; /* Gradient w.r.t. w1, w2, w4 (normalized weights) */ float numerator = w1 * x1 + w2 * x2 + w4 * x4; /* d_w1 (before normalization) */ /* Let n1 = w1/sum, output = n1*x1 + n2*x2 + n4*x4 */ /* d_w1 = d_out * (x1/sum - numerator/sum^2) */ float d_w1_raw = d_out_ti * (x1 * inv_sum - numerator * inv_sum * inv_sum); float d_w2_raw = d_out_ti * (x2 * inv_sum - numerator * inv_sum * inv_sum); float d_w4_raw = d_out_ti * (x4 * inv_sum - numerator * inv_sum * inv_sum); /* Through sigmoid: d_w_pre = d_w_raw * w * (1 - w) */ d_shift_w1->data[i] += d_w1_raw * w1 * (1.0f - w1); d_shift_w2->data[i] += d_w2_raw * w2 * (1.0f - w2); d_shift_w4->data[i] += d_w4_raw * w4 * (1.0f - w4); } } } /* ============================================================ * Token Mixing Backward * ============================================================ */ static void token_mixing_backward( Tensor *d_x_ln1, /* Gradient w.r.t. layer norm output */ Tensor *d_x_shifted, /* Gradient w.r.t. shifted input */ Tensor *d_mix_r, /* Gradient w.r.t. time_mix_r */ Tensor *d_mix_k, Tensor *d_mix_v, const Tensor *d_xr, /* Incoming gradient for xr */ const Tensor *d_xk, const Tensor *d_xv, const Tensor *x_ln1, const Tensor *x_shifted, const Tensor *mix_r_sig, const Tensor *mix_k_sig, const Tensor *mix_v_sig, int seq_len, int n_embd ) { tensor_zero(d_x_ln1); tensor_zero(d_x_shifted); tensor_zero(d_mix_r); tensor_zero(d_mix_k); tensor_zero(d_mix_v); for (int t = 0; t < seq_len; t++) { for (int i = 0; i < n_embd; i++) { int idx = t * n_embd + i; float mr = mix_r_sig->data[i]; float mk = mix_k_sig->data[i]; float mv = mix_v_sig->data[i]; float x_val = x_ln1->data[idx]; float x_sh = x_shifted->data[idx]; /* xr = x * mr + x_shifted * (1 - mr) */ float d_xr_ti = d_xr->data[idx]; d_x_ln1->data[idx] += d_xr_ti * mr; d_x_shifted->data[idx] += d_xr_ti * (1.0f - mr); /* d_mr = d_xr * (x - x_shifted) */ float d_mr = d_xr_ti * (x_val - x_sh); /* Through sigmoid */ d_mix_r->data[i] += d_mr * mr * (1.0f - mr); /* xk = x * mk + x_shifted * (1 - mk) */ float d_xk_ti = d_xk->data[idx]; d_x_ln1->data[idx] += d_xk_ti * mk; d_x_shifted->data[idx] += d_xk_ti * (1.0f - mk); float d_mk = d_xk_ti * (x_val - x_sh); d_mix_k->data[i] += d_mk * mk * (1.0f - mk); /* xv = x * mv + x_shifted * (1 - mv) */ float d_xv_ti = d_xv->data[idx]; d_x_ln1->data[idx] += d_xv_ti * mv; d_x_shifted->data[idx] += d_xv_ti * (1.0f - mv); float d_mv = d_xv_ti * (x_val - x_sh); d_mix_v->data[i] += d_mv * mv * (1.0f - mv); } } } /* ============================================================ * Channel Mixing Backward * ============================================================ */ static void channel_mix_shift_backward( Tensor *d_x_ln2, /* Gradient w.r.t. layer norm output */ Tensor *d_channel_mix, /* Gradient w.r.t. channel_mix parameter */ const Tensor *d_xm, /* Incoming gradient */ const Tensor *x_ln2, /* Cached layer norm output */ const Tensor *cm_mix_sig, /* Cached sigmoid(channel_mix) */ int seq_len, int n_embd ) { tensor_zero(d_x_ln2); tensor_zero(d_channel_mix); for (int t = 0; t < seq_len; t++) { for (int i = 0; i < n_embd; i++) { int idx = t * n_embd + i; float mix = cm_mix_sig->data[i]; float x_curr = x_ln2->data[idx]; float x_prev = (t > 0) ? x_ln2->data[(t-1) * n_embd + i] : 0.0f; float d_xm_ti = d_xm->data[idx]; /* xm = x_curr * mix + x_prev * (1 - mix) */ d_x_ln2->data[idx] += d_xm_ti * mix; if (t > 0) { d_x_ln2->data[(t-1) * n_embd + i] += d_xm_ti * (1.0f - mix); } float d_mix = d_xm_ti * (x_curr - x_prev); d_channel_mix->data[i] += d_mix * mix * (1.0f - mix); } } } /* ============================================================ * Forward Pass with Caching * ============================================================ */ static float forward_with_cache(ForwardCache *cache, const int *tokens, int seq_len, const ModelParams *mp, const lrnnConfig *cfg) { int n_embd = cfg->n_embd; int n_head = cfg->n_head; //int hdim = n_embd / n_head; /* Embedding lookup */ for (int t = 0; t < seq_len; t++) { int tok = tokens[t]; memcpy(cache->emb_out.data + t * n_embd, mp->emb.data + tok * n_embd, (size_t)n_embd * sizeof(float)); } /* Initial layer norm */ layer_norm_seq(&cache->x_ln0, &cache->emb_out, &mp->ln0_weight, &mp->ln0_bias); /* Copy x_ln0 to first layer input - we'll use x_final as working space */ tensor_copy(&cache->x_final, &cache->x_ln0); /* Process layers */ for (int layer_idx = 0; layer_idx < mp->n_layers; layer_idx++) { const LayerParams *lp = &mp->layers[layer_idx]; LayerCache *lc = &cache->layers[layer_idx]; /* ============ TimeMix Forward ============ */ tensor_copy(&lc->x_in, &cache->x_final); /* Layer norm 1 */ layer_norm_seq(&lc->x_ln1, &cache->x_final, &lp->ln1_weight, &lp->ln1_bias); /* Multi-scale shift - compute sigmoid weights */ sigmoid_vec(lc->shift_w1_sig.data, lp->time_shift_w1.data, n_embd); sigmoid_vec(lc->shift_w2_sig.data, lp->time_shift_w2.data, n_embd); sigmoid_vec(lc->shift_w4_sig.data, lp->time_shift_w4.data, n_embd); for (int i = 0; i < n_embd; i++) { lc->shift_w_sum.data[i] = lc->shift_w1_sig.data[i] + lc->shift_w2_sig.data[i] + lc->shift_w4_sig.data[i] + EPSILON; } /* Apply multi-scale shift */ for (int t = 0; t < seq_len; t++) { for (int i = 0; i < n_embd; i++) { float w1 = lc->shift_w1_sig.data[i] / lc->shift_w_sum.data[i]; float w2 = lc->shift_w2_sig.data[i] / lc->shift_w_sum.data[i]; float w4 = lc->shift_w4_sig.data[i] / lc->shift_w_sum.data[i]; float x1 = (t >= 1) ? lc->x_ln1.data[(t-1) * n_embd + i] : 0.0f; float x2 = (t >= 2) ? lc->x_ln1.data[(t-2) * n_embd + i] : 0.0f; float x4 = (t >= 4) ? lc->x_ln1.data[(t-4) * n_embd + i] : 0.0f; lc->x_shifted.data[t * n_embd + i] = w1 * x1 + w2 * x2 + w4 * x4; } } /* Token mixing weights */ sigmoid_vec(lc->mix_r_sig.data, lp->time_mix_r.data, n_embd); sigmoid_vec(lc->mix_k_sig.data, lp->time_mix_k.data, n_embd); sigmoid_vec(lc->mix_v_sig.data, lp->time_mix_v.data, n_embd); for (int t = 0; t < seq_len; t++) { for (int i = 0; i < n_embd; i++) { int idx = t * n_embd + i; float mr = lc->mix_r_sig.data[i]; float mk = lc->mix_k_sig.data[i]; float mv = lc->mix_v_sig.data[i]; lc->xr.data[idx] = lc->x_ln1.data[idx] * mr + lc->x_shifted.data[idx] * (1.0f - mr); lc->xk.data[idx] = lc->x_ln1.data[idx] * mk + lc->x_shifted.data[idx] * (1.0f - mk); lc->xv.data[idx] = lc->x_ln1.data[idx] * mv + lc->x_shifted.data[idx] * (1.0f - mv); } } /* Projections */ matmul(&lc->r_pre, &lc->xr, &lp->Wr); matmul(&lc->k_pre, &lc->xk, &lp->Wk); matmul(&lc->v, &lc->xv, &lp->Wv); /* Data-dependent decay */ matmul(&lc->decay_tmp, &lc->x_ln1, &lp->decay_lora_a); matmul(&lc->decay_delta, &lc->decay_tmp, &lp->decay_lora_b); for (int t = 0; t < seq_len; t++) { for (int i = 0; i < n_embd; i++) { int idx = t * n_embd + i; lc->decay_pre.data[idx] = lp->decay_base.data[i] + lc->decay_delta.data[idx]; lc->decay.data[idx] = sigmoid_f(lc->decay_pre.data[idx]); } } /* Receptance (sigmoid) and k (exp) */ for (int i = 0; i < lc->r_pre.size; i++) { lc->r.data[i] = sigmoid_f(lc->r_pre.data[i]); } for (int i = 0; i < n_embd; i++) { lc->time_first_exp.data[i] = expf(clamp_f(lp->time_first.data[i], -10.0f, 10.0f)); } for (int i = 0; i < lc->k_pre.size; i++) { lc->k_exp.data[i] = expf(clamp_f(lc->k_pre.data[i], -10.0f, 10.0f)); } /* WKV sequential scan with state caching and ALiBi */ /* WKV scan with multi-slot memory and ALiBi */ { int n_slots = cfg->n_mem_slots; int hdim_val = n_embd / n_head; /* Zero initial states */ tensor_zero(&lc->num_states[0]); tensor_zero(&lc->den_states[0]); float *wl = (float *)malloc((size_t)n_slots * sizeof(float)); float *rl = (float *)malloc((size_t)n_slots * sizeof(float)); for (int t = 0; t < seq_len; t++) { /* Compute gates for this timestep */ matvec(wl, lc->x_ln1.data + t * n_embd, &lp->mem_gate_write); softmax_vec(lc->write_gates[t].data, wl, n_slots); matvec(rl, lc->x_ln1.data + t * n_embd, &lp->mem_gate_read); softmax_vec(lc->read_gates[t].data, rl, n_slots); for (int h = 0; h < n_head; h++) { int base = h * hdim_val; float ad = expf(-lp->alibi_slopes.data[h]); for (int d = 0; d < hdim_val; d++) { int i = base + d; int idx = t * n_embd + i; float ki = lc->k_exp.data[idx]; float vi = lc->v.data[idx]; float di = lc->decay.data[idx]; float tfi = lc->time_first_exp.data[i]; float kv = ki * vi; float combined = di * ad; /* Read from slots */ float read_num = 0.0f, read_den = 0.0f; for (int s = 0; s < n_slots; s++) { int si = s * n_embd + i; read_num += lc->read_gates[t].data[s] * lc->num_states[t].data[si]; read_den += lc->read_gates[t].data[s] * lc->den_states[t].data[si]; } lc->wkv.data[idx] = (read_num + tfi * kv) / (read_den + tfi * ki + EPSILON); /* Write to slots */ for (int s = 0; s < n_slots; s++) { int si = s * n_embd + i; float wg = lc->write_gates[t].data[s]; lc->num_states[t+1].data[si] = combined * lc->num_states[t].data[si] + wg * kv; lc->den_states[t+1].data[si] = combined * lc->den_states[t].data[si] + wg * ki; } } } } free(wl); free(rl); } /* Apply receptance and output projection */ for (int i = 0; i < lc->wkv.size; i++) { lc->wkv_r.data[i] = lc->wkv.data[i] * lc->r.data[i]; } matmul(&lc->tm_out, &lc->wkv_r, &lp->Wo); /* Residual connection */ for (int i = 0; i < cache->x_final.size; i++) { lc->x_after_tm.data[i] = cache->x_final.data[i] + lc->tm_out.data[i]; } /* ============ ChannelMix Forward ============ */ /* Layer norm 2 */ layer_norm_seq(&lc->x_ln2, &lc->x_after_tm, &lp->ln2_weight, &lp->ln2_bias); /* Channel mixing */ sigmoid_vec(lc->cm_mix_sig.data, lp->channel_mix.data, n_embd); for (int t = 0; t < seq_len; t++) { for (int i = 0; i < n_embd; i++) { int idx = t * n_embd + i; float mix = lc->cm_mix_sig.data[i]; float x_curr = lc->x_ln2.data[idx]; float x_prev = (t > 0) ? lc->x_ln2.data[(t-1) * n_embd + i] : 0.0f; lc->xm.data[idx] = x_curr * mix + x_prev * (1.0f - mix); } } /* SwiGLU FFN */ matmul(&lc->gate_pre, &lc->xm, &lp->ffn_gate); matmul(&lc->up_val, &lc->xm, &lp->ffn_up); for (int i = 0; i < lc->gate_pre.size; i++) { lc->gate_silu.data[i] = silu_f(lc->gate_pre.data[i]); lc->hidden.data[i] = lc->gate_silu.data[i] * lc->up_val.data[i]; } matmul(&lc->cm_out, &lc->hidden, &lp->ffn_down); /* Residual connection - update x_final for next layer */ for (int i = 0; i < cache->x_final.size; i++) { cache->x_final.data[i] = lc->x_after_tm.data[i] + lc->cm_out.data[i]; } } /* Output layer norm and head */ layer_norm_seq(&cache->x_ln_out, &cache->x_final, &mp->ln_out_weight, &mp->ln_out_bias); matmul(&cache->logits, &cache->x_ln_out, &mp->head); /* Compute loss (return for monitoring) */ float loss = cross_entropy_loss(&cache->logits, tokens + 1, seq_len - 1); return loss; } /* ============================================================ * Backward Pass * ============================================================ */ static void backward_pass(ModelGrads *grads, const ForwardCache *cache, const int *tokens, int seq_len, const ModelParams *mp, const lrnnConfig *cfg) { int n_embd = cfg->n_embd; int vocab_size = cfg->vocab_size; int ffn_h = ffn_hidden(cfg); int lora_rank = cfg->decay_lora_rank; int target_len = seq_len - 1; /* We predict tokens[1:] from tokens[:-1] */ /* Allocate working gradient tensors */ Tensor d_logits = tensor_alloc(target_len, vocab_size); Tensor d_x_ln_out = tensor_alloc(target_len, n_embd); Tensor d_x_final = tensor_alloc(target_len, n_embd); Tensor d_x_tmp = tensor_alloc(target_len, n_embd); /* Softmax cross-entropy backward */ /* Create view of logits for target_len positions */ Tensor logits_view = { .data = cache->logits.data, .rows = target_len, .cols = vocab_size, .size = target_len * vocab_size }; softmax_cross_entropy_backward(&d_logits, &logits_view, tokens + 1); /* Head backward: logits = x_ln_out @ head */ matmul_backward_x(&d_x_ln_out, &d_logits, &mp->head); /* Create view of x_ln_out for target_len */ Tensor x_ln_out_view = { .data = cache->x_ln_out.data, .rows = target_len, .cols = n_embd, .size = target_len * n_embd }; matmul_backward_w(&grads->head, &d_logits, &x_ln_out_view); /* Output layer norm backward */ Tensor x_final_view = { .data = cache->x_final.data, .rows = target_len, .cols = n_embd, .size = target_len * n_embd }; layer_norm_backward_seq(&d_x_final, &grads->ln_out_weight, &grads->ln_out_bias, &d_x_ln_out, &x_final_view, &mp->ln_out_weight); /* Backward through layers (in reverse order) */ for (int layer_idx = mp->n_layers - 1; layer_idx >= 0; layer_idx--) { const LayerParams *lp = &mp->layers[layer_idx]; const LayerCache *lc = &cache->layers[layer_idx]; LayerGrads *lg = &grads->layers[layer_idx]; /* ============ ChannelMix Backward ============ */ /* d_x_final comes from residual: x_final = x_after_tm + cm_out */ /* So d_x_after_tm and d_cm_out both get d_x_final */ Tensor d_cm_out = tensor_alloc(target_len, n_embd); Tensor d_hidden = tensor_alloc(target_len, ffn_h); Tensor d_gate_silu = tensor_alloc(target_len, ffn_h); Tensor d_up_val = tensor_alloc(target_len, ffn_h); Tensor d_gate_pre = tensor_alloc(target_len, ffn_h); Tensor d_xm = tensor_alloc(target_len, n_embd); Tensor d_x_ln2 = tensor_alloc(target_len, n_embd); Tensor d_x_after_tm = tensor_alloc(target_len, n_embd); /* d_cm_out = d_x_final */ tensor_copy(&d_cm_out, &d_x_final); /* cm_out = hidden @ ffn_down backward */ Tensor hidden_view = { .data = lc->hidden.data, .rows = target_len, .cols = ffn_h, .size = target_len * ffn_h }; matmul_backward_x(&d_hidden, &d_cm_out, &lp->ffn_down); matmul_backward_w(&lg->ffn_down, &d_cm_out, &hidden_view); /* hidden = gate_silu * up_val */ for (int i = 0; i < target_len * ffn_h; i++) { d_gate_silu.data[i] = d_hidden.data[i] * lc->up_val.data[i]; d_up_val.data[i] = d_hidden.data[i] * lc->gate_silu.data[i]; } /* gate_silu = silu(gate_pre) */ silu_backward(d_gate_pre.data, d_gate_silu.data, lc->gate_pre.data, target_len * ffn_h); /* gate_pre = xm @ ffn_gate, up_val = xm @ ffn_up */ Tensor xm_view = { .data = lc->xm.data, .rows = target_len, .cols = n_embd, .size = target_len * n_embd }; Tensor d_xm_gate = tensor_alloc(target_len, n_embd); Tensor d_xm_up = tensor_alloc(target_len, n_embd); matmul_backward_x(&d_xm_gate, &d_gate_pre, &lp->ffn_gate); matmul_backward_w(&lg->ffn_gate, &d_gate_pre, &xm_view); matmul_backward_x(&d_xm_up, &d_up_val, &lp->ffn_up); matmul_backward_w(&lg->ffn_up, &d_up_val, &xm_view); /* Combine gradients for xm */ for (int i = 0; i < target_len * n_embd; i++) { d_xm.data[i] = d_xm_gate.data[i] + d_xm_up.data[i]; } tensor_free(&d_xm_gate); tensor_free(&d_xm_up); /* Channel mixing shift backward */ Tensor x_ln2_view = { .data = lc->x_ln2.data, .rows = target_len, .cols = n_embd, .size = target_len * n_embd }; channel_mix_shift_backward(&d_x_ln2, &lg->channel_mix, &d_xm, &x_ln2_view, &lc->cm_mix_sig,target_len, n_embd); /* Layer norm 2 backward */ Tensor x_after_tm_view = { .data = lc->x_after_tm.data, .rows = target_len, .cols = n_embd, .size = target_len * n_embd }; layer_norm_backward_seq(&d_x_after_tm, &lg->ln2_weight, &lg->ln2_bias, &d_x_ln2, &x_after_tm_view, &lp->ln2_weight); /* Add residual gradient */ for (int i = 0; i < target_len * n_embd; i++) { d_x_after_tm.data[i] += d_x_final.data[i]; } tensor_free(&d_cm_out); tensor_free(&d_hidden); tensor_free(&d_gate_silu); tensor_free(&d_up_val); tensor_free(&d_gate_pre); tensor_free(&d_xm); tensor_free(&d_x_ln2); /* ============ TimeMix Backward ============ */ Tensor d_tm_out = tensor_alloc(target_len, n_embd); Tensor d_wkv_r = tensor_alloc(target_len, n_embd); Tensor d_wkv = tensor_alloc(target_len, n_embd); Tensor d_r = tensor_alloc(target_len, n_embd); Tensor d_k_exp = tensor_alloc(target_len, n_embd); Tensor d_v = tensor_alloc(target_len, n_embd); Tensor d_decay = tensor_alloc(target_len, n_embd); Tensor d_time_first_exp = tensor_alloc_1d(n_embd); Tensor d_decay_pre = tensor_alloc(target_len, n_embd); Tensor d_decay_delta = tensor_alloc(target_len, n_embd); Tensor d_decay_tmp = tensor_alloc(target_len, lora_rank); Tensor d_x_ln1_decay = tensor_alloc(target_len, n_embd); Tensor d_r_pre = tensor_alloc(target_len, n_embd); Tensor d_k_pre = tensor_alloc(target_len, n_embd); Tensor d_xr = tensor_alloc(target_len, n_embd); Tensor d_xk = tensor_alloc(target_len, n_embd); Tensor d_xv = tensor_alloc(target_len, n_embd); Tensor d_x_ln1 = tensor_alloc(target_len, n_embd); Tensor d_x_shifted = tensor_alloc(target_len, n_embd); Tensor d_x_layer_in = tensor_alloc(target_len, n_embd); /* d_tm_out comes from residual: x_after_tm = x + tm_out */ tensor_copy(&d_tm_out, &d_x_after_tm); /* tm_out = wkv_r @ Wo backward */ Tensor wkv_r_view = { .data = lc->wkv_r.data, .rows = target_len, .cols = n_embd, .size = target_len * n_embd }; matmul_backward_x(&d_wkv_r, &d_tm_out, &lp->Wo); matmul_backward_w(&lg->Wo, &d_tm_out, &wkv_r_view); /* wkv_r = wkv * r */ for (int i = 0; i < target_len * n_embd; i++) { d_wkv.data[i] = d_wkv_r.data[i] * lc->r.data[i]; d_r.data[i] = d_wkv_r.data[i] * lc->wkv.data[i]; } /* decay_tmp = x_ln1 @ decay_lora_a backward */ Tensor x_ln1_view = { .data = lc->x_ln1.data, .rows = target_len, .cols = n_embd, .size = target_len * n_embd }; /* WKV backward - updated call with multi-head ALiBi */ /* Memory gate backward */ Tensor d_gate_write_logits = tensor_alloc(target_len, cfg->n_mem_slots); Tensor d_gate_read_logits = tensor_alloc(target_len, cfg->n_mem_slots); /* wkv_backward now fills these */ wkv_backward(&d_k_exp, &d_v, &d_decay, &d_time_first_exp, &d_gate_write_logits, &d_gate_read_logits, &d_wkv, &lc->k_exp, &lc->v, &lc->decay, &lc->time_first_exp, lc->num_states, lc->den_states, lc->write_gates, lc->read_gates, target_len, n_embd, cfg->n_head, cfg->n_mem_slots, lp->alibi_slopes.data); /* gate_write_logits = x_ln1 @ mem_gate_write backward */ Tensor d_x_ln1_wgate = tensor_alloc(target_len, n_embd); matmul_backward_x(&d_x_ln1_wgate, &d_gate_write_logits, &lp->mem_gate_write); matmul_backward_w(&lg->mem_gate_write, &d_gate_write_logits, &x_ln1_view); Tensor d_x_ln1_rgate = tensor_alloc(target_len, n_embd); matmul_backward_x(&d_x_ln1_rgate, &d_gate_read_logits, &lp->mem_gate_read); matmul_backward_w(&lg->mem_gate_read, &d_gate_read_logits, &x_ln1_view); /* Add to d_x_ln1 later when combining all x_ln1 gradients */ /* After token_mixing_backward and before multi_scale_shift_backward: */ for (int i = 0; i < target_len * n_embd; i++) { d_x_ln1.data[i] += d_x_ln1_decay.data[i] + d_x_ln1_wgate.data[i] + d_x_ln1_rgate.data[i]; } tensor_free(&d_gate_write_logits); tensor_free(&d_gate_read_logits); tensor_free(&d_x_ln1_wgate); tensor_free(&d_x_ln1_rgate); /* time_first backward: time_first_exp = exp(time_first) */ for (int i = 0; i < n_embd; i++) { float x = lp->time_first.data[i]; if (x >= -10.0f && x <= 10.0f) { lg->time_first.data[i] += d_time_first_exp.data[i] * lc->time_first_exp.data[i]; } } /* decay backward: decay = sigmoid(decay_pre) */ for (int i = 0; i < target_len * n_embd; i++) { float y = lc->decay.data[i]; d_decay_pre.data[i] = d_decay.data[i] * y * (1.0f - y); } /* decay_pre = decay_base + decay_delta */ for (int t = 0; t < target_len; t++) { for (int i = 0; i < n_embd; i++) { int idx = t * n_embd + i; lg->decay_base.data[i] += d_decay_pre.data[idx]; d_decay_delta.data[idx] = d_decay_pre.data[idx]; } } /* decay_delta = decay_tmp @ decay_lora_b backward */ Tensor decay_tmp_view = { .data = lc->decay_tmp.data, .rows = target_len, .cols = lora_rank, .size = target_len * lora_rank }; matmul_backward_x(&d_decay_tmp, &d_decay_delta, &lp->decay_lora_b); matmul_backward_w(&lg->decay_lora_b, &d_decay_delta, &decay_tmp_view); matmul_backward_x(&d_x_ln1_decay, &d_decay_tmp, &lp->decay_lora_a); matmul_backward_w(&lg->decay_lora_a, &d_decay_tmp, &x_ln1_view); /* r backward: r = sigmoid(r_pre) */ sigmoid_backward(d_r_pre.data, d_r.data, lc->r.data, target_len * n_embd); /* k_exp backward: k_exp = exp(k_pre) (clamped) */ exp_backward_clamped(d_k_pre.data, d_k_exp.data, lc->k_pre.data, target_len * n_embd); /* r_pre = xr @ Wr, k_pre = xk @ Wk, v = xv @ Wv backward */ Tensor xr_view = { .data = lc->xr.data, .rows = target_len, .cols = n_embd, .size = target_len * n_embd }; Tensor xk_view = { .data = lc->xk.data, .rows = target_len, .cols = n_embd, .size = target_len * n_embd }; Tensor xv_view = { .data = lc->xv.data, .rows = target_len, .cols = n_embd, .size = target_len * n_embd }; matmul_backward_x(&d_xr, &d_r_pre, &lp->Wr); matmul_backward_w(&lg->Wr, &d_r_pre, &xr_view); matmul_backward_x(&d_xk, &d_k_pre, &lp->Wk); matmul_backward_w(&lg->Wk, &d_k_pre, &xk_view); matmul_backward_x(&d_xv, &d_v, &lp->Wv); matmul_backward_w(&lg->Wv, &d_v, &xv_view); /* Token mixing backward */ token_mixing_backward(&d_x_ln1, &d_x_shifted, &lg->time_mix_r, &lg->time_mix_k, &lg->time_mix_v, &d_xr, &d_xk, &d_xv, &lc->x_ln1, &lc->x_shifted, &lc->mix_r_sig, &lc->mix_k_sig, &lc->mix_v_sig, target_len, n_embd); /* Add gradient from decay LoRA path */ for (int i = 0; i < target_len * n_embd; i++) { d_x_ln1.data[i] += d_x_ln1_decay.data[i]; } /* Multi-scale shift backward */ Tensor d_x_ln1_shift = tensor_alloc(target_len, n_embd); multi_scale_shift_backward(&d_x_ln1_shift, &lg->time_shift_w1, &lg->time_shift_w2, &lg->time_shift_w4, &d_x_shifted, &lc->x_ln1, &lc->shift_w1_sig, &lc->shift_w2_sig, &lc->shift_w4_sig, &lc->shift_w_sum, target_len, n_embd); /* Add shift gradients to x_ln1 */ for (int i = 0; i < target_len * n_embd; i++) { d_x_ln1.data[i] += d_x_ln1_shift.data[i]; } tensor_free(&d_x_ln1_shift); Tensor layer_input; if (layer_idx == 0) { layer_input.data = cache->x_ln0.data; layer_input.rows = target_len; layer_input.cols = n_embd; layer_input.size = target_len * n_embd; } else { layer_input = tensor_alloc(target_len, n_embd); for (int i = 0; i < target_len * n_embd; i++) { layer_input.data[i] = lc->x_after_tm.data[i] - lc->tm_out.data[i]; } } /* Layer norm 1 backward - use cached layer input */ Tensor layer_input_view = { .data = lc->x_in.data, .rows = target_len, .cols = n_embd, .size = target_len * n_embd }; layer_norm_backward_seq(&d_x_layer_in, &lg->ln1_weight, &lg->ln1_bias, &d_x_ln1, &layer_input_view, &lp->ln1_weight); if (layer_idx > 0) { tensor_free(&layer_input); } /* Add residual gradient from TM block */ for (int i = 0; i < target_len * n_embd; i++) { d_x_layer_in.data[i] += d_x_after_tm.data[i]; } /* This gradient flows to the previous layer or initial embedding */ tensor_copy(&d_x_final, &d_x_layer_in); /* Free layer-specific gradients */ tensor_free(&d_tm_out); tensor_free(&d_wkv_r); tensor_free(&d_wkv); tensor_free(&d_r); tensor_free(&d_k_exp); tensor_free(&d_v); tensor_free(&d_decay); tensor_free(&d_time_first_exp); tensor_free(&d_decay_pre); tensor_free(&d_decay_delta); tensor_free(&d_decay_tmp); tensor_free(&d_x_ln1_decay); tensor_free(&d_r_pre); tensor_free(&d_k_pre); tensor_free(&d_xr); tensor_free(&d_xk); tensor_free(&d_xv); tensor_free(&d_x_ln1); tensor_free(&d_x_shifted); tensor_free(&d_x_layer_in); tensor_free(&d_x_after_tm); } /* Initial layer norm backward */ Tensor d_emb_out = tensor_alloc(target_len, n_embd); Tensor emb_out_view = { .data = cache->emb_out.data, .rows = target_len, .cols = n_embd, .size = target_len * n_embd }; layer_norm_backward_seq(&d_emb_out, &grads->ln0_weight, &grads->ln0_bias, &d_x_final, &emb_out_view, &mp->ln0_weight); /* Embedding backward */ for (int t = 0; t < target_len; t++) { int tok = tokens[t]; for (int i = 0; i < n_embd; i++) { grads->emb.data[tok * n_embd + i] += d_emb_out.data[t * n_embd + i]; } } tensor_free(&d_logits); tensor_free(&d_x_ln_out); tensor_free(&d_x_final); tensor_free(&d_x_tmp); tensor_free(&d_emb_out); } /* ============================================================ * Adam Optimizer * ============================================================ */ typedef struct { float beta1; float beta2; float epsilon; float weight_decay; int t; /* timestep */ } AdamConfig; typedef struct { Tensor m; /* First moment */ Tensor v; /* Second moment */ } AdamState; typedef struct { AdamState emb; AdamState ln0_weight, ln0_bias; struct { AdamState ln1_weight, ln1_bias; AdamState ln2_weight, ln2_bias; AdamState time_shift_w1, time_shift_w2, time_shift_w4; AdamState time_mix_r, time_mix_k, time_mix_v; AdamState decay_lora_a, decay_lora_b; AdamState decay_base; AdamState time_first; AdamState Wr, Wk, Wv, Wo; AdamState channel_mix; AdamState ffn_gate, ffn_up, ffn_down; AdamState mem_gate_write; AdamState mem_gate_read; } *layers; AdamState ln_out_weight, ln_out_bias; AdamState head; int n_layers; } AdamStates; static void init_adam_state(AdamState *as, int rows, int cols) { as->m = tensor_alloc(rows, cols); as->v = tensor_alloc(rows, cols); } static void free_adam_state(AdamState *as) { tensor_free(&as->m); tensor_free(&as->v); } static void init_adam_states(AdamStates *as, const lrnnConfig *cfg) { int n_embd = cfg->n_embd; int vocab_size = cfg->vocab_size; int ffn_h = ffn_hidden(cfg); int lora_rank = cfg->decay_lora_rank; as->n_layers = cfg->n_layer; init_adam_state(&as->emb, vocab_size, n_embd); init_adam_state(&as->ln0_weight, n_embd, 1); init_adam_state(&as->ln0_bias, n_embd, 1); as->layers = calloc((size_t)cfg->n_layer, sizeof(*as->layers)); for (int i = 0; i < cfg->n_layer; i++) { init_adam_state(&as->layers[i].ln1_weight, n_embd, 1); init_adam_state(&as->layers[i].ln1_bias, n_embd, 1); init_adam_state(&as->layers[i].ln2_weight, n_embd, 1); init_adam_state(&as->layers[i].ln2_bias, n_embd, 1); init_adam_state(&as->layers[i].time_shift_w1, n_embd, 1); init_adam_state(&as->layers[i].time_shift_w2, n_embd, 1); init_adam_state(&as->layers[i].time_shift_w4, n_embd, 1); init_adam_state(&as->layers[i].time_mix_r, n_embd, 1); init_adam_state(&as->layers[i].time_mix_k, n_embd, 1); init_adam_state(&as->layers[i].time_mix_v, n_embd, 1); init_adam_state(&as->layers[i].decay_lora_a, n_embd, lora_rank); init_adam_state(&as->layers[i].decay_lora_b, lora_rank, n_embd); init_adam_state(&as->layers[i].decay_base, n_embd, 1); init_adam_state(&as->layers[i].time_first, n_embd, 1); init_adam_state(&as->layers[i].Wr, n_embd, n_embd); init_adam_state(&as->layers[i].Wk, n_embd, n_embd); init_adam_state(&as->layers[i].Wv, n_embd, n_embd); init_adam_state(&as->layers[i].Wo, n_embd, n_embd); init_adam_state(&as->layers[i].channel_mix, n_embd, 1); init_adam_state(&as->layers[i].ffn_gate, n_embd, ffn_h); init_adam_state(&as->layers[i].ffn_up, n_embd, ffn_h); init_adam_state(&as->layers[i].ffn_down, ffn_h, n_embd); init_adam_state(&as->layers[i].mem_gate_write, n_embd, cfg->n_mem_slots); init_adam_state(&as->layers[i].mem_gate_read, n_embd, cfg->n_mem_slots); } init_adam_state(&as->ln_out_weight, n_embd, 1); init_adam_state(&as->ln_out_bias, n_embd, 1); init_adam_state(&as->head, n_embd, vocab_size); } static void free_adam_states(AdamStates *as) { free_adam_state(&as->emb); free_adam_state(&as->ln0_weight); free_adam_state(&as->ln0_bias); for (int i = 0; i < as->n_layers; i++) { free_adam_state(&as->layers[i].ln1_weight); free_adam_state(&as->layers[i].ln1_bias); free_adam_state(&as->layers[i].ln2_weight); free_adam_state(&as->layers[i].ln2_bias); free_adam_state(&as->layers[i].time_shift_w1); free_adam_state(&as->layers[i].time_shift_w2); free_adam_state(&as->layers[i].time_shift_w4); free_adam_state(&as->layers[i].time_mix_r); free_adam_state(&as->layers[i].time_mix_k); free_adam_state(&as->layers[i].time_mix_v); free_adam_state(&as->layers[i].decay_lora_a); free_adam_state(&as->layers[i].decay_lora_b); free_adam_state(&as->layers[i].decay_base); free_adam_state(&as->layers[i].time_first); free_adam_state(&as->layers[i].Wr); free_adam_state(&as->layers[i].Wk); free_adam_state(&as->layers[i].Wv); free_adam_state(&as->layers[i].Wo); free_adam_state(&as->layers[i].channel_mix); free_adam_state(&as->layers[i].ffn_gate); free_adam_state(&as->layers[i].ffn_up); free_adam_state(&as->layers[i].ffn_down); free_adam_state(&as->layers[i].mem_gate_write); free_adam_state(&as->layers[i].mem_gate_read); } free(as->layers); as->layers = NULL; free_adam_state(&as->ln_out_weight); free_adam_state(&as->ln_out_bias); free_adam_state(&as->head); } /* ============================================================ * Adam Optimizer Update Step * ============================================================ */ static void adam_update(Tensor *param, Tensor *grad, AdamState *state, AdamConfig *config, float lr) { float beta1 = config->beta1; float beta2 = config->beta2; float eps = config->epsilon; float wd = config->weight_decay; int t = config->t; /* Bias correction factors */ float bias_correction1 = 1.0f - powf(beta1, (float)t); float bias_correction2 = 1.0f - powf(beta2, (float)t); for (int i = 0; i < param->size; i++) { float g = grad->data[i]; /* Gradient clipping */ // if (g > GRAD_CLIP) g = GRAD_CLIP; // if (g < -GRAD_CLIP) g = -GRAD_CLIP; /* Update biased first moment estimate */ state->m.data[i] = beta1 * state->m.data[i] + (1.0f - beta1) * g; /* Update biased second raw moment estimate */ state->v.data[i] = beta2 * state->v.data[i] + (1.0f - beta2) * g * g; /* Compute bias-corrected estimates */ float m_hat = state->m.data[i] / bias_correction1; float v_hat = state->v.data[i] / bias_correction2; /* Update parameters with AdamW weight decay */ param->data[i] -= lr * (m_hat / (sqrtf(v_hat) + eps) + wd * param->data[i]); } } /* ============================================================ * Gradient Norm Clipping * ============================================================ */ static float compute_tensor_norm_sq(const Tensor *t) { float sum = 0.0f; for (int i = 0; i < t->size; i++) { sum += t->data[i] * t->data[i]; } return sum; } static void scale_tensor(Tensor *t, float scale) { for (int i = 0; i < t->size; i++) { t->data[i] *= scale; } } static void clip_gradients_by_global_norm(ModelGrads *grads, float max_norm) { /* Compute global L2 norm of all gradients */ float total_norm_sq = 0.0f; total_norm_sq += compute_tensor_norm_sq(&grads->emb); total_norm_sq += compute_tensor_norm_sq(&grads->ln0_weight); total_norm_sq += compute_tensor_norm_sq(&grads->ln0_bias); for (int i = 0; i < grads->n_layers; i++) { LayerGrads *lg = &grads->layers[i]; total_norm_sq += compute_tensor_norm_sq(&lg->ln1_weight); total_norm_sq += compute_tensor_norm_sq(&lg->ln1_bias); total_norm_sq += compute_tensor_norm_sq(&lg->ln2_weight); total_norm_sq += compute_tensor_norm_sq(&lg->ln2_bias); total_norm_sq += compute_tensor_norm_sq(&lg->time_shift_w1); total_norm_sq += compute_tensor_norm_sq(&lg->time_shift_w2); total_norm_sq += compute_tensor_norm_sq(&lg->time_shift_w4); total_norm_sq += compute_tensor_norm_sq(&lg->time_mix_r); total_norm_sq += compute_tensor_norm_sq(&lg->time_mix_k); total_norm_sq += compute_tensor_norm_sq(&lg->time_mix_v); total_norm_sq += compute_tensor_norm_sq(&lg->decay_lora_a); total_norm_sq += compute_tensor_norm_sq(&lg->decay_lora_b); total_norm_sq += compute_tensor_norm_sq(&lg->decay_base); total_norm_sq += compute_tensor_norm_sq(&lg->time_first); total_norm_sq += compute_tensor_norm_sq(&lg->Wr); total_norm_sq += compute_tensor_norm_sq(&lg->Wk); total_norm_sq += compute_tensor_norm_sq(&lg->Wv); total_norm_sq += compute_tensor_norm_sq(&lg->Wo); total_norm_sq += compute_tensor_norm_sq(&lg->channel_mix); total_norm_sq += compute_tensor_norm_sq(&lg->ffn_gate); total_norm_sq += compute_tensor_norm_sq(&lg->ffn_up); total_norm_sq += compute_tensor_norm_sq(&lg->ffn_down); total_norm_sq += compute_tensor_norm_sq(&lg->mem_gate_write); total_norm_sq += compute_tensor_norm_sq(&lg->mem_gate_read); } total_norm_sq += compute_tensor_norm_sq(&grads->ln_out_weight); total_norm_sq += compute_tensor_norm_sq(&grads->ln_out_bias); total_norm_sq += compute_tensor_norm_sq(&grads->head); float total_norm = sqrtf(total_norm_sq); /* Scale gradients if norm exceeds max */ if (total_norm > max_norm) { float scale = max_norm / (total_norm + 1e-8f); scale_tensor(&grads->emb, scale); scale_tensor(&grads->ln0_weight, scale); scale_tensor(&grads->ln0_bias, scale); for (int i = 0; i < grads->n_layers; i++) { LayerGrads *lg = &grads->layers[i]; scale_tensor(&lg->ln1_weight, scale); scale_tensor(&lg->ln1_bias, scale); scale_tensor(&lg->ln2_weight, scale); scale_tensor(&lg->ln2_bias, scale); scale_tensor(&lg->time_shift_w1, scale); scale_tensor(&lg->time_shift_w2, scale); scale_tensor(&lg->time_shift_w4, scale); scale_tensor(&lg->time_mix_r, scale); scale_tensor(&lg->time_mix_k, scale); scale_tensor(&lg->time_mix_v, scale); scale_tensor(&lg->decay_lora_a, scale); scale_tensor(&lg->decay_lora_b, scale); scale_tensor(&lg->decay_base, scale); scale_tensor(&lg->time_first, scale); scale_tensor(&lg->Wr, scale); scale_tensor(&lg->Wk, scale); scale_tensor(&lg->Wv, scale); scale_tensor(&lg->Wo, scale); scale_tensor(&lg->channel_mix, scale); scale_tensor(&lg->ffn_gate, scale); scale_tensor(&lg->ffn_up, scale); scale_tensor(&lg->ffn_down, scale); scale_tensor(&lg->mem_gate_write, scale); scale_tensor(&lg->mem_gate_read, scale); } scale_tensor(&grads->ln_out_weight, scale); scale_tensor(&grads->ln_out_bias, scale); scale_tensor(&grads->head, scale); } } /* ============================================================ * Apply Adam Updates to All Parameters * ============================================================ */ static void apply_adam_updates(ModelParams *mp, ModelGrads *grads, AdamStates *adam, AdamConfig *config, float lr) { /* Increment timestep */ config->t++; /* Embedding and initial layer norm */ adam_update(&mp->emb, &grads->emb, &adam->emb, config, lr); adam_update(&mp->ln0_weight, &grads->ln0_weight, &adam->ln0_weight, config, lr); adam_update(&mp->ln0_bias, &grads->ln0_bias, &adam->ln0_bias, config, lr); /* Per-layer parameters */ for (int i = 0; i < mp->n_layers; i++) { LayerParams *lp = &mp->layers[i]; LayerGrads *lg = &grads->layers[i]; /* Layer norms */ adam_update(&lp->ln1_weight, &lg->ln1_weight, &adam->layers[i].ln1_weight, config, lr); adam_update(&lp->ln1_bias, &lg->ln1_bias, &adam->layers[i].ln1_bias, config, lr); adam_update(&lp->ln2_weight, &lg->ln2_weight, &adam->layers[i].ln2_weight, config, lr); adam_update(&lp->ln2_bias, &lg->ln2_bias, &adam->layers[i].ln2_bias, config, lr); /* Multi-scale token shift */ adam_update(&lp->time_shift_w1, &lg->time_shift_w1, &adam->layers[i].time_shift_w1, config, lr); adam_update(&lp->time_shift_w2, &lg->time_shift_w2, &adam->layers[i].time_shift_w2, config, lr); adam_update(&lp->time_shift_w4, &lg->time_shift_w4, &adam->layers[i].time_shift_w4, config, lr); /* Token mixing ratios */ adam_update(&lp->time_mix_r, &lg->time_mix_r, &adam->layers[i].time_mix_r, config, lr); adam_update(&lp->time_mix_k, &lg->time_mix_k, &adam->layers[i].time_mix_k, config, lr); adam_update(&lp->time_mix_v, &lg->time_mix_v, &adam->layers[i].time_mix_v, config, lr); /* Data-dependent decay */ adam_update(&lp->decay_lora_a, &lg->decay_lora_a, &adam->layers[i].decay_lora_a, config, lr); adam_update(&lp->decay_lora_b, &lg->decay_lora_b, &adam->layers[i].decay_lora_b, config, lr); adam_update(&lp->decay_base, &lg->decay_base, &adam->layers[i].decay_base, config, lr); adam_update(&lp->time_first, &lg->time_first, &adam->layers[i].time_first, config, lr); /* Projections */ adam_update(&lp->Wr, &lg->Wr, &adam->layers[i].Wr, config, lr); adam_update(&lp->Wk, &lg->Wk, &adam->layers[i].Wk, config, lr); adam_update(&lp->Wv, &lg->Wv, &adam->layers[i].Wv, config, lr); adam_update(&lp->Wo, &lg->Wo, &adam->layers[i].Wo, config, lr); /* Channel mix and FFN */ adam_update(&lp->channel_mix, &lg->channel_mix, &adam->layers[i].channel_mix, config, lr); adam_update(&lp->ffn_gate, &lg->ffn_gate, &adam->layers[i].ffn_gate, config, lr); adam_update(&lp->ffn_up, &lg->ffn_up, &adam->layers[i].ffn_up, config, lr); adam_update(&lp->ffn_down, &lg->ffn_down, &adam->layers[i].ffn_down, config, lr); adam_update(&lp->mem_gate_write, &lg->mem_gate_write, &adam->layers[i].mem_gate_write, config, lr); adam_update(&lp->mem_gate_read, &lg->mem_gate_read, &adam->layers[i].mem_gate_read, config, lr); } /* Output layer norm and head */ adam_update(&mp->ln_out_weight, &grads->ln_out_weight, &adam->ln_out_weight, config, lr); adam_update(&mp->ln_out_bias, &grads->ln_out_bias, &adam->ln_out_bias, config, lr); adam_update(&mp->head, &grads->head, &adam->head, config, lr); } /* ============================================================ * Training Function * ============================================================ */ static void train_model(const char *corpus_path, const char *save_path, int epochs, lrnnConfig *cfg, float lr, TokenizerType tok_type, bool auto_config) { printf("======================================================================\n"); printf(" lrnn-like Model - Training (C Implementation)\n"); printf("======================================================================\n\n"); /* Load corpus */ printf("Loading corpus: %s\n", corpus_path); FILE *f = fopen(corpus_path, "rb"); if (!f) { fprintf(stderr, "Error: cannot open corpus file: %s\n", corpus_path); return; } fseek(f, 0, SEEK_END); long file_size = ftell(f); fseek(f, 0, SEEK_SET); if (file_size <= 0) { fprintf(stderr, "Error: empty or invalid corpus file\n"); fclose(f); return; } char *text = (char *)malloc((size_t)file_size + 1); if (!text) { fprintf(stderr, "Error: cannot allocate memory for corpus\n"); fclose(f); return; } size_t read_size = fread(text, 1, (size_t)file_size, f); text[read_size] = '\0'; fclose(f); printf(" Loaded %zu bytes\n", read_size); /* Build tokenizer */ printf("\nBuilding tokenizer...\n"); Tokenizer tok; init_tokenizer(&tok, tok_type); build_tokenizer(&tok, text, read_size, tok_type); int vocab_size = tokenizer_vocab_size(&tok); printf(" Vocabulary size: %d %s\n", vocab_size, tok.type == TOKENIZER_CHAR ? "characters" : "words"); /* Auto-configure model if requested */ if (auto_config) { printf("\nAuto-configuring model for corpus size...\n"); *cfg = config_for_corpus(file_size, tok.type, vocab_size); } else { cfg->vocab_size = vocab_size; } /* Tokenize */ int token_count; int *tokens = tokenizer_encode(&tok, text, read_size, &token_count); free(text); printf(" Token count: %d\n", token_count); if (tok.type == TOKENIZER_WORD) { printf(" Compression ratio: %.2fx\n", (float)read_size / (float)token_count); } /* Initialize model */ printf("\nInitializing model...\n"); printf(" Layers: %d\n", cfg->n_layer); printf(" Embedding dim: %d\n", cfg->n_embd); printf(" FFN hidden: %d\n", ffn_hidden(cfg)); printf(" Context length: %d\n", cfg->ctx_len); printf(" LoRA rank: %d\n", cfg->decay_lora_rank); ModelParams mp; memset(&mp, 0, sizeof(mp)); init_model_params(&mp, cfg); /* Count parameters */ long total_params = 0; total_params += mp.emb.size; total_params += mp.ln0_weight.size + mp.ln0_bias.size; for (int i = 0; i < mp.n_layers; i++) { LayerParams *lp = &mp.layers[i]; total_params += lp->ln1_weight.size + lp->ln1_bias.size; total_params += lp->ln2_weight.size + lp->ln2_bias.size; total_params += lp->time_shift_w1.size + lp->time_shift_w2.size + lp->time_shift_w4.size; total_params += lp->time_mix_r.size + lp->time_mix_k.size + lp->time_mix_v.size; total_params += lp->decay_lora_a.size + lp->decay_lora_b.size; total_params += lp->decay_base.size + lp->time_first.size; total_params += lp->Wr.size + lp->Wk.size + lp->Wv.size + lp->Wo.size; total_params += lp->channel_mix.size; total_params += lp->ffn_gate.size + lp->ffn_up.size + lp->ffn_down.size; } total_params += mp.ln_out_weight.size + mp.ln_out_bias.size; total_params += mp.head.size; printf(" Total parameters: %ld (%.2f MB)\n", total_params, (float)total_params * sizeof(float) / (1024.0f * 1024.0f)); printf(" Params per byte: %.2f\n", (float)total_params / (float)read_size); /* Initialize gradients */ ModelGrads grads; memset(&grads, 0, sizeof(grads)); init_model_grads(&grads, &mp, cfg); /* Initialize Adam optimizer */ AdamStates adam; memset(&adam, 0, sizeof(adam)); init_adam_states(&adam, cfg); AdamConfig adam_cfg = { .beta1 = 0.9f, .beta2 = 0.999f, .epsilon = 1e-8f, .weight_decay = 0.0f, .t = 0 }; /* Training configuration */ int batch_size = cfg->ctx_len; if (batch_size > token_count - 1) { batch_size = token_count - 1; } int n_batches = (token_count - 1) / batch_size; if (n_batches < 1) n_batches = 1; /* Allocate forward cache */ ForwardCache cache; memset(&cache, 0, sizeof(cache)); init_forward_cache(&cache, batch_size, cfg); printf("\n======================================================================\n"); printf("Starting training...\n"); printf(" Tokenizer: %s\n", tok.type == TOKENIZER_CHAR ? "character" : "word"); printf(" Batch size: %d tokens\n", batch_size); printf(" Batches per epoch: %d\n", n_batches); printf("======================================================================\n\n"); time_t start_time = time(NULL); float best_loss = FLT_MAX; for (int epoch = 0; epoch < epochs; epoch++) { float epoch_loss = 0.0f; int batch_count = 0; /* Shuffle batches (just random offset each epoch) */ int offset = rand() % (batch_size > 10 ? 10 : 1); for (int batch = 0; batch < n_batches; batch++) { int start = offset + batch * batch_size; if (start + batch_size >= token_count) continue; int *batch_tokens = tokens + start; int seq_len = batch_size; /* Zero gradients */ zero_model_grads(&grads); /* Forward pass with caching */ float loss = forward_with_cache(&cache, batch_tokens, seq_len, &mp, cfg); /* Check for NaN/Inf */ if (!isfinite(loss)) { printf("Warning: NaN/Inf loss detected at epoch %d, batch %d. Skipping.\n", epoch + 1, batch); continue; } /* Backward pass */ backward_pass(&grads, &cache, batch_tokens, seq_len, &mp, cfg); /* Global gradient clipping */ clip_gradients_by_global_norm(&grads, 5.0f); /* Adam update */ apply_adam_updates(&mp, &grads, &adam, &adam_cfg, lr); epoch_loss += loss; batch_count++; /* Progress indicator */ if ((batch + 1) % 10 == 0 || batch == n_batches - 1) { printf("\r Epoch %d/%d - Batch %d/%d - Loss: %.4f", epoch + 1, epochs, batch + 1, n_batches, batch_count > 0 ? epoch_loss / (float)batch_count : 0.0f); fflush(stdout); } } /* Compute epoch statistics */ float avg_loss = (batch_count > 0) ? epoch_loss / (float)batch_count : 0.0f; float perplexity = expf(avg_loss); time_t elapsed = time(NULL) - start_time; int hours = (int)(elapsed / 3600); int mins = (int)((elapsed % 3600) / 60); int secs = (int)(elapsed % 60); printf("\n Epoch %d/%d complete - Loss: %.4f - Perplexity: %.2f - Time: %02d:%02d:%02d\n", epoch + 1, epochs, avg_loss, perplexity, hours, mins, secs); /* Track best loss */ if (avg_loss < best_loss) { best_loss = avg_loss; printf(" ** New best loss! **\n"); } /* Save checkpoint */ if ((epoch + 1) % 5 == 0 || epoch == epochs - 1) { printf(" Saving checkpoint to: %s\n", save_path); if (save_model(save_path, &mp, cfg, &tok) != 0) { fprintf(stderr, " Warning: failed to save checkpoint\n"); } } } /* Final save */ printf("\nSaving final model to: %s\n", save_path); save_model(save_path, &mp, cfg, &tok); /* Cleanup */ free(tokens); free_forward_cache(&cache); free_model_grads(&grads); free_adam_states(&adam); free_model_params(&mp); free_tokenizer(&tok); } /* ============================================================ * Generation with Dynamic State Checkpointing * ============================================================ */ /* ============================================================ * Generation with Dynamic State Checkpointing * ============================================================ */ /* Comparison function for qsort - descending by probability */ typedef struct { float prob; int index; } ProbIndex; static int prob_index_cmp_desc(const void *a, const void *b) { float pa = ((const ProbIndex *)a)->prob; float pb = ((const ProbIndex *)b)->prob; if (pa > pb) return -1; if (pa < pb) return 1; return 0; } static int sample_top_p(const float *probs, int vocab_size, float top_p) { /* Build (prob, index) pairs */ ProbIndex *pi = (ProbIndex *)malloc((size_t)vocab_size * sizeof(ProbIndex)); if (!pi) { /* Fallback: argmax */ int best = 0; for (int i = 1; i < vocab_size; i++) { if (probs[i] > probs[best]) best = i; } return best; } for (int i = 0; i < vocab_size; i++) { pi[i].prob = probs[i]; pi[i].index = i; } /* O(V log V) sort instead of O(V²) bubble sort */ qsort(pi, (size_t)vocab_size, sizeof(ProbIndex), prob_index_cmp_desc); /* Find top-p cutoff */ float cumsum = 0.0f; int cutoff = vocab_size; /* default: use all */ for (int i = 0; i < vocab_size; i++) { cumsum += pi[i].prob; if (cumsum >= top_p) { cutoff = i + 1; break; } } if (cutoff < 1) cutoff = 1; /* Renormalize over the kept tokens */ float sum = 0.0f; for (int i = 0; i < cutoff; i++) { sum += pi[i].prob; } /* Sample from the truncated distribution */ float r = ((float)rand() / (float)RAND_MAX) * sum; float running = 0.0f; int sampled = pi[0].index; for (int i = 0; i < cutoff; i++) { running += pi[i].prob; if (running >= r) { sampled = pi[i].index; break; } } free(pi); return sampled; } static void generate_text(const char *model_path, const char *seed_text, int n_tokens, float temperature, float top_p) { printf("======================================================================\n"); printf(" Text Generation\n"); printf("======================================================================\n\n"); /* Load model */ printf("Loading model: %s\n", model_path); ModelParams mp; lrnnConfig cfg; Tokenizer tok; memset(&mp, 0, sizeof(mp)); memset(&tok, 0, sizeof(tok)); if (load_model(model_path, &mp, &cfg, &tok) != 0) { fprintf(stderr, "Failed to load model\n"); return; } printf(" Model loaded!\n"); printf(" Tokenizer: %s\n", tok.type == TOKENIZER_CHAR ? "character" : "word"); printf(" Vocab size: %d\n", cfg.vocab_size); printf(" Layers: %d, Dim: %d\n", cfg.n_layer, cfg.n_embd); /* Initialize state */ ModelState state; init_model_state(&state, &cfg); printf("\nSeed: \"%s\"\n", seed_text); printf("Generating %d tokens (temp=%.2f, top_p=%.2f)\n\n", n_tokens, temperature, top_p); printf("======================================================================\n"); printf("%s", seed_text); fflush(stdout); float *logits = (float *)malloc((size_t)cfg.vocab_size * sizeof(float)); float *probs = (float *)malloc((size_t)cfg.vocab_size * sizeof(float)); /* Tokenize and process seed */ int seed_token_count; int *seed_tokens = tokenizer_encode(&tok, seed_text, strlen(seed_text), &seed_token_count); int last_token = 0; for (int i = 0; i < seed_token_count; i++) { forward_single(logits, seed_tokens[i], &mp, &state, &cfg); last_token = seed_tokens[i]; } free(seed_tokens); /* Generate tokens */ srand((unsigned int)time(NULL)); char decode_buf[MAX_WORD_LEN]; for (int i = 0; i < n_tokens; i++) { forward_single(logits, last_token, &mp, &state, &cfg); /* Apply temperature */ if (temperature != 1.0f) { for (int j = 0; j < cfg.vocab_size; j++) { logits[j] /= temperature; } } softmax_vec(probs, logits, cfg.vocab_size); /* Sample */ int next_token = sample_top_p(probs, cfg.vocab_size, top_p); /* Decode and print */ tokenizer_decode_token(&tok, next_token, decode_buf, sizeof(decode_buf)); printf("%s", decode_buf); fflush(stdout); last_token = next_token; } printf("\n======================================================================\n"); /* Cleanup */ free(logits); free(probs); free_model_state(&state); free_model_params(&mp); free_tokenizer(&tok); } /* ============================================================ * Main Entry Point * ============================================================ */ static void print_usage(const char *prog) { printf("Usage:\n"); printf(" Training:\n"); printf(" %s --train corpus.txt --save model.bin [options]\n\n", prog); printf(" Generation:\n"); printf(" %s --load model.bin --seed \"text\" [options]\n\n", prog); printf("Options:\n"); printf(" --train FILE Path to training corpus\n"); printf(" --save FILE Path to save model\n"); printf(" --load FILE Path to load model\n"); printf(" --seed TEXT Seed text for generation\n"); printf(" --epochs N Training epochs (default: 20)\n"); printf(" --tokens N Tokens to generate (default: 200)\n"); printf(" --layers N Number of layers (default: auto)\n"); printf(" --dim N Embedding dimension (default: auto)\n"); printf(" --heads N Number of heads (default: auto)\n"); printf(" --ctx N Max context length (default: auto)\n"); printf(" --lr FLOAT Learning rate (default: 0.0003)\n"); printf(" --temp FLOAT Temperature (default: 0.8)\n"); printf(" --top_p FLOAT Top-p sampling (default: 0.9)\n"); printf(" --tokenizer TYPE char, word, or auto (default: auto)\n"); printf(" --auto-config Auto-configure model size (default: on)\n"); printf(" --no-auto-config Disable auto-configuration\n"); printf(" --help Show this help\n"); } int main(int argc, char *argv[]) { static struct option long_options[] = { {"train", required_argument, 0, 't'}, {"save", required_argument, 0, 's'}, {"load", required_argument, 0, 'l'}, {"seed", required_argument, 0, 'S'}, {"epochs", required_argument, 0, 'e'}, {"tokens", required_argument, 0, 'n'}, {"layers", required_argument, 0, 'L'}, {"dim", required_argument, 0, 'd'}, {"heads", required_argument, 0, 'h'}, {"ctx", required_argument, 0, 'c'}, {"lr", required_argument, 0, 'r'}, {"temp", required_argument, 0, 'T'}, {"top_p", required_argument, 0, 'p'}, {"tokenizer", required_argument, 0, 'k'}, {"auto-config", no_argument, 0, 'A'}, {"no-auto-config", no_argument, 0, 'N'}, {"help", no_argument, 0, 'H'}, {0, 0, 0, 0} }; char *train_path = NULL; char *save_path = NULL; char *load_path = NULL; char *seed_text = NULL; int epochs = 20; int n_tokens = 200; float lr = 0.0003f; float temperature = 0.8f; float top_p = 0.9f; TokenizerType tok_type = TOKENIZER_AUTO; bool auto_config = true; bool manual_config = false; lrnnConfig cfg = default_config(); int opt; int option_index = 0; while ((opt = getopt_long(argc, argv, "t:s:l:S:e:n:L:d:h:c:r:T:p:k:ANH", long_options, &option_index)) != -1) { switch (opt) { case 't': train_path = optarg; break; case 's': save_path = optarg; break; case 'l': load_path = optarg; break; case 'S': seed_text = optarg; break; case 'e': epochs = atoi(optarg); break; case 'n': n_tokens = atoi(optarg); break; case 'L': cfg.n_layer = atoi(optarg); manual_config = true; break; case 'd': cfg.n_embd = atoi(optarg); manual_config = true; break; case 'h': cfg.n_head = atoi(optarg); manual_config = true; break; case 'c': cfg.ctx_len = atoi(optarg); manual_config = true; break; case 'r': lr = (float)atof(optarg); break; case 'T': temperature = (float)atof(optarg); break; case 'p': top_p = (float)atof(optarg); break; case 'k': if (strcmp(optarg, "char") == 0) { tok_type = TOKENIZER_CHAR; } else if (strcmp(optarg, "word") == 0) { tok_type = TOKENIZER_WORD; } else { tok_type = TOKENIZER_AUTO; } break; case 'A': auto_config = true; break; case 'N': auto_config = false; break; case 'H': default: print_usage(argv[0]); return (opt == 'H') ? 0 : 1; } } /* If user specified any model params manually, disable auto-config */ if (manual_config) { auto_config = false; } /* Training mode */ if (train_path) { if (!save_path) { fprintf(stderr, "Error: --save is required when training\n"); return 1; } train_model(train_path, save_path, epochs, &cfg, lr, tok_type, auto_config); } /* Generation mode */ else if (load_path) { if (!seed_text || strlen(seed_text) == 0) { fprintf(stderr, "Error: --seed is required for generation\n"); return 1; } generate_text(load_path, seed_text, n_tokens, temperature, top_p); } else { print_usage(argv[0]); return 1; } return 0; }