| import torch.optim as optim |
| import torch.nn as nn |
|
|
| class Trainer: |
| def __init__(self, model): |
| self.model = model |
| self.criterion = nn.CrossEntropyLoss() |
| self.optimizer = optim.SGD(self.model.parameters(), lr=0.01) |
|
|
| def train(self, train_loader, epochs=5): |
| for epoch in range(epochs): |
| for images, labels in train_loader: |
| self.optimizer.zero_grad() |
| outputs = self.model(images.view(-1, 28*28)) |
| loss = self.criterion(outputs, labels) |
| loss.backward() |
| self.optimizer.step() |
|
|
| print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}") |