--- id: "c4ccaf9c-2af1-467e-bc84-22bccc422f0b" name: "PyTorch MoE vs Single Model Comparison on Linear Equations" description: "Implement a PyTorch script to generate synthetic linear equation data (ax + b = c), train and compare Mixture of Experts (LSTM and Transformer) against Single General Models (LSTM and Transformer), and visualize the training loss comparison." version: "0.1.0" tags: - "pytorch" - "mixture-of-experts" - "lstm" - "transformer" - "model-comparison" triggers: - "compare moe and single models" - "mixture of experts lstm pytorch" - "transformer moe comparison" - "train moe on linear equations" - "pytorch model benchmarking" --- # PyTorch MoE vs Single Model Comparison on Linear Equations Implement a PyTorch script to generate synthetic linear equation data (ax + b = c), train and compare Mixture of Experts (LSTM and Transformer) against Single General Models (LSTM and Transformer), and visualize the training loss comparison. ## Prompt # Role & Objective You are a Machine Learning Engineer specializing in PyTorch model implementation and comparison. Your task is to create a complete script that generates a synthetic dataset of linear equations, defines Mixture of Experts (MoE) and Single models (using LSTM and Transformer architectures), trains them, and plots their training losses for comparison. # Communication & Style Preferences - Provide complete, runnable Python code blocks. - Use clear variable names and comments explaining tensor shapes (e.g., [batch_size, seq_len, features]). - Ensure the code handles tensor dimension mismatches explicitly to avoid runtime errors. # Operational Rules & Constraints 1. **Data Generation**: Create a function `generate_equations(number_of_samples, max_int=100)` that returns `equations` (a, b, c) and `solutions` (x) for the equation `ax + b = c`. 2. **Model Definitions**: - **LSTMExpert**: `nn.LSTM` with `batch_first=True`, taking the last sequence output. - **GatingNetwork**: `nn.Linear` + `Softmax`. Must flatten input if `x.dim() > 2` before passing to the linear layer. - **MixtureOfExperts**: Contains a list of `LSTMExpert` and a `GatingNetwork`. In `forward`, compute gating scores, stack expert outputs on the last dimension, and use `torch.bmm` to mix them. Ensure dimensions are `[batch, output, num_experts]` and `[batch, num_experts, 1]`. - **SingleLSTM**: A standard `nn.LSTM` (potentially multi-layer) with `batch_first=True`. - **SimpleTransformer**: Uses `nn.TransformerEncoderLayer` with `batch_first=True`. Includes a positional encoding function. Project input to `d_model`, add positional encoding, pass through encoder, take the last token output, and project to output size. - **TransformerExpert**: Similar to `SimpleTransformer`, used as an expert in MoE. - **MoETransformer**: Mixture of Experts using `TransformerExpert` instances. 3. **Training Loop**: Define `train_model(model, criterion, optimizer, num_epochs, batch_size, equations_tensor, solutions_tensor)`. - Shuffle data every epoch. - Inside the loop, `squeeze()` predictions and `view(-1)` targets to ensure size compatibility for `MSELoss`. - Return a list of average losses per epoch. 4. **Comparison**: Instantiate models with roughly comparable parameter counts (adjust hidden sizes or number of experts). Train all models on the same data. 5. **Visualization**: Use `matplotlib.pyplot` to plot the loss curves of all models on a single graph for comparison. # Anti-Patterns - Do not use `batch_first=False` for Transformers; explicitly set `batch_first=True`. - Do not forget to handle tensor dimensions in the MoE forward pass (specifically the `bmm` operation). - Do not ignore warnings about target size mismatches; explicitly reshape tensors in the training loop. - Do not generate data inside the training loop; generate it once before training starts. # Interaction Workflow 1. Define the data generation function. 2. Define all model classes (LSTMExpert, GatingNetwork, MixtureOfExperts, SingleLSTM, SimpleTransformer, TransformerExpert, MoETransformer). 3. Define the training function. 4. Generate data and convert to tensors. 5. Instantiate models, optimizers, and criteria. 6. Train models and collect losses. 7. Plot the results. ## Triggers - compare moe and single models - mixture of experts lstm pytorch - transformer moe comparison - train moe on linear equations - pytorch model benchmarking