import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import TensorDataset, DataLoader import copy # Toy data x = torch.randn(100, 10) # 100 samples, 10 features each y = torch.randn(100, 1) # 100 targets # Split into training and validation sets x_train, x_val = x[:80], x[80:] y_train, y_val = y[:80], y[80:] # Wrap in a Dataset and DataLoader for minibatch training train_dataset = TensorDataset(x_train, y_train) val_dataset = TensorDataset(x_val, y_val) train_loader = DataLoader(train_dataset, batch_size=20) # Minibatch size of 20 val_loader = DataLoader(val_dataset, batch_size=20) # Minibatch size of 20 # Define a simple model with batch normalization model = nn.Sequential( nn.Linear(10, 5), nn.BatchNorm1d(5), # Batch normalization nn.ReLU(), nn.Linear(5, 1) ) # Define loss function and optimizer criterion = nn.MSELoss() optimizer = optim.SGD(model.parameters(), lr=0.01) # Variables for early stopping n_epochs_stop = 10 min_val_loss = float('inf') epochs_no_improve = 0 # Variable to store the best model parameters best_model_params = copy.deepcopy(model.state_dict()) # Training loop with early stopping for epoch in range(100): # Training model.train() for batch_x, batch_y in train_loader: optimizer.zero_grad() outputs = model(batch_x) loss = criterion(outputs, batch_y) loss.backward() optimizer.step() # Validation model.eval() val_loss = 0 with torch.no_grad(): for batch_x, batch_y in val_loader: outputs = model(batch_x) val_loss += criterion(outputs, batch_y).item() # Check for early stopping if val_loss < min_val_loss: min_val_loss = val_loss epochs_no_improve = 0 best_model_params = copy.deepcopy(model.state_dict()) else: epochs_no_improve += 1 if epochs_no_improve == n_epochs_stop: print('Early stopping!') break print(f'Epoch {epoch+1}, Loss: {loss.item()}, Validation Loss: {val_loss}') # Load the best model parameters model.load_state_dict(best_model_params)