import Accelerate import CPUOps import Darwin import Foundation import ANERuntime import ANETypes import Espresso @main enum EspressoTrainMain { private static let defaultCheckpointPath = "ane_stories110M_ckpt.bin" private static let defaultModelPath = "../../assets/models/stories110M.bin" private static let defaultDataPath = "tinystories_data00.bin" private struct Args { var resume: Bool = false var totalSteps: Int = 10_000 var lr: Float = 3e-4 var checkpointPath: String = defaultCheckpointPath var modelPath: String = defaultModelPath var dataPath: String = defaultDataPath var twoStepStudentSidecarPath: String? var generationModelExportPath: String? var localTextDatasetPath: String? var localBigramPrefix: String? var artifactLayerCount: Int = 6 var offlineAcceptanceJSONPath: String? var offlineRecurrentCheckpointPath: String? var offlineFutureSidecarPath: String? var promptToken: TokenID? var gateMaxNewTokens: Int = 8 var textRoots: [String] = [] var textExtensions: Set = ["swift", "md", "txt", "py", "sh", "m", "h", "c"] var maxCorpusFiles: Int? var maxCorpusBytes: Int? } private enum TrainExit { case finished case execRestart(step: Int, compileCount: Int, loss: Float) } static func main() { do { let exitReason = try train(args: parseArgs(CommandLine.arguments)) switch exitReason { case .finished: return case let .execRestart(step, compileCount, loss): ExecRestart.restart(step: step, compileCount: compileCount, loss: loss) } } catch { fputs("espresso-train error: \(error)\n", stderr) exit(1) } } private static func parseArgs(_ argv: [String]) -> Args { var a = Args() var i = 1 while i < argv.count { let arg = argv[i] switch arg { case "--resume": a.resume = true case "--steps": if i + 1 < argv.count { a.totalSteps = Int(argv[i + 1]) ?? a.totalSteps; i += 1 } case "--lr": if i + 1 < argv.count { a.lr = Float(argv[i + 1]) ?? a.lr; i += 1 } case "--ckpt": if i + 1 < argv.count { a.checkpointPath = argv[i + 1]; i += 1 } case "--model": if i + 1 < argv.count { a.modelPath = argv[i + 1]; i += 1 } case "--data": if i + 1 < argv.count { a.dataPath = argv[i + 1]; i += 1 } case "--export-two-step-student": if i + 1 < argv.count { a.twoStepStudentSidecarPath = argv[i + 1]; i += 1 } case "--export-generation-model": if i + 1 < argv.count { a.generationModelExportPath = argv[i + 1]; i += 1 } case "--build-local-text-dataset": if i + 1 < argv.count { a.localTextDatasetPath = argv[i + 1]; i += 1 } case "--export-local-bigram-prefix": if i + 1 < argv.count { a.localBigramPrefix = argv[i + 1]; i += 1 } case "--artifact-layer-count": if i + 1 < argv.count { a.artifactLayerCount = Int(argv[i + 1]) ?? a.artifactLayerCount; i += 1 } case "--offline-acceptance-json": if i + 1 < argv.count { a.offlineAcceptanceJSONPath = argv[i + 1]; i += 1 } case "--offline-recurrent-checkpoint": if i + 1 < argv.count { a.offlineRecurrentCheckpointPath = argv[i + 1]; i += 1 } case "--offline-future-sidecar": if i + 1 < argv.count { a.offlineFutureSidecarPath = argv[i + 1]; i += 1 } case "--prompt-token": if i + 1 < argv.count, let token = TokenID(argv[i + 1]) { a.promptToken = token; i += 1 } case "--gate-max-new-tokens": if i + 1 < argv.count { a.gateMaxNewTokens = Int(argv[i + 1]) ?? a.gateMaxNewTokens; i += 1 } case "--text-root": if i + 1 < argv.count { a.textRoots.append(argv[i + 1]); i += 1 } case "--text-ext": if i + 1 < argv.count { let ext = argv[i + 1].trimmingCharacters(in: .whitespacesAndNewlines).lowercased() if !ext.isEmpty { a.textExtensions.insert(ext) } i += 1 } case "--max-corpus-files": if i + 1 < argv.count { a.maxCorpusFiles = Int(argv[i + 1]); i += 1 } case "--max-corpus-bytes": if i + 1 < argv.count { a.maxCorpusBytes = Int(argv[i + 1]); i += 1 } default: break } i += 1 } return a } // MARK: - Timing private enum MachTime { private static let tb: mach_timebase_info_data_t = { var tb = mach_timebase_info_data_t() mach_timebase_info(&tb) return tb }() @inline(__always) static func now() -> UInt64 { mach_absolute_time() } @inline(__always) static func ms(_ delta: UInt64) -> Double { let nanos = (Double(delta) * Double(tb.numer)) / Double(tb.denom) return nanos / 1_000_000.0 } } // MARK: - Training private static func train(args: Args) throws -> TrainExit { setbuf(stdout, nil) CheckpointHeader.validateLayout() let dim = ModelConfig.dim let hidden = ModelConfig.hidden let heads = ModelConfig.heads let seqLen = ModelConfig.seqLen let nLayers = ModelConfig.nLayers let vocab = ModelConfig.vocab let beta1: Float = 0.9 let beta2: Float = 0.999 let eps: Float = 1e-8 let posix = Locale(identifier: "en_US_POSIX") if let datasetPath = args.localTextDatasetPath { let roots = args.textRoots.isEmpty ? [FileManager.default.currentDirectoryPath] : args.textRoots let tokens = try LocalTextTokenDatasetBuilder.collectTokens( roots: roots, allowedExtensions: args.textExtensions, maxFiles: args.maxCorpusFiles, maxBytes: args.maxCorpusBytes ) try LocalTextTokenDatasetBuilder.writeUInt16Dataset(tokens: tokens, to: datasetPath) print("[built local text dataset: \(datasetPath) tokens=\(tokens.count) roots=\(roots.count)]") return .finished } var offlineRecurrentCheckpointPath = args.offlineRecurrentCheckpointPath var offlineFutureSidecarPath = args.offlineFutureSidecarPath var offlinePromptToken = args.promptToken if let prefix = args.localBigramPrefix { let manifest = try LocalRealArtifactPipeline.exportLocalBigramArtifacts( datasetPath: args.dataPath, prefix: prefix, layerCount: args.artifactLayerCount, vocabSize: vocab ) offlineRecurrentCheckpointPath = offlineRecurrentCheckpointPath ?? manifest.recurrentCheckpointPath offlineFutureSidecarPath = offlineFutureSidecarPath ?? manifest.futureSidecarPath offlinePromptToken = offlinePromptToken ?? manifest.promptToken print( "[exported local bigram artifacts: manifest=\(manifest.manifestPath) prompt=\(manifest.promptToken) tokens=\(manifest.tokenCount)]" ) if args.offlineAcceptanceJSONPath == nil { return .finished } } if let offlineAcceptanceJSONPath = args.offlineAcceptanceJSONPath { guard let recurrentCheckpointPath = offlineRecurrentCheckpointPath else { throw GenerationError.invalidArguments("offline gate requires --offline-recurrent-checkpoint or --export-local-bigram-prefix") } guard let futureSidecarPath = offlineFutureSidecarPath else { throw GenerationError.invalidArguments("offline gate requires --offline-future-sidecar or --export-local-bigram-prefix") } guard let promptToken = offlinePromptToken else { throw GenerationError.invalidArguments("offline gate requires --prompt-token or an exported local-bigram manifest") } let trace = try LocalRealArtifactPipeline.offlineAcceptanceGate( recurrentCheckpointPath: recurrentCheckpointPath, futureSidecarPath: futureSidecarPath, promptTokens: [promptToken], maxNewTokens: args.gateMaxNewTokens ) let payload: [String: Any] = [ "prompt_tokens": [Int(promptToken)], "generated_tokens": trace.generatedTokens.map(Int.init), "committed_exact_token_counts": trace.committedExactTokenCounts, "accepted_future_token_counts": trace.acceptedFutureTokenCounts, "parity_status": trace.parityMatchedAllCommittedTokens ? "match" : "mismatch", "committed_exact_tokens_per_pass": trace.committedExactTokensPerPass, "accepted_future_tokens_per_pass": trace.acceptedFutureTokensPerPass, "recurrent_checkpoint": recurrentCheckpointPath, "future_sidecar": futureSidecarPath, ] let data = try JSONSerialization.data(withJSONObject: payload, options: [.prettyPrinted, .sortedKeys]) try data.write(to: URL(fileURLWithPath: offlineAcceptanceJSONPath), options: .atomic) print("[wrote offline acceptance gate: \(offlineAcceptanceJSONPath)]") return .finished } // Allocate per-layer state. let layers = LayerStorage(count: nLayers) { _ in LayerWeights() } let layerAdam = LayerStorage(count: nLayers) { _ in LayerAdam() } let acts = LayerStorage(count: nLayers) { _ in LayerActivations() } let grads = LayerStorage(count: nLayers) { _ in LayerGradients() } // Globals: final RMSNorm + embedding. let rmsFinal = TensorBuffer(count: dim, zeroed: false) let embed = TensorBuffer(count: vocab * dim, zeroed: false) // [vocab, dim] row-major let grmsFinal = TensorBuffer(count: dim, zeroed: true) let gembed = TensorBuffer(count: vocab * dim, zeroed: true) let adamRmsFinal = AdamState(count: dim) let adamEmbed = AdamState(count: vocab * dim) var totalSteps = args.totalSteps var lr = args.lr var cumCompile: Double = 0 var cumTrain: Double = 0 var cumWall: Double = 0 var cumSteps: Int = 0 var cumBatches: Int = 0 var adamT: Int = 0 var startStep: Int = 0 var lastLoss: Float = 999.0 var resuming: Bool = false if args.resume { do { let meta = try Checkpoint.load( path: args.checkpointPath, intoLayers: layers, intoLayerAdam: layerAdam, intoRmsFinal: rmsFinal, intoAdamRmsFinal: adamRmsFinal, intoEmbed: embed, intoAdamEmbed: adamEmbed ) startStep = meta.step totalSteps = meta.totalSteps lr = meta.lr lastLoss = meta.loss cumCompile = meta.cumCompile cumTrain = meta.cumTrain cumWall = meta.cumWall cumSteps = meta.cumSteps cumBatches = meta.cumBatches adamT = meta.adamT resuming = true print(String(format: "[RESUMED step %d, loss=%.4f]", startStep, lastLoss)) } catch { print("[resume failed: \(error)]") } } if !resuming { // Load pretrained weights. The checkpoint format assumes a shared embed/classifier. do { let pretrained = try ModelWeightLoader.load(from: args.modelPath) precondition(pretrained.sharedClassifier, "Checkpoint format assumes shared embed/classifier weights") for L in 0.. Float { // Match ObjC ordering/precision: (float scale) promoted to double, multiplied in double, // then cast once on assignment to Float. Float(scale * (2.0 * drand48() - 1.0)) } for L in 0..(count: nLayers, throwingInitializer: { _ in try StaticKernel() }) let accumulator = GradientAccumulator() Sampler.seed(startStep: startStep) // Scratch buffers (allocate once, reuse). let xCur = TensorBuffer(count: dim * seqLen, zeroed: false) let xFinal = TensorBuffer(count: dim * seqLen, zeroed: false) let logits = TensorBuffer(count: vocab * seqLen, zeroed: false) let dlogits = TensorBuffer(count: vocab * seqLen, zeroed: false) let dy = TensorBuffer(count: dim * seqLen, zeroed: false) let bwdScratch = BackwardScratch(dim: dim, hidden: hidden, seqLen: seqLen) let rmsWorkspace = RMSNorm.Workspace(seqLen: seqLen) let crossEntropyWorkspace = CrossEntropy.Workspace(vocabSize: vocab, seqLen: seqLen) // Token ID conversion buffers: dataset is on-disk UInt16; CPUOps expect TokenID (UInt32). var inputTokenIDs = [TokenID](repeating: 0, count: seqLen) var targetTokenIDs = [TokenID](repeating: 0, count: seqLen) var totalCompileMs: Double = 0 var totalTrainMs: Double = 0 var stepsDoneThisRun: Int = 0 var batchesThisRun: Int = 0 let wallStart = MachTime.now() @inline(__always) func stderrLine(_ line: String) { line.withCString { cstr in _ = fputs(cstr, stderr) _ = fputc(0x0A, stderr) } } @inline(__always) func adamUpdate( weights w: borrowing TensorBuffer, grads g: borrowing TensorBuffer, state s: borrowing AdamState, timestep: Int ) { w.withUnsafeMutablePointer { wPtr in g.withUnsafePointer { gPtr in s.m.withUnsafeMutablePointer { mPtr in s.v.withUnsafeMutablePointer { vPtr in AdamOptimizer.update( weights: wPtr, gradients: gPtr, m: mPtr, v: vPtr, count: s.count, timestep: timestep, lr: lr, beta1: beta1, beta2: beta2, eps: eps ) } } } } } var step = startStep while step < totalSteps { // Compile budget: exec restart if we can't compile another full weight-bearing batch. if CompileBudget.currentCount + ModelConfig.totalWeightKernels > ModelConfig.maxCompiles { let wallMs = MachTime.ms(MachTime.now() - wallStart) var meta = CheckpointMeta() meta.step = step meta.totalSteps = totalSteps meta.lr = lr meta.loss = lastLoss meta.cumCompile = cumCompile + totalCompileMs meta.cumTrain = cumTrain + totalTrainMs meta.cumWall = cumWall + wallMs meta.cumSteps = cumSteps + stepsDoneThisRun meta.cumBatches = cumBatches + batchesThisRun meta.adamT = adamT try Checkpoint.save( path: args.checkpointPath, meta: meta, layers: layers, layerAdam: layerAdam, rmsFinal: rmsFinal, adamRmsFinal: adamRmsFinal, embed: embed, adamEmbed: adamEmbed ) return .execRestart(step: step, compileCount: CompileBudget.currentCount, loss: lastLoss) } // Compile all layers' weight-bearing kernels. let tc0 = MachTime.now() let kernelStorage: LayerStorage do { kernelStorage = try LayerStorage(count: nLayers, throwingInitializer: { i in try LayerKernelSet(weights: layers[i]) }) } catch { // Compile failed: force budget exhaustion and restart on next iteration. try? CompileBudget.setCount(ModelConfig.maxCompiles) continue } let cms = MachTime.ms(MachTime.now() - tc0) totalCompileMs += cms let surfaceHandles: [LayerSurfaceHandles] do { surfaceHandles = try SurfaceHandleCache.build(kernels: kernelStorage, staticKernels: staticKernels) } catch { try? CompileBudget.setCount(ModelConfig.maxCompiles) continue } // Zero gradient accumulators (accumulate across micro-steps). for L in 0.. xCur. xCur.withUnsafeMutablePointer { xPtr in embed.withUnsafePointer { ePtr in inputTokenIDs.withUnsafeBufferPointer { tokensBuf in Embedding.lookup( output: xPtr, embedding: ePtr, tokens: tokensBuf.baseAddress!, vocabSize: vocab, dim: dim, seqLen: seqLen ) } } } stepTimings.tElem += MachTime.ms(MachTime.now() - t0) // Wait for prior async dW work before touching layer IO for this step. t0 = MachTime.now() accumulator.barrier() stepTimings.tCblasWait += MachTime.ms(MachTime.now() - t0) // Forward pass (12 layers). try ForwardPass.runTimed( xCur: xCur, acts: acts, kernels: kernelStorage, accumulator: accumulator, dim: dim, hidden: hidden, seqLen: seqLen, surfaceHandles: surfaceHandles, timings: &stepTimings ) // Final RMSNorm. t0 = MachTime.now() xFinal.withUnsafeMutablePointer { outPtr in xCur.withUnsafePointer { inPtr in rmsFinal.withUnsafePointer { wPtr in RMSNorm.forward( output: outPtr, input: inPtr, weights: wPtr, dim: dim, seqLen: seqLen, workspace: rmsWorkspace ) } } } stepTimings.tRms += MachTime.ms(MachTime.now() - t0) // Classifier: logits = embed @ xFinal. t0 = MachTime.now() logits.withUnsafeMutablePointer { logitsPtr in embed.withUnsafePointer { ePtr in xFinal.withUnsafePointer { xPtr in BLAS.sgemm( CblasRowMajor, CblasNoTrans, CblasNoTrans, m: Int32(vocab), n: Int32(seqLen), k: Int32(dim), alpha: 1.0, a: ePtr, lda: Int32(dim), b: xPtr, ldb: Int32(seqLen), beta: 0.0, c: logitsPtr, ldc: Int32(seqLen) ) } } } stepTimings.tCls += MachTime.ms(MachTime.now() - t0) // Cross-entropy loss + dlogits. t0 = MachTime.now() let loss = dlogits.withUnsafeMutablePointer { dlogitsPtr in logits.withUnsafePointer { logitsPtr in targetTokenIDs.withUnsafeBufferPointer { targetsBuf in CrossEntropy.lossAndGradient( dlogits: dlogitsPtr, logits: logitsPtr, targets: targetsBuf.baseAddress!, vocabSize: vocab, seqLen: seqLen, workspace: crossEntropyWorkspace ) } } } lastLoss = loss stepTimings.tElem += MachTime.ms(MachTime.now() - t0) // Classifier backward: dy = embed^T @ dlogits. dy.withUnsafeMutablePointer { dyPtr in embed.withUnsafePointer { ePtr in dlogits.withUnsafePointer { dlogitsPtr in BLAS.sgemm( CblasRowMajor, CblasTrans, CblasNoTrans, m: Int32(dim), n: Int32(seqLen), k: Int32(vocab), alpha: 1.0, a: ePtr, lda: Int32(dim), b: dlogitsPtr, ldb: Int32(seqLen), beta: 0.0, c: dyPtr, ldc: Int32(seqLen) ) } } } // dembed += dlogits @ xFinal^T (async, accumulate). let gembedPtr = gembed.withUnsafeMutablePointer { SendablePointer($0) } let captDlogits = dlogits.withUnsafePointer { SendableConstPointer($0) } let captXFinal = xFinal.withUnsafePointer { SendableConstPointer($0) } accumulator.enqueue { [captDlogits, captXFinal] in BLAS.sgemm( CblasRowMajor, CblasNoTrans, CblasTrans, m: Int32(vocab), n: Int32(dim), k: Int32(seqLen), alpha: 1.0, a: captDlogits.pointer, lda: Int32(seqLen), b: captXFinal.pointer, ldb: Int32(seqLen), beta: 1.0, c: gembedPtr.pointer, ldc: Int32(dim) ) } // Final RMSNorm backward: dx_rms_final -> dy (in-place overwrite). bwdScratch.dxRms1.withUnsafeMutablePointer { dxPtr in grmsFinal.withUnsafeMutablePointer { dwPtr in dy.withUnsafePointer { dyPtr in xCur.withUnsafePointer { xPtr in rmsFinal.withUnsafePointer { wPtr in RMSNorm.backward( dx: dxPtr, dw: dwPtr, dy: dyPtr, x: xPtr, weights: wPtr, dim: dim, seqLen: seqLen, workspace: rmsWorkspace ) } } } } } dy.withUnsafeMutablePointer { dyPtr in bwdScratch.dxRms1.withUnsafePointer { dxPtr in dyPtr.update(from: dxPtr, count: dim * seqLen) } } // Backward pass (12 layers, reverse). try BackwardPass.runTimed( dy: dy, acts: acts, kernels: kernelStorage, staticKernels: staticKernels, grads: grads, weights: layers, scratch: bwdScratch, rmsWorkspace: rmsWorkspace, accumulator: accumulator, dim: dim, hidden: hidden, seqLen: seqLen, heads: heads, surfaceHandles: surfaceHandles, timings: &stepTimings ) // Embedding backward (accumulates into gembed). accumulator.barrier() gembed.withUnsafeMutablePointer { gPtr in dy.withUnsafePointer { dyPtr in inputTokenIDs.withUnsafeBufferPointer { tokensBuf in Embedding.backward( dEmbedding: gPtr, dx: dyPtr, tokens: tokensBuf.baseAddress!, vocabSize: vocab, dim: dim, seqLen: seqLen ) } } } if currentStep % 10 == 0 || currentStep == startStep { print(String(format: "step %-4d loss=%.4f", locale: posix, currentStep, lastLoss)) } let stepMs = MachTime.ms(MachTime.now() - stepT0) batchStepMsAccum += stepMs let completedSteps = stepsBatch + 1 batchTimings.tAne += stepTimings.tAne batchTimings.tIO += stepTimings.tIO batchTimings.tCls += stepTimings.tCls batchTimings.tElem += stepTimings.tElem batchTimings.tRms += stepTimings.tRms batchTimings.tCblasWait += stepTimings.tCblasWait stderrLine( String( format: "{\"type\":\"step\",\"step\":%d,\"loss\":%.6f,\"ms\":%.3f,\"ms_per_step\":%.3f,\"t_ane\":%.3f,\"t_io\":%.3f,\"t_cls\":%.3f,\"t_elem\":%.3f,\"t_rms\":%.3f,\"t_cblas_wait\":%.3f,\"compiles\":%d}", locale: posix, currentStep, lastLoss, stepMs, batchStepMsAccum / Double(completedSteps), batchTimings.tAne / Double(completedSteps), batchTimings.tIO / Double(completedSteps), batchTimings.tCls / Double(completedSteps), batchTimings.tElem / Double(completedSteps), batchTimings.tRms / Double(completedSteps), batchTimings.tCblasWait / Double(completedSteps), CompileBudget.currentCount ) ) stepsBatch += 1 stepsDoneThisRun += 1 step += 1 } let tms = MachTime.ms(MachTime.now() - tt0) totalTrainMs += tms batchesThisRun += 1 // Wait all async dW. accumulator.waitAll() // Scale gradients by mean over actual accumulated steps. let gsc = 1.0 / Float(stepsBatch) for L in 0..