import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader # Data preprocessing and loading transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) train_dataset = datasets.MNIST('.', train=True, download=True, transform=transform) test_dataset = datasets.MNIST('.', train=False, transform=transform) train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) val_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) # Define a simple model model = nn.Sequential( nn.Linear(784, 128), nn.ReLU(), nn.Linear(128, 10) ) # Loss function and optimizer criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.01) # Parameters for early stopping patience = 5 best_loss = float('inf') counter = 0 best_model_wts = model.state_dict() # Training loop with early stopping for epoch in range(50): model.train() for inputs, targets in train_loader: inputs = inputs.view(inputs.size(0), -1) # Flatten the inputs optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() # Validation phase model.eval() val_loss = 0.0 with torch.no_grad(): for inputs, targets in val_loader: inputs = inputs.view(inputs.size(0), -1) # Flatten the inputs outputs = model(inputs) loss = criterion(outputs, targets) val_loss += loss.item() val_loss /= len(val_loader) print(f"Epoch {epoch+1}, Validation Loss: {val_loss}") # Check if validation loss improved if val_loss < best_loss: best_loss = val_loss counter = 0 best_model_wts = model.state_dict() # Save the best model weights else: counter += 1 if counter >= patience: print("Early stopping triggered") break # Save the best model to disk torch.save(best_model_wts, 'best_model.pth') print("Best model saved to best_model.pth") # Restore the best model weights model.load_state_dict(best_model_wts) print("Best model weights restored") # Load the best model weights from disk model.load_state_dict(torch.load('best_model.pth')) print("Best model weights restored")