Skip to content

Commit 8e917a4

Browse files
committed
Address review: torch.accelerator device selection and prose
1 parent ae10152 commit 8e917a4

File tree

1 file changed

+18
-17
lines changed

1 file changed

+18
-17
lines changed

beginner_source/blitz/cifar10_tutorial.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,15 @@
7777

7878
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
7979
download=True, transform=transform)
80+
# num_workers=0 avoids multiprocessing spawn issues when running this file as
81+
# ``python cifar10_tutorial.py`` on macOS/Windows (see note above).
8082
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
81-
shuffle=True, num_workers=2)
83+
shuffle=True, num_workers=0)
8284

8385
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
8486
download=True, transform=transform)
8587
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
86-
shuffle=False, num_workers=2)
88+
shuffle=False, num_workers=0)
8789

8890
classes = ('plane', 'car', 'bird', 'cat',
8991
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
@@ -292,28 +294,27 @@ def forward(self, x):
292294
########################################################################
293295
# Okay, so what next?
294296
#
295-
# How do we run these neural networks on the GPU?
297+
# How do we run these neural networks on a GPU or other accelerator?
296298
#
297299
# Training on GPU
298300
# ----------------
299-
# Just like how you transfer a Tensor onto the GPU, you transfer the neural
300-
# net onto the GPU.
301+
# Just like how you transfer a Tensor onto a device, you transfer the neural
302+
# net onto that device.
301303
#
302-
# Let's first select a device. Prefer CUDA when available, otherwise use MPS
303-
# (Apple Silicon), and fall back to CPU.
304-
if torch.cuda.is_available():
305-
device = torch.device("cuda:0")
306-
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
307-
device = torch.device("mps")
308-
else:
309-
device = torch.device("cpu")
310-
311-
# This prints the selected device, e.g. "cuda:0", "mps", or "cpu".
312-
304+
# Let's first select a device. This picks the fastest available accelerator,
305+
# or falls back to CPU (same pattern as ``quickstart_tutorial.py``).
306+
device = (
307+
torch.accelerator.current_accelerator().type
308+
if torch.accelerator.is_available()
309+
else "cpu"
310+
)
311+
312+
# ``device`` is a string such as ``"cuda"``, ``"mps"``, or ``"cpu"``, which you
313+
# can pass to ``Tensor.to(...)`` and ``nn.Module.to(...)``.
313314
print(device)
314315

315316
########################################################################
316-
# The rest of this section assumes that ``device`` is an accelerator device.
317+
# The snippets below show how to move the model and batch data to ``device``.
317318
#
318319
# Then these methods will recursively go over all modules and convert their
319320
# parameters and buffers to tensors on ``device``:

0 commit comments

Comments
 (0)