examples icon indicating copy to clipboard operation
examples copied to clipboard

[DOC] Update mnist.py example

Open orion160 opened this issue 1 year ago • 10 comments

Update example at https://github.com/pytorch/examples/blob/main/mnist/main.py to use torch.compile features

orion160 avatar Jun 12 '24 17:06 orion160

Proposal:

diff --git a/mnist/main.py b/mnist/main.py
index 184dc47..a3cffd1 100644
--- a/mnist/main.py
+++ b/mnist/main.py
@@ -3,13 +3,14 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 import torch.optim as optim
-from torchvision import datasets, transforms
+from torchvision import datasets
+from torchvision.transforms import v2 as transforms
 from torch.optim.lr_scheduler import StepLR
 
 
 class Net(nn.Module):
     def __init__(self):
-        super(Net, self).__init__()
+        super().__init__()
         self.conv1 = nn.Conv2d(1, 32, 3, 1)
         self.conv2 = nn.Conv2d(32, 64, 3, 1)
         self.dropout1 = nn.Dropout(0.25)
@@ -33,19 +34,42 @@ class Net(nn.Module):
         return output
 
 
-def train(args, model, device, train_loader, optimizer, epoch):
+def train_amp(args, model, device, train_loader, opt, epoch, scaler):
     model.train()
     for batch_idx, (data, target) in enumerate(train_loader):
-        data, target = data.to(device), target.to(device)
-        optimizer.zero_grad()
+        data, target = data.to(device, memory_format=torch.channels_last), target.to(
+            device
+        )
+        opt.zero_grad()
+        with torch.autocast(device_type=device.type):
+            output = model(data)
+            loss = F.nll_loss(output, target)
+        scaler.scale(loss).backward()
+        scaler.step(opt)
+        scaler.update()
+        if batch_idx % args.log_interval == 0:
+            print(
+                f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100.0 * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
+            )
+            if args.dry_run:
+                break
+
+
+def train(args, model, device, train_loader, opt, epoch):
+    model.train()
+    for batch_idx, (data, target) in enumerate(train_loader):
+        data, target = data.to(device, memory_format=torch.channels_last), target.to(
+            device
+        )
+        opt.zero_grad()
         output = model(data)
         loss = F.nll_loss(output, target)
         loss.backward()
-        optimizer.step()
+        opt.step()
         if batch_idx % args.log_interval == 0:
-            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
-                epoch, batch_idx * len(data), len(train_loader.dataset),
-                100. * batch_idx / len(train_loader), loss.item()))
+            print(
+                f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100.0 * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
+            )
             if args.dry_run:
                 break
 
@@ -58,43 +82,125 @@ def test(model, device, test_loader):
         for data, target in test_loader:
             data, target = data.to(device), target.to(device)
             output = model(data)
-            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
-            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
+            test_loss += F.nll_loss(
+                output, target, reduction="sum"
+            ).item()  # sum up batch loss
+            pred = output.argmax(
+                dim=1, keepdim=True
+            )  # get the index of the max log-probability
             correct += pred.eq(target.view_as(pred)).sum().item()
 
     test_loss /= len(test_loader.dataset)
 
-    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
-        test_loss, correct, len(test_loader.dataset),
-        100. * correct / len(test_loader.dataset)))
+    print(
+        f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100.0 * correct / len(test_loader.dataset):.0f}%)\n"
+    )
 
 
-def main():
+def parse_args():
     # Training settings
-    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
-    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
-                        help='input batch size for training (default: 64)')
-    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
-                        help='input batch size for testing (default: 1000)')
-    parser.add_argument('--epochs', type=int, default=14, metavar='N',
-                        help='number of epochs to train (default: 14)')
-    parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
-                        help='learning rate (default: 1.0)')
-    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
-                        help='Learning rate step gamma (default: 0.7)')
-    parser.add_argument('--no-cuda', action='store_true', default=False,
-                        help='disables CUDA training')
-    parser.add_argument('--no-mps', action='store_true', default=False,
-                        help='disables macOS GPU training')
-    parser.add_argument('--dry-run', action='store_true', default=False,
-                        help='quickly check a single pass')
-    parser.add_argument('--seed', type=int, default=1, metavar='S',
-                        help='random seed (default: 1)')
-    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
-                        help='how many batches to wait before logging training status')
-    parser.add_argument('--save-model', action='store_true', default=False,
-                        help='For Saving the current Model')
+    parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
+    parser.add_argument(
+        "--batch-size",
+        type=int,
+        default=64,
+        metavar="N",
+        help="input batch size for training (default: 64)",
+    )
+    parser.add_argument(
+        "--test-batch-size",
+        type=int,
+        default=1000,
+        metavar="N",
+        help="input batch size for testing (default: 1000)",
+    )
+    parser.add_argument(
+        "--epochs",
+        type=int,
+        default=14,
+        metavar="N",
+        help="number of epochs to train (default: 14)",
+    )
+    parser.add_argument(
+        "--lr",
+        type=float,
+        default=1.0,
+        metavar="LR",
+        help="learning rate (default: 1.0)",
+    )
+    parser.add_argument(
+        "--gamma",
+        type=float,
+        default=0.7,
+        metavar="M",
+        help="Learning rate step gamma (default: 0.7)",
+    )
+    parser.add_argument(
+        "--no-cuda", action="store_true", default=False, help="disables CUDA training"
+    )
+    parser.add_argument(
+        "--no-mps",
+        action="store_true",
+        default=False,
+        help="disables macOS GPU training",
+    )
+    parser.add_argument(
+        "--dry-run",
+        action="store_true",
+        default=False,
+        help="quickly check a single pass",
+    )
+    parser.add_argument(
+        "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
+    )
+    parser.add_argument(
+        "--log-interval",
+        type=int,
+        default=10,
+        metavar="N",
+        help="how many batches to wait before logging training status",
+    )
+    parser.add_argument(
+        "--use-amp",
+        type=bool,
+        default=False,
+        help="use automatic mixed precision",
+    )
+    parser.add_argument(
+        "--compile-backend",
+        type=str,
+        default="inductor",
+        metavar="BACKEND",
+        help="backend to compile the model with",
+    )
+    parser.add_argument(
+        "--compile-mode",
+        type=str,
+        default="default",
+        metavar="MODE",
+        help="compilation mode",
+    )
+    parser.add_argument(
+        "--save-model",
+        action="store_true",
+        default=False,
+        help="For Saving the current Model",
+    )
+    parser.add_argument(
+        "--data-dir",
+        type=str,
+        default="../data",
+        metavar="DIR",
+        help="path to the data directory",
+    )
     args = parser.parse_args()
+
+    return args
+
+
+def main():
+    args = parse_args()
+
     use_cuda = not args.no_cuda and torch.cuda.is_available()
     use_mps = not args.no_mps and torch.backends.mps.is_available()
 
@@ -107,32 +213,43 @@ def main():
     else:
         device = torch.device("cpu")
 
-    train_kwargs = {'batch_size': args.batch_size}
-    test_kwargs = {'batch_size': args.test_batch_size}
+    train_kwargs = {"batch_size": args.batch_size}
+    test_kwargs = {"batch_size": args.test_batch_size}
     if use_cuda:
-        cuda_kwargs = {'num_workers': 1,
-                       'pin_memory': True,
-                       'shuffle': True}
+        cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
         train_kwargs.update(cuda_kwargs)
         test_kwargs.update(cuda_kwargs)
 
-    transform=transforms.Compose([
-        transforms.ToTensor(),
-        transforms.Normalize((0.1307,), (0.3081,))
-        ])
-    dataset1 = datasets.MNIST('../data', train=True, download=True,
-                       transform=transform)
-    dataset2 = datasets.MNIST('../data', train=False,
-                       transform=transform)
-    train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
+    transform = transforms.Compose(
+        [
+            transforms.ToImage(),
+            transforms.ToDtype(torch.float32, scale=True),
+            transforms.Normalize(mean=(0.1307,), std=(0.3081,)),
+        ]
+    )
+
+    data_dir = args.data_dir
+
+    dataset1 = datasets.MNIST(data_dir, train=True, download=True, transform=transform)
+    dataset2 = datasets.MNIST(data_dir, train=False, transform=transform)
+    train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
     test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
 
-    model = Net().to(device)
-    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
+    model = Net().to(device, memory_format=torch.channels_last)
+    model = torch.compile(model, backend=args.compile_backend, mode=args.compile_mode)
+    optimizer = optim.Adadelta(model.parameters(), lr=torch.tensor(args.lr))
 
     scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
+
+    scaler = None
+    if use_cuda and args.use_amp:
+        scaler = torch.GradScaler(device=device)
+
     for epoch in range(1, args.epochs + 1):
-        train(args, model, device, train_loader, optimizer, epoch)
+        if scaler is None:
+            train(args, model, device, train_loader, optimizer, epoch)
+        else:
+            train_amp(args, model, device, train_loader, optimizer, epoch, scaler)
         test(model, device, test_loader)
         scheduler.step()
 
@@ -140,5 +257,5 @@ def main():
         torch.save(model.state_dict(), "mnist_cnn.pt")
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()

orion160 avatar Jun 12 '24 17:06 orion160

CC @svekars

orion160 avatar Jun 12 '24 23:06 orion160

Hey, can I work on this issue ?

doshi-kevin avatar Aug 22 '24 17:08 doshi-kevin

Hey @orion160 I have made the changes and create a pull request at #1282 . Could you please check once

doshi-kevin avatar Aug 23 '24 11:08 doshi-kevin

@orion160 any updates on this ?

doshi-kevin avatar Aug 26 '24 08:08 doshi-kevin

@doshi-kevin I am not a PyTorch maintainer, it is great your interest on contributing.

Though I see 2 problems, first I you are applying the commit, which means that there's neither authoring nor coauthoring on the git history.

And the second, is that I see that your PR is about 600+ commits long, so it is possible that you did a merge and you created an altered git history which you want to push.

The usual approach would be to rebase your changes to head and then do the PR.

orion160 avatar Aug 26 '24 14:08 orion160

Sure I will make the changes, else create a new fork and a new pull request. @orion160

doshi-kevin avatar Aug 27 '24 06:08 doshi-kevin

Hey @orion160 , please check #1283

doshi-kevin avatar Aug 27 '24 06:08 doshi-kevin

It seems good

orion160 avatar Aug 27 '24 22:08 orion160

Hope this gets pushed @orion160 @msaroufim

doshi-kevin avatar Aug 28 '24 05:08 doshi-kevin