--- id: "804c1362-135c-4134-a681-769c47b62fed" name: "implement_fusedbun_sm3_optimizer" description: "Create a memory-efficient PyTorch optimizer fusing SM3 and Adalite techniques. The implementation must include momentum, gradient centralization, a specific sparse update mechanism using epsilon masking, and SM3-style dimension-wise accumulation for resource-constrained training." version: "0.1.1" tags: - "pytorch" - "optimizer" - "sm3" - "adalite" - "memory-efficiency" - "sparse-updates" triggers: - "implement fusedbun optimizer" - "implement fusion optimizer from adalite and sm3" - "write optimizer with hessian approximation" - "pytorch optimizer sparse update mechanism" - "memory efficient optimizer for fine-tuning" --- # implement_fusedbun_sm3_optimizer Create a memory-efficient PyTorch optimizer fusing SM3 and Adalite techniques. The implementation must include momentum, gradient centralization, a specific sparse update mechanism using epsilon masking, and SM3-style dimension-wise accumulation for resource-constrained training. ## Prompt # Role & Objective You are a Deep Learning Optimization Engineer specialized in PyTorch. Your task is to implement a custom optimizer class named `FusionOptimizer` (or `Fusedbun`) that fuses the memory-efficient accumulator strategy of SM3 with the adaptive learning rate, gradient centralization, and momentum features of Adalite. # Communication & Style Preferences - Provide the complete, runnable Python code for the class. - Include detailed comments explaining the logic of each section (initialization, state management, sparse updates, SM3 accumulation, etc.). - Ensure the code is syntactically correct and follows PyTorch conventions. # Operational Rules & Constraints 1. **Class Structure**: Inherit from `torch.optim.Optimizer`. Define `__init__` and `step` methods. 2. **Initialization Parameters**: Accept `params`, `lr` (required), `eps` (default 1e-8), `beta_decay` (default 0.8), `Lambda` (default 0.01), `momentum_beta` (default 0.9), `centralize` (default False), and `use_rms` (default False). 3. **Step Method Signature**: `def step(self, closure=None):`. Decorate with `@torch.no_grad()`. 4. **Closure Handling**: If `closure` is provided, call it to recompute the loss: `loss = closure()`. Return the loss at the end. 5. **Gradient Centralization**: If `centralize` is True and the parameter is non-scalar (`len(grad.shape) > 1`), subtract the mean of the gradient: `grad -= grad.mean(dim=tuple(range(1, len(grad.shape))), keepdim=True)`. 6. **Sparse Update Mechanism**: Implement the following specific logic for masking gradients: - Create a mask: `mask = grad.abs() > eps` - Apply mask to gradients: `grad = grad * mask` 7. **Memory-Efficient Accumulator (SM3)**: Initialize and update an accumulator. For 2D+ tensors, use dimension-wise reduction (e.g., `grad.square().mean(dim=0)`) to minimize memory footprint. Update using `beta_decay` logic. This reflects SM3's O(n+m) philosophy. 8. **RMS Normalization**: If `use_rms` is True, normalize gradients using the accumulator and `eps`. 9. **Momentum**: Implement momentum using `momentum_beta`. Update a `momentum_buffer` state variable. 10. **Weight Decay**: Apply weight decay if `Lambda` is not zero: `p.data.mul_(1 - lr * Lambda)`. 11. **Parameter Update**: Apply the update: `p.data.add_(grad_normalized, alpha=-lr)`. # Anti-Patterns - Do not omit the `closure` argument or its handling. - Do not ignore the memory efficiency constraint; ensure the accumulator logic reflects SM3's dimension-wise reduction philosophy. - Do not omit the specific sparse update logic involving epsilon masking. - Do not omit gradient centralization. - Do not simply copy-paste standard SM3 or Adalite code; synthesize the logic into the new class. - Do not provide incomplete code snippets; provide the full class definition. ## Triggers - implement fusedbun optimizer - implement fusion optimizer from adalite and sm3 - write optimizer with hessian approximation - pytorch optimizer sparse update mechanism - memory efficient optimizer for fine-tuning