{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "(lightning_mnist_example)=\n", "\n", "# Train a Pytorch Lightning Image Classifier\n", "\n", "This example introduces how to train a Pytorch Lightning Module using Ray Train {class}`TorchTrainer `. It demonstrates how to train a basic neural network on the MNIST dataset with distributed data parallelism.\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "!pip install \"torchmetrics>=0.9\" \"pytorch_lightning>=1.6\" " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "import random\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from filelock import FileLock\n", "from torch.utils.data import DataLoader, random_split, Subset\n", "from torchmetrics import Accuracy\n", "from torchvision.datasets import MNIST\n", "from torchvision import transforms\n", "\n", "import pytorch_lightning as pl\n", "from pytorch_lightning import trainer\n", "from pytorch_lightning.loggers.csv_logs import CSVLogger" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Prepare a dataset and module\n", "\n", "The Pytorch Lightning Trainer takes either `torch.utils.data.DataLoader` or `pl.LightningDataModule` as data inputs. You can continue using them without any changes with Ray Train. " ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "class MNISTDataModule(pl.LightningDataModule):\n", " def __init__(self, batch_size=100):\n", " super().__init__()\n", " self.data_dir = os.getcwd()\n", " self.batch_size = batch_size\n", " self.transform = transforms.Compose(\n", " [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]\n", " )\n", "\n", " def setup(self, stage=None):\n", " with FileLock(f\"{self.data_dir}.lock\"):\n", " mnist = MNIST(\n", " self.data_dir, train=True, download=True, transform=self.transform\n", " )\n", "\n", " # Split data into train and val sets\n", " self.mnist_train, self.mnist_val = random_split(mnist, [55000, 5000])\n", "\n", " def train_dataloader(self):\n", " return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=4)\n", "\n", " def val_dataloader(self):\n", " return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=4)\n", "\n", " def test_dataloader(self):\n", " with FileLock(f\"{self.data_dir}.lock\"):\n", " self.mnist_test = MNIST(\n", " self.data_dir, train=False, download=True, transform=self.transform\n", " )\n", " return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=4)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Next, define a simple multi-layer perception as the subclass of `pl.LightningModule`." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class MNISTClassifier(pl.LightningModule):\n", " def __init__(self, lr=1e-3, feature_dim=128):\n", " torch.manual_seed(421)\n", " super(MNISTClassifier, self).__init__()\n", " self.save_hyperparameters()\n", "\n", " self.linear_relu_stack = nn.Sequential(\n", " nn.Linear(28 * 28, feature_dim),\n", " nn.ReLU(),\n", " nn.Linear(feature_dim, 10),\n", " nn.ReLU(),\n", " )\n", " self.lr = lr\n", " self.accuracy = Accuracy(task=\"multiclass\", num_classes=10, top_k=1)\n", " self.eval_loss = []\n", " self.eval_accuracy = []\n", " self.test_accuracy = []\n", " pl.seed_everything(888)\n", "\n", " def forward(self, x):\n", " x = x.view(-1, 28 * 28)\n", " x = self.linear_relu_stack(x)\n", " return x\n", "\n", " def training_step(self, batch, batch_idx):\n", " x, y = batch\n", " y_hat = self(x)\n", " loss = torch.nn.functional.cross_entropy(y_hat, y)\n", " self.log(\"train_loss\", loss)\n", " return loss\n", "\n", " def validation_step(self, val_batch, batch_idx):\n", " loss, acc = self._shared_eval(val_batch)\n", " self.log(\"val_accuracy\", acc)\n", " self.eval_loss.append(loss)\n", " self.eval_accuracy.append(acc)\n", " return {\"val_loss\": loss, \"val_accuracy\": acc}\n", "\n", " def test_step(self, test_batch, batch_idx):\n", " loss, acc = self._shared_eval(test_batch)\n", " self.test_accuracy.append(acc)\n", " self.log(\"test_accuracy\", acc, sync_dist=True, on_epoch=True)\n", " return {\"test_loss\": loss, \"test_accuracy\": acc}\n", "\n", " def _shared_eval(self, batch):\n", " x, y = batch\n", " logits = self.forward(x)\n", " loss = F.nll_loss(logits, y)\n", " acc = self.accuracy(logits, y)\n", " return loss, acc\n", "\n", " def on_validation_epoch_end(self):\n", " avg_loss = torch.stack(self.eval_loss).mean()\n", " avg_acc = torch.stack(self.eval_accuracy).mean()\n", " self.log(\"val_loss\", avg_loss, sync_dist=True)\n", " self.log(\"val_accuracy\", avg_acc, sync_dist=True)\n", " self.eval_loss.clear()\n", " self.eval_accuracy.clear()\n", " \n", " def configure_optimizers(self):\n", " optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\n", " return optimizer" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "You don't need to modify the definition of the PyTorch Lightning model or datamodule." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Define a training function\n", "\n", "This code defines a {ref}`training function ` for each worker. Comparing the training fuction with the original PyTorch Lightning code, notice three main differences:\n", "\n", "- Distributed strategy: Use {class}`RayDDPStrategy `.\n", "- Cluster environment: Use {class}`RayLightningEnvironment `.\n", "- Parallel devices: Always set to `devices=\"auto\"` to use all available devices configured by ``TorchTrainer``.\n", "\n", "See {ref}`Getting Started with PyTorch Lightning ` for more information.\n", "\n", "\n", "For checkpoint reporting, Ray Train provides a minimal {class}`RayTrainReportCallback ` class that reports metrics and checkpoints at the end of each train epoch. For more complex checkpoint logic, implement custom callbacks. See {ref}`Saving and Loading Checkpoint `." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "use_gpu = True # Set to False if you want to run without GPUs\n", "num_workers = 4" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "import pytorch_lightning as pl\n", "from ray.train import RunConfig, ScalingConfig, CheckpointConfig\n", "from ray.train.torch import TorchTrainer\n", "from ray.train.lightning import (\n", " RayDDPStrategy,\n", " RayLightningEnvironment,\n", " RayTrainReportCallback,\n", " prepare_trainer,\n", ")\n", "\n", "def train_func_per_worker():\n", " model = MNISTClassifier(lr=1e-3, feature_dim=128)\n", " datamodule = MNISTDataModule(batch_size=128)\n", "\n", " trainer = pl.Trainer(\n", " devices=\"auto\",\n", " strategy=RayDDPStrategy(),\n", " plugins=[RayLightningEnvironment()],\n", " callbacks=[RayTrainReportCallback()],\n", " max_epochs=10,\n", " accelerator=\"gpu\" if use_gpu else \"cpu\",\n", " log_every_n_steps=100,\n", " logger=CSVLogger(\"logs\"),\n", " )\n", " \n", " trainer = prepare_trainer(trainer)\n", " \n", " # Train model\n", " trainer.fit(model, datamodule=datamodule)\n", "\n", " # Evaluation on the test dataset\n", " trainer.test(model, datamodule=datamodule)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Now put everything together:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)\n", "\n", "run_config = RunConfig(\n", " name=\"ptl-mnist-example\",\n", " storage_path=\"/tmp/ray_results\",\n", " checkpoint_config=CheckpointConfig(\n", " num_to_keep=3,\n", " checkpoint_score_attribute=\"val_accuracy\",\n", " checkpoint_score_order=\"max\",\n", " ),\n", ")\n", "\n", "trainer = TorchTrainer(\n", " train_func_per_worker,\n", " scaling_config=scaling_config,\n", " run_config=run_config,\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Now fit your trainer:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "
\n", "
\n", "

Tune Status

\n", " \n", "\n", "\n", "\n", "\n", "\n", "
Current time:2023-08-07 23:41:11
Running for: 00:00:39.80
Memory: 24.2/186.6 GiB
\n", "
\n", "
\n", "
\n", "

System Info

\n", " Using FIFO scheduling algorithm.
Logical resource usage: 1.0/48 CPUs, 4.0/4 GPUs\n", "
\n", " \n", "
\n", "
\n", "
\n", "

Trial Status

\n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "
Trial name status loc iter total time (s) train_loss val_accuracy val_loss
TorchTrainer_78346_00000TERMINATED10.0.6.244:120026 10 29.0221 0.0315938 0.970002 -12.3466
\n", "
\n", "
\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(TorchTrainer pid=120026)\u001b[0m Starting distributed worker processes: ['120176 (10.0.6.244)', '120177 (10.0.6.244)', '120178 (10.0.6.244)', '120179 (10.0.6.244)']\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m Setting up process group for: env:// [rank=0, world_size=4]\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m [rank: 0] Global seed set to 888\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m GPU available: True (cuda), used: True\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m TPU available: False, using: 0 TPU cores\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m IPU available: False, using: 0 IPUs\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m HPU available: False, using: 0 HPUs\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=120178)\u001b[0m Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120178)\u001b[0m Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /tmp/ray_results/ptl-mnist-example/TorchTrainer_78346_00000_0_2023-08-07_23-40-31/rank_2/MNIST/raw/train-images-idx3-ubyte.gz\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/9912422 [00:00` for a tutorial on using Ray Train and PyTorch Lightning \n", "\n", "* {ref}`Ray Train Examples ` for more use cases" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.11" }, "orphan": true, "vscode": { "interpreter": { "hash": "a8c1140d108077f4faeb76b2438f85e4ed675f93d004359552883616a1acd54c" } } }, "nbformat": 4, "nbformat_minor": 4 }