Editorial hint: This post has been revised and changed in parts on March, 18th/19th/22nd, 2025, after some new tests and insights. The changes did not concern the result data of the performed experiments, but their interpretation.
In the last post of this mini-series we saw that some Torchvision datasets have a directly accessible property “data“. It contains image data in a format specific for the dataset. In the case of MNIST and FashionMNIST (and for many other sets) these data are already Torch tensors. However, due to the fact that tensors for gray images typically are squeezed, they will not fit the format a neural network [NN] requires for image tensors. They lack a dimension for indexing the usual color layers. Thus, we cannot use these tensors directly as input for a NN-model.
For standard usage of an image set like MNIST, we therefore define a PyTorch dataset object with an added chain of transformation operations. A standard operation is e.g. ToTensor() – which remedies deficits regarding typical format expectations of NNs. All defined transformation operations are invoked whenever we retrieve an indexed element of the dataset. Without transform operations, such an element typically provides a tuple of image data in PIL format and label data (for MNIST/FashionMNIST as integers). The ToTensor() function eventually delivers usable PyTorch tensors. The resulting image tensor data follow the required [C, W, H] format.
An (obligatory) __get_item__ function of a specific Torch dataset class controls how the data of a dataset object are handled internally. In the case of MNIST this includes a transformation to the PIL format and the subsequent application of the user-defined transform functions.
In this post we look a bit at the setup of a PyTorch “dataloader” and what it does with dataset elements on the CPU. My experience with TF2 data pipelines tells me that the activities of a dataloader may become a performance bottleneck, in particular during the training of relatively small NNs.
In some experiments below we will indeed see that the GPU operates far below its capabilities for small NNs when a dataloader feeds a long sequence of small data batches into it. This is an effect that may hit ML-applications on standard PC systems with newer GPUs. The question is whether the GPU may wait in such a situation for data provision from CPU and RAM. Or whether other factors limit the overall performance during model-training.
In this post we will identify some first measures which help to use the GPU more efficiently – at least in parts. Below, I provide code snippets to enable the reader to perform his/her own test runs.
Test data and test NN: To test the functionality of a PyTorch dataloader we first define a MNIST dataset object. We then create a dataloader object which uses this dataset. Afterward we pass the data to a simple CNN with 4 successive simple Conv2D layers – each working with a stride of 2. The eventual result of 256 maps are fed into two fully connected layers producing an output via log_softmax. I will not go into details and the disadvantages of this extremely simple NN as this is not of interest here. We are primarily interested in CPU and GPU performance and their relative load during NN-training.
Plan of work: We vary parameters of the dataloader object in successive training runs of our NN. During the runs we grossly measure the CPU and GPU load as well as the turnaround time for 40 training epochs. We will find that the parameters for the number of parallel dataloader workers (num_workers) and for the batch size are decisive for performance optimization.
Hint: The first sections of this post contain somewhat standard preparation steps.
Previous post
- Post I: PyTorch / datasets / dataloader / data transfer to GPU – I – properties of some torchvision datasets
Preparation and code snippets
I should say that I did all of my tests in a Jupyterlab environment. I assume that the reader is familiar with such a browser-based environment. I further assume that you know how to restart the notebook kernel and how to rerun a notebook after substantial changes to a test-run’s setup.
We first load some Python modules. Afterward, we define parameters for the setup of the training and the dataloader. We also check for a cuda-device, in my case an available Nvidia TI 4060.
You may sometimes have to scroll to see all of the code.
import os, sys, time, math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor, Normalize, Compose
import matplotlib.pyplot as plt
from PIL import Image
Some of the following control parameters will become clearer later on. I will later also indicate which parameters have to be changed for certain runs.
num_samples = 60000 # MNIST & FashionMNIST
# Parameters for NN training
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Learning rate
LR = 1.e-4
# Batch Size
BATCH_SIZE_TRAIN = 32
# Number of epochs
NUM_EPOCHS = 41
NUM_BATCHES = math.ceil(num_samples / BATCH_SIZE_TRAIN)
print("Num_BATCHES = ", NUM_BATCHES)
# print all NP_Batches
NP_Batches = int(num_samples / BATCH_SIZE_TRAIN / 100.) * 20
#print(NP_Batches)
# Preload all data to the GPU?
b_PRELOAD = True
# Parameters for setting up a dataloader
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# if True => Tensor data set with all transforms done once
b_TENSOR_DATASET = False
if b_PRELOAD:
b_TENSOR_DATASET = True
# Dataloader parameters
b_STANDARD_PARAMS = True
NUM_WORKERS = 0
b_PIN_MEMORY = False
b_PERSISTENT_WORKERS = False
b_SHUFFLE = True
b_ASYNC = False
Note: We will not use options controlled by “b_PRELOAD” and “b_TENSOR_DATASET” parameters in this post.
The parameter “b_TENSOR_DATASET” is a decisive one, although we will not use it in this post. However, in the next post of this series, we will perform basic data transformations ahead of the NN-training. Then we will create datasets built on tensors. The parameter “b_PRELOAD” controls whether the elements of a tensor dataset will be preloaded to the GPU ahead of a training run.
We will not set these two parameter to True in the present post, i.e. we will not use related options in the code.
We now test for CUDA:
# Checking for cuda-device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
if device.type == 'cuda':
print(torch.cuda.get_device_name(0))
-------------------------
Using device: cuda
NVIDIA GeForce RTX 4060 Ti
Memory Usage:
Allocated: 0.0 GB
Cached: 0.0 GB
This is a standard procedure for PyTorch. I use CUDA 12.8, cudnn 9.7 and Nvidia drivers 570.127.06.
Definition and download of a dataset (for the MNIST example)
We need to download the MNIST data, of course. The following statements perform a download automatically if the dataset is not already found at a defined path. We define a dataset object for the MNIST training images and labels:
# Path to MNIST data
ds_path = '/mnt_ramdisk/MNIST_data/'
# Definition of a dataset with standard parameters
# and a chain of basic transformation operations
train_data = datasets.MNIST(root=ds_path,
train=True,
download=True,
transform=Compose([
ToTensor(),
Normalize( (0.1302,),
(0.3069,) )
])
)
print(train_data.__len__)
-------------------------------------------
<bound method MNIST.__len__ of Dataset MNIST
Number of datapoints: 60000
Root location: /mnt_ramdisk/MNIST_data/
Split: Train
StandardTransform
Transform: Compose(
ToTensor()
Normalize(mean=(0.1302,), std=(0.3069,))
)>
-----------------------
Typically, I have a RAM disk mounted on /mnt_ramdisk. So, reading the data fast enough from their storage certainly is not a problem for our tests. Even if we used a standard hard disk, the Linux system would buffer all data access on disks in a sufficiently large RAM, anyway.
Note: In contrast to the example in the first post, I have added a standardization of the data to the transform chain. I.e., the dataset will in addition to internal data handling (change to PIL-format) create tensors in the right format for us and also adjust the data wth respect to their mean value and standard deviation. On the fly during the training and – under the control of the dataloader – batch for batch. I.e. the dataloader calls, transforms and combines all elements required for a batch and bundles them in a surrounding batch tensor.
If you wonder where the mean and the standard deviation values provided to the Normalize()-function come from: We can get these data directly from the downloaded tensors of the MNIST image data:
pics = train_data.data.to(torch.float32).unsqueeze(1)
pics2 = pics / 256.
print("mean = ", pics2.mean(), " :: stddev = ", pics2.std() )
The 1st statement above should be clear after the first post of this series. This snippet gives us an output of :
mean = tensor(0.1302) :: stddev = tensor(0.3069).
Readers of the first post in this mini-series may now ask where the division by 256 is done when we later use the dataset with a dataloader. Well, have a look at what we found for the function __get_item__() in the last post. This function includes a transformation of the tensors first to Numpy arrays and afterwards to the PIL-format. The latter is where the division happens.
Creation of a so called tensor dataset for runs with prepared tensors
The next sequence of statements prepares another dataset composed of already fully normalized tensors. We will use such a “tensor dataset” in runs of the next post in this series. For now, note that we must use a special class “TensorDataset” to create a dataset object based on image tensors of e.g. dtype=torch.float32.
We first simply use the pics2-tensor from above to get yet another tensor “pics3” of normalized data.
b_print = False
if b_TENSOR_DATASET:
trans = transforms.Compose([ Normalize( (0.1302,),
(0.3069,) )
])
# Perform the transformation
start_time = time.perf_counter()
pics3 = trans(pics2)
end_time = time.perf_counter()
cpu_time = end_time - start_time
print("CPU time for normalization: ", cpu_time)
print()
print("Shape images = ", pics3.shape)
print()
if b_print:
print("Shape of img tensor : ",pics3[1,0].shape)
print()
print(pics3[1,0])
# !!!!!
# When preloading do not forget that you use
# standard parameters in the dataloader
# Multiple Workers and memory_pinning is not possible for data on the GPU!
if b_PRELOAD:
# Label tensors to GPU
labels_pre = labels.to(device)
print (labels_pre.shape)
# Normalized Img data to the GPU
picts_pre = pics3.to(device)
print (picts_pre.shape)
------------------------
CPU time for normalization: 0.04189275699991413
Shape images = torch.Size([60000, 1, 28, 28])
torch.Size([60000])
torch.Size([60000, 1, 28, 28])
We then create our “tensor dataset” as an instance of the Dataloader class:
if b_TENSOR_DATASET:
if not b_PRELOAD:
tens_train_ds = torch.utils.data.TensorDataset(pics3, labels)
else:
tens_train_ds = torch.utils.data.TensorDataset(picts_pre, labels_pre)
print("Preload - dataset formed by loaded tensors")
As said: These variants become important in a forthcoming post, only.
Definition and a first usage of a dataloader
How do we define a dataloader? It is easy.
if not b_PRELOAD:
if not b_TENSOR_DATASET:
if b_STANDARD_PARAMS:
train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE_TRAIN, shuffle=b_SHUFFLE)
else:
train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE_TRAIN, shuffle=b_SHUFFLE,
num_workers=NUM_WORKERS, pin_memory=b_PIN_MEMORY,
persistent_workers=b_PERSISTENT_WORKERS )
#test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE_TEST, shuffle=False)
else:
if b_STANDARD_PARAMS:
train_tens_loader = DataLoader(tens_train_ds, batch_size=BATCH_SIZE_TRAIN, shuffle=b_SHUFFLE)
else:
train_tens_loader = DataLoader(tens_train_ds, batch_size=BATCH_SIZE_TRAIN, shuffle=b_SHUFFLE,
num_workers=NUM_WORKERS, pin_memory=b_PIN_MEMORY,
persistent_workers=b_PERSISTENT_WORKERS )
print("Tensor Loader / NO Preload / NON-standard params", " :: NW = ", NUM_WORKERS, " :: BS = ", BATCH_SIZE_TRAIN, " :: Shuff = ", b_SHUFFLE )
#test_loader = DataLoader(tens_train_ds, batch_size=BATCH_SIZE_TEST, shuffle=False)
if b_PRELOAD:
# We absolutely need standrad parameters !!!
train_tens_loader = DataLoader(tens_train_ds, batch_size=BATCH_SIZE_TRAIN, shuffle=b_SHUFFLE)
print("Tensor Loader / PRELOAD / standard params", " :: BS = ", BATCH_SIZE_TRAIN, " :: Shuff = ", b_SHUFFLE )
The first simple variant “DataLoader(train_data, batch_size=BATCH_SIZE_TRAIN, shuffle=True)” works with standard parameters and performs a shuffling of the data during training. There are, however, more available parameters; see here. I will come back to some of them when we activate respective control parameters.
The complicated conditions regarding a dataloader for fully prepared tensor data will not be used in this post, but become central in the next one. Note, however, that a dataloader for tensors preloaded to the GPU requires standard parameters! Things like memory pinning and multiple workers would not work in this case.
Statements for separate dataloaders of test data are not used in this post. Note, however, that no shuffling would be required for the test data.
What is a dataloader?
I just quote from the PyTorch documentation: “It represents a Python iterable over a dataset”.
How can we use a dataloader?
This is easy, too: We just have to deal with the iterable. We use the enumerate function for this purpose. See the following simple example:
# Example for using the dataloader
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
b_run_loader = True
if b_run_loader and not b_PRELOAD:
if not b_TENSOR_DATASET:
d_loader = train_dataloader
else:
d_loader = train_tens_loader
print("using tensor dataset")
for batch_i, (X,y) in enumerate(d_loader):
mean = X.mean()
stddev = X.std()
# Some output
# The current position is given by all batches so far + len(X)
if batch_i % NP_Batches == 0:
print("batch = ", "{:4.0f}".format(batch_i),
" :: mean : ", "{:10.5f}".format(mean),
" :: stddev : ", "{:10.5f}".format(stddev)
)
------------------------------------------------------
batch = 0 :: mean : -0.00167 :: stddev : 1.00048
batch = 50 :: mean : -0.00630 :: stddev : 0.99772
batch = 100 :: mean : -0.00196 :: stddev : 1.00246
batch = 150 :: mean : -0.02309 :: stddev : 0.97656
batch = 200 :: mean : -0.00836 :: stddev : 0.99679
batch = 250 :: mean : -0.00524 :: stddev : 0.99781
batch = 300 :: mean : 0.00660 :: stddev : 1.01225
batch = 350 :: mean : -0.00551 :: stddev : 0.99595
batch = 400 :: mean : -0.00455 :: stddev : 0.99664
batch = 450 :: mean : 0.01409 :: stddev : 1.01712
The result data shown were produced for the standard dataset based on non-normalized raw data. I.e., we used the “train_dataloader” object, which in turn invoked the “train_data” dataset object. During the resulting many calls for data tuples, the __get_item__ – function applied all requested transformation operations, including normalization, to the dataset’s raw data.
Note that the iteration is based on batches. The size of the batches was defined via the parameter “batch_size” when we instantiated the dataloader object.
Looking closely at the data we learn that the defined transformation operations for our basic dataset have indeed been performed:
After a standardization of the data, the mean-value of a batch and its standard deviation should be close to 0 and 1, respectively. This is obviously the case.
__get_item__ is called during each single iteration step over the dataset. Actually during epoch-driven training cycles the defined transformation steps will be repeated again and again, epoch for epoch, for each dataset element (a tuple). (In the example above we did not use the labels [y].)
Cooperation of the dataloader with a NN-model
Let us assume that we have defined a NN-model, an optimizer (e.g. RMSprop) and a loss function (e.g. nn.NLLLoss). My simple test-model defined by a class “ConvNet()” can be described as follows:
from torchsummary import summary
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
self.cn1 = nn.Conv2d(1,32,3,1,padding=1); self.cn2 = nn.Conv2d(32,64,3,2,padding=1)
self.cn3 = nn.Conv2d(64,128,3,2, padding=1); self.cn4 = nn.Conv2d(128,256,3,2)
self.fl1 = nn.Flatten()
self.fc1 = nn.Linear(2304, 64); self.fc2 = nn.Linear(64, 10)
def forward(self,x):
x=self.cn1(x); x=F.leaky_relu(x)
x=self.cn2(x); x=F.leaky_relu(x)
x=self.cn3(x); x=F.leaky_relu(x)
x=self.cn4(x); x=F.leaky_relu(x)
x=self.fl1(x); x=self.fc1(x); x=F.leaky_relu(x)
x=self.fc2(x) # Here no RELU
# Output via softmax
mod_out=F.log_softmax(x, dim=1)
return mod_out
model = ConvNet()
# print the summary of the model
summary(model, input_size=(1, 28, 28), batch_size=-1)
# Load the model to the cuda device
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
model.to(device)
# Loss and optimizer
# ~~~~~~~~~~~~~~~~~~~~
loss_fn = nn.NLLLoss()
optimizer=optim.RMSprop(model.parameters(), lr=LR)
--------------------------------------------------------
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 32, 28, 28] 320
Conv2d-2 [-1, 64, 14, 14] 18,496
Conv2d-3 [-1, 128, 7, 7] 73,856
Conv2d-4 [-1, 256, 3, 3] 295,168
Flatten-5 [-1, 2304] 0
Linear-6 [-1, 64] 147,520
Linear-7 [-1, 10] 650
================================================================
Total params: 536,010
Trainable params: 536,010
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.37
Params size (MB): 2.04
Estimated Total Size (MB): 2.42
----------------------------------------------------------------
This model object was assigned to a variable “model“.
Note: The model must be loaded to the Nvidia GPU for using respective CUDA acceleration. This can be achieved by using the “to“-method: “model.to(device)“. The “device” was set in a snippet above.
Loss function and optimizer: You may well change the loss function and the optimizer according to your own preferences for MNIST.
Training loop
The interaction between the dataloader and the NN-model can be defined via a function training_loop() that controls the handling of a batch of data – and in our present case also transfers them to the GPU:
def train_loop(model, device, train_dataloader, optim, loss_fn, NP_Batches, epoch, b_preload=False):
size = len(train_dataloader.dataset)
print("epoch = ", epoch)
# We iterate over all batches each with batch_size, but the last X may be shorter
for batch_i, (X,y) in enumerate(train_dataloader):
if not b_PRELOAD:
if not b_ASYNC:
X, y = X.to(device), y.to(device)
else:
X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)
# Standard steps for a batch
# (Reset gradient, prediction, loss determination)
optim.zero_grad()
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
loss.backward()
optim.step()
# Some output
# The current position is given by all batches so far + len(X)
if epoch%5 == 0 and batch_i % NP_Batches == 0:
loss, current = loss.item(), batch_i * BATCH_SIZE_TRAIN + len(X)
print(f"loss: {loss:>7f} [{current:>5d} / {size:>5d}]")
Readers who have worked with a NN-model in PyTorch know this kind of steps, of course.
Note, by the way, that we will not load the X,y-data into the GPU again in experiments which perform a preloading of all tensor data to the GPU ahead of the training. This will become important in forthcoming posts.
Performing a training run
The following statements initiate a training run over the defined number of epochs.
# Set the model into train mode
model.train()
# Do we work with preloaded data ?
if not b_TENSOR_DATASET:
loader = train_dataloader
else:
print("load via TENSOR loader")
loader = train_tens_loader
# Perform the model training
start_time = time.perf_counter()
for epoch in range(NUM_EPOCHS):
#train_loop(model, device, train_dataloader, optimizer, loss_fn, NP_Batches, epoch)
train_loop(model, device, loader, optimizer, loss_fn,
NP_Batches, NUM_BATCHES, epoch, b_preload=b_PRELOAD)
end_time = time.perf_counter()
cpu_time = end_time - start_time
print()
print(" Calc time: ", cpu_time)
The somewhat convoluted conditions regarding the call of the right loader allow for a potential change to tensor datasets, respective dataloaders and a preloading of data to the GPU. As said, we will, however, neither use tensor datasets nor a preloading in this post.
Training run with standard parameters of the dataloader and all transformation operations
During training run I measure an average CPU load via Linux tools. I pick an average GPU load by logging the output of “watch -n1.0 nvidia-smi” on a console terminal.
Background activity: I keep a PA server for streaming music from a player active in the background – just to have some background load. Additionally, many open Firefox tabs create further background load. The real training run is done in a Jupyterlab tab of the Firefox browser. Due to the background activities I have a basic CPU load in my KDE environment (stretched over three 3 screens each with 2560×1440) varied between 2% to 9% ahead of the training run. The GPU load varies between 2% and 11% – depending on my interactions with application windows on the KDE desktop. These data mark a variation imposed on my measured run data. I have tried to measure averaged values over the full training period.
We start with a run for which we use standard parameters of the dataloader. All defined transformation operation, i.e. toTensor() and Normalize(), are performed for each dataset element and repeated during each epoch. Note that this basically follows the prescriptions of the standard tutorial on PyTorch datasets.
The following table shows parameter settings and performance data for such a kind of training run. Note that the batch size has a value of
BATCH_SIZE_TRAIN = 32 .
This is not an unusual value.
Parameter | Batch Size |
Run time |
Avg. CPU load |
Avg. GPU load |
Avg. GPU Watts |
---|---|---|---|---|---|
Standard Shuffle |
32 | 486 secs | 17% | 13% | 36 |
A run over 40 epochs takes 468 secs – which is a disaster. The GPU and CPU obviously work far from their capacity limits.
epoch = 0
initial loss: 2.316442
....
epoch = 40
avg loss = 0.0005348916165530682
Calc time: 468.9557462849998
Model convergence is good, however. We started with an averaged loss of 2.35 to 2.5 over the very first batch of the first epoch. The final average loss of epoch 40 varies statistically between 0.0005 and 0.002. Of course, 40 epochs with a LR of 1.e-4 is much too much effort for MNIST, but this is not the topic here. Actually, we are running into massive overfitting after the first 15 epochs. But, we are interested in other things.
The central questions we must try to answer are:
- Which part of this lousy performance is due to the internal and the enforced transformation processes?
- Which part is due to the transfer of data batches to the GPU?
- Which part is due to the small model itself? Can we improve the overall performance by varying the batch size?
Changing parameters of the dataloader
Looking around in the Internet you may find a hint that you should set the parameter pin_memory to True. Let us try this. From now on we set our control parameter for using standard parameters to False and the parameter b_PIN_MEMORY to True:
b_STANDARD_PARAMS = False, b_PIN_MEMORY = True .
Parameters settings are listed in the leftmost column of the following result table.
Parameter | Batch Size |
Run time |
Avg. CPU load |
Avg. GPU load |
Avg. GPU Watts |
---|---|---|---|---|---|
Shuffle pin_memory=True |
32 | 475 secs | 17% | 13% | 36 |
epoch = 0
initial loss: 2.316442
....
epoch = 40
avg loss = 0.0007310810033231974
Calc time: 475.5468245379998
Well, that did not help much on my system. The reason might be that the latest version of the Dataset class determine automatically whether pin_memory is used. More important is, however, that this parameter gets really effective, only, when multiple workers are active for data transformations and subsequent loading.
So, let us in addition try different values of the dataloader’s parameter num_workers, which corresponds the “number of worker processes” [NW]. This parameter enables a kind of parallelization. Instead of one dataloader process we start multiple ones, which will use available CPU cores in parallel. While this may improve the performance of data transformation, note however that the data transfer to the GPU may still become a bottleneck. Here, both the physical architectures of mainboard, CPU and GPU on one side and the number of physically available pipelines to the GPU may play a role. Also the RAM’s bandwidth and latency get important.
Parameter | Batch Size |
Run time |
Avg. CPU load |
Avg. GPU load |
Avg. GPU Watts |
---|---|---|---|---|---|
Shuffle pin_memory=True num_workers=1 persistent_workers=True |
32 | 445 secs | 24% | 14% | 36 |
Shuffle pin_memory=True num_workers=2 |
32 | 266 secs | 42% | 22% | 42 |
Shuffle pin_memory=True num_workers=3 |
32 | 220 secs | 52% | 27% | 48 |
Shuffle pin_memory=True num_workers=4 persistent_workers=True |
32 | 215 secs | 52% | 27% | 48 |
Shuffle pin_memory=True num_workers=6 persistent_workers=True |
32 | 212 secs | 53% | 27% | 48 |
Side remark: I have allowed for 7 threads to run in the startup script for Jupyterlab (export OPENBLAS_NUM_THREADS=7; export OMP_NUM_THREADS=7). Note that num_workers=0 implies that all available workers shall be used. If none is present one must be created. To avoid this we also set the parameter “persistent_workers=True”.
With only one worker we find only a very small improvement. This had to be expected. But up to num_workers = 4, we see a major improvement. However, for more workers we find no further substantial gain. The rather small improvements for NW > 4 may even reside within the range of statistical fluctuations.
What could be reasons for this finding?
- Well, I have just 4 real CPU cores and only due to hyperthreading 8 CPU threads on the i7-6700K CPU of the test system. The hyperthreads may not come with much of an advantage in our situation.
- Another point is the number of tensors to be transferred. For a small batch size [BS] as BS=32 there are so many tensors to transfer to the GPU that the loader workers may have to wait for one another to finish. Data below for a larger BS actually seem to support this point of view.
- The model itself may be too small to efficiently challenge the GPU. In the end, i.e. for a sufficiently large number NW, batch for batch is just put through – without an option of further gains.
Anyway, further workers appear to be pretty useless under my present conditions – and the GPU load won’t go up. We obviously have to change other parameters in addition.
The impact of the Batch Size
Lets us change the batch size (via our respective control parameter) to BATCH_SIZE_TRAIN=64 and further up to BATCH_SIZE_TRAIN=256. Note that with such a change we reduce the number of tensors which have to be transferred to the GPU.
Parameter | Batch Size |
Run time |
Avg. CPU load |
Avg. GPU load |
Avg. GPU Watts |
---|---|---|---|---|---|
Shuffle pin_memory=True num_workers=1 persistent_workers=True |
64 | 362 secs | 13% | 20% | 36 |
Shuffle pin_memory=True num_workers=3 persistent_workers=True |
64 | 160 secs | 55% | 28% | 50 |
Shuffle pin_memory=True num_workers=4 persistent_workers=True |
64 | 129 secs | 73% | 34% | 53 |
Shuffle pin_memory=True num_workers=6 persistent_workers=True |
64 | 120 secs | 74% | 36% | 57 |
Shuffle pin_memory=True num_workers=1 persistent_workers=True |
128 | 334 secs | 12% | 19% | 39 |
Shuffle pin_memory=True num_workers=4 persistent_workers=True |
128 | 116 secs | 64% | 33% | 53 |
Shuffle pin_memory=True num_workers=6 persistent_workers=True |
128 | 87 secs | 90% | 44% | 72 |
Shuffle pin_memory=True num_workers=1 persistent_workers=True |
256 | 312 secs | 13% | 19% | 39 |
Shuffle pin_memory=True num_workers=4 persistent_workers=True |
256 | 107 secs | 62% | 34% | 62 |
Shuffle pin_memory=True num_workers=6 persistent_workers=True |
256 | 81 secs | 88% | 45% | 76 |
In this case I used the best value of three runs for each parameter. We see that the batch size plays an important role. By making the batch size larger, we rise the GPU load significantly. This is to be expected of course. But this does not tell us what we can gain reducing transformation processes and the transfer of batches to the GPU.
Summary of data
The following graphics summarizes some of the runtime data – and their dependence on the Number of Workers [NW] and the Batch Size [BS].

The data groups from left to right belong to BS=32, 64, 128, 256, respectively. Note the plateau for a small batch size of 32 and a number of workers NW ≥ 4. The plateau is, however, not as pronounced for larger batch sizes. So, for larger batch sizes BS ≥ 64, additional workers NW ≥ 6 actually do help!
In my case however, with 6 workers I reach a CPU load already close to 90%. So, I stopped adding further workers at this point.
Note that the effect of larger batch sizes clearly diminishes with a doubling of the BS. This is a sign that the provision of data to the GPU by only one worker happens much too slowly.
We see that we get relatively optimal data for 6 workers NW=6 and a batch size of BS=128. Going to a larger batch size does not come with a substantial improvement; we gain less than 10% in runtime reduction. I have seen similar effects for TF2 pipelines.
(By the way: BS=128 appeared to be an optimal value for the 4060 TI with respect to performance gains vs. energy consumption in other image related ML-trainings, too.)
Let us face it: We have approached a kind of limit – and my GPU still does not work anywhere near its capacity level for the MNIST case. Its average load got stuck at 45%, only.
Discussion and Conclusion
The above results for small datasets, small batch sizes and small NN models indicate a real mess regarding overall performance. The GPU works far below its limits. We saw that the lousy performance only in parts was due to the relatively small number of sequential operations coming with a limited model. The GPU also waits for a faster flow of input data.
In this post we saw at least that two parameters controlling a PyTorch dataloader object may help to improve the overall performance. In particular the number of workers and the batch size were decisive.
- We could improve the overall performance by raising the number of parallel dataloader workers.
- We could raise the GPU load and at the same time improve the turnaround performance by raising the batch size.
The first point indicates that we can get some performance by parallelizing CPU-related transformation operations – and thus increasing the flow of batches to the GPU. But, obviously, a parallelization effort regarding the CPU can reach a limit rather quickly – if the CPU does not have enough cores and if its core performance is not sufficient. This may happen on older PC systems as mine.
The second point describes a simple standard effect: the GPU uses more internal units to perform model operations in parallel. Whether fewer, but larger batches improve the data transfer to the GPU is, however, not clear, at all.
All in all we got down from 468 secs to almost 80 secs for the MNIST example, a simple CNN and 40 epochs by using both effects. This is already a factor of 5.9. However, we still experience a big gap between CPU/RAM and GPU capabilities.
We still have to investigate how big the impact of the pure transfer of our eventual tensors to the GPU is. So far, we have only seen the combined effects of CPU-processing and tensor transfer. In the next post we will, therefore, try out further adjustments:
We will avoid a repetition of transform operations by doing it once ahead of the training and not epoch for epoch. We will furthermore look at the impact of a “non-blocking” option – allowing the transfer to the GPU whilst performing model calculations. We will try to find out whether the remaining shuffling and the pure tensor transfer to the GPU, organized by a Torch dataloader, are parts of our problem causes.
The results will show that working with prepared and preloaded tensor data will give us another acceleration factor of 2 to 2.4 (on a 4060 TI). I.e. we will come down to 36 secs instead 468 secs for our special tests case and model. This result actually marks a kind of competitive value for other frameworks like Keras with a Tensorflow backend. But such an approach may not help in cases when statistical image augmentation is required during training. This raises a bunch of further questions.
Stay tuned …