Skip to content

Performance of PyTorch vs. Keras 3 with tensorflow/torch backends for a small NN-model on a Nvidia 4060 TI – I – Torch vs. Keras3/TF2 and relevant parameters

Today’s world of Machine Learning is characterized by competing frameworks. I am used to the combination of Keras with the Tensorflow2 [TF2] backend, but have turned now to using PyTorch in addition. As a beginner with PyTorch, I wanted to get an impression about potential performance advantages in comparison with the Keras/TF2 framework combination. I had read about significant performance differences on the Internet.

In addition, Keras3 has supported Torch as a backend for more than 1.5 years now. The whole bunch of resulting options appeared to be worth an investment of some experiments on my consumer graphics card – a Nvidia 4060 TI.

Just to start with something, I made some performance tests with a rather simple neural network [NN] model. The model is of course not at all representative for the state of art in Machine Learning. However, even small models can already reveal the impact of some parameter adjustments. I present the results of comparative runs in two posts. My results and some preliminary conclusions may at least be interesting for PyTorch beginners or Torch users who want to give Keras3/TF2 a chance. Test runs with really deep neural networks will follow later this year.

In this first post, I give you data for a pure PyTorch approach on the one side vs. data of a standard Keras/Tensorflow2 [TF2] combination on the other side. In a forthcoming second post, I will then turn to Keras with the Torch backend.

In case you are just interested in the capabilities of the Nvidia 4060 TI, take the results as some performance tests specific for small N-models and the efficient transfer of data to the GPU during training runs.

Note: Unfortunately, so far, I was not able to perform test runs with PyTorch and xla-compilation. The present version 2.6 of the required module “torch-xla” does not support CUDA GPUs. But, even with preceding version 2.5 of torch-xla (and torch), xla-compilation for CUDA GPUs appeared to produce discouragingly bad results after some first trials – at least on a Nvidia 4060. If I find the time, I will write about this topic in a separate post.

Test setup with small CNN model

I tested the training of a simple CNN with just 4 Conv2D layers and a small 2-layer FCN for the MNIST data. To inspect details of the NN-model see here. The model had around 563,000 parameters. So, it really is a very small model.

Nevertheless, with a batch size [BS] of BS=256, one can create a relatively high GPU load between 80% and 100% on the 4060 TI. The data transfer to the GPU becomes a key factor for the overall performance in such cases. Specialized tensor datasets of a framework may be helpful in this regard and also the optimal batch size has to be determined.

Each training run lasted over 41 epochs – each for 60,000 MNIST images. Regarding the handling of the training data, I used both standard Torch or TF2 tensors as well as dataset technology:

  • In the case of a pure Keras/TF2-approach I used datasets created by the tf.data ()-functionality, more precisely by the function “tf.data.Dataset.from_tensor_slices ()“, which is a standard to build TF2 data pipelines upon Numpy arrays or prepared TF2-tensors.
  • When testing pure PyTorch and a Keras3/torch-combination, I employed torch standard datasets and also torch tensor-datasets. Note that for Torch you must refer to a respective dataloader in the training loop. For Keras3/TF2 you must pass the dataset to the “model.fit ()”-function .

For standard PyTorch datasets, I also used 6 workers for data manipulation to prepare tensors suited for the input layers of the NN-model. The transformations done in this case to the data comprised only (1) an obligatory reformatting to CxHxW torch tensors followed by (2) a standardization of the tensor values. Respective results are given below just for an extended comparison. However, from a previous post we know already that fully prepared torch tensors are more effectively handled by a tensor-dataset and a respective dataloader (with standard parameters and num_workers = 0). So, tensor datasets on the torch side were used for a core comparison with Keras/TF2 runs.

Prepared tensors or Numpy arrays: When I used torch tensor-datasets/dataloaders or TF2 datasets (via tf.data ()-sets), I did all normalization and transformation operations only once and ahead of forming the dataset. The only task that remained then was to load the data batch-wise to the GPU – either by respective dataloaders in PyTorch or by internal mechanisms of TF2’s model.fit-function.

Shuffling was always activated. In the case of PyTorch via a parameter of the torch dataloaders. For TF2 shuffling was either controlled via an option in the definition of a tf.data-dataeset, or in cases when I delivered Numpy arrays to the model.fit-function by a respective parameter of “model.fit ()”.

My first runs were done without any special settings – for example without jit-compilation and/or without mixed-precision. A second and third bunch of runs then covered the effects of these options.

Below, I only present data for a batch-size BS=256. All frameworks and the Keras3/Torch combination gave me an optimal performance for this batch-size. All runs were done with RMSprop as the optimizer, a Learning Rate of LR=1.e-4 and CrossEntropyLoss as the loss function.

SW-versions / CUDA:
Python 3.11, Torch 2.60, Keras: 3.9.2, Nvidia driver 570.133.07.
Keras/TF2: CUDA 12.8.93, cudnn 9.8. PyTorch or Keras/torch: CUDA 12.4, cudnn: 9.1.

The differences in the CUDA/cudnn versions used stem from a global system-wide Cuda installation vs. a local installation enforced by the torch packet in my virtual Python environment; see here for more information. CUDA capabilities of the graphics card were used by both frameworks. However, only Keras3/TF2 used additional xla-compilation on the CUDA device; see a respective section below.

Note 1: I did not yet test customized training loops in my Keras3/TF2 runs. I only tested with the standard model.fit ()-interface of Keras3.

Note 2: When you want to test Keras/TF2 with or without jit-compilation yourself, be careful to avoid any hints in the environment variables of your (bash) shell which would activate jit-compilation for XLA globally. Statements for your shell configuration like

export TF_XLA_FLAGS=--tf_xla_cpu_global_jit
export TF_XLA_FLAGS=--tf_xla_auto_jit=1

get a higher priority than setting the option “jit_compile=False” in the statement model.compile(…).

Abbreviations used in tables below:

  • DS: Dataset,
  • DL: Dataloader,
  • NW (Torch DS, TF DS ): Number of workers or parallel processes,
  • SPE (Keras): Keras compile option steps_per_execution. Note: if SPE is chosen to be bigger than the batch size, it will automatically be reduced.
  • EC: Energy consumption of graphics card per second,
  • BS: Batch Size.

Hardware: i7-6700K, Nvidia 4060 TI [16GB VRAM]

Pure PyTorch reference runs – without jit-compilation and without mixed-precision

The input layer(s) of my PyTorch based NN-model defined the format which the torch input tensors had to fulfill. PyTorch’s standard datasets for certain test cases like MNIST do not always provide data fitting the requirements. In addition you may want to normalize and shuffle the data during the training of your NN. Then you have two alternatives:

  1. Worker processes preparing data on the fly: You can perform necessary modifications within a PyTorch dataset’s functionality. You can invoke a number NW of parallel worker processes (parameter num_workers of a torch dataset) there to perform the required data modifications and to send data-batches to the NN-model on the GPU, afterward.
  2. Tensor-datasets built with fully prepared tensors: You can apply the necessary modifications once via your own loops and ahead of building a tensor-dataset. A dataloader then handles already batches of fully prepared tensors – and just transfers them to the NN-model on the GPU. In the case of this option is also possible to preload the tensors to the GPU. The tensor-dataset then is built with tensors residing already in the GPU’s VRAM – and shuffling reduces to playing with indices.

The following runs discriminate between these options. The runs have been performed without jit-compilation and without mixed-precision. Optimizer (RMSprop) and loss (CrossEntropyLoss) were taken from respective torch modules. A standard torch training loop was used (with a forward and a backward pass for the training). The batch size for the dataloader was BS=256. In the case of tensor-datasets, I also preloaded the tensors to the GPU, which gave a small improvement.

Runs PyTorch / NN-model defined by Torch / MNIST dataset vs. tensor-dataset of prepared tensors / NO jit-compilation / NO mixed-precision / BS=256

Config Input data
format
NW SPE time
async
[sec]
time
preload,
[sec]
GPU
load
[%]
EC
GPU
[Watt]
Remarks
Pure PyTorch,
No jit-compilation
model/loss/optimizer by Torch
( with std. loop)
Torch MNIST dataset
with transformations (sync)
0 322 12 36 Totally inefficient, main process for dataloading does all transformations
Pure PyTorch,
No jit-compilation
model/loss/optimizer by Torch
( with std. loop)
Torch MNIST DS
with transformations (async)
6 80 45 76 Parallelized processes for data transformations,
CPU load around 86%
Pure PyTorch,
No jit-compilation
model/loss/optimizer by Torch
( with std. loop)
torch tensor-DS
over prepared tensors
standard params
0 38 35 99 138 CPU load around 18%,
high GPU load => efficient transfer and handling of data on GPU

Convergence of the loss was very good. It went down from a value averaged over all items of an epoch of 2.4 to below and around a value of 0.002. Note that I used full precision (float32 for the data and int64 for labels) during the above runs.

Note also the high GPU load for the run with a tensor dataset

Intermediate conclusion

  • Good PyTorch performance without any special settings:
    A pure PyTorch approach with a tensor-dataset of prepared tensors obviously offers a very efficient way to use the GPU optimally for (small) NN-models.

The following runs will show that the performance was really good.

Pure Keras/TF2 runs – without jit-compilation and without mixed-precision

Test-runs with Keras3/TF2 were done

  • either with providing Numpy arrays or (prepared) tensors to the model.fit-function
  • or by using tf.data datasets.

When I used tf.data datasets I set possible parameters, namely (1) num_parallel_calls for batch production and (2) buffer_size for prefetching to tf.data.AUTOTUNE. In addition, the runs with tf.data-datasets were done with a shuffling width set to 20,000.

Runs Keras3/TF2 / NN-model defined by Keras3 / MNIST data handled as prepared Numpy arrays and TF2-tensors / use of tf.data-datasets / NO jit-compilation / NO mixed-precision / BS=256

Config Input data
format
NW SPE time
async
[sec]
time
preload,
[sec]
GPU
load
[%]
EC
GPU
[Watt]
Remarks
Keras3/TF2,
jit-compile=False,
model,loss,optimizer via Keras
TF tensors, tf.data-dataset,
prefetching and num_parallel_calls autotuned
autotune 1 66 67 68-75 104-108 CPU load around 22%
Keras3/TF2,
jit-compile=False,
model,loss,optimizer via Keras
Numpy arrays, tf.data-dataset,
prefetching and num_parallel_calls autotuned
autotune 1 67 69-75 103-107 CPU load around 22%
Keras3/TF2,
jit-compile=False,
model,loss,optimizer via Keras
Numpy arrays, no tf.data-dataset,
array transfer/shuffling via model.fit()
1 65.5 68-75 102-108 CPU load varies 18-24%
Keras3/TF2,
jit-compile=False,
model,loss,optimizer via Keras
TF tensors, tf.data-dataset,
prefetching and num_parallel_calls autotuned
autotune 512 57 57 78-85 111-115 CPU load around 20%
Keras3/TF2,
jit-compile=False,
model,loss,optimizer via Keras
Numpy arrays, tf.data-dataset,
prefetching and num_parallel_calls autotuned
autotune 512 57 75-85 111-115 CPU load around 22%
Keras3/TF2,
jit-compile=False,
model,loss,optimizer via Keras
Numpy arrays, no dataset,
arrays and shuffling via model.fit()
512 58 75-85 110-116 CPU load around 19%

Convergence of the loss was slightly worse than with torch. It went down from 2.4 to below and around 0.005.

Because the Numpy arrays were already adjusted and normalized, the above Keras3/TF2 runs must be compared with PyTorch runs based on prepared tensors. The datasets were based on normalized TF-tensors or respective Numpy arrays. The CPU load varied around the given mean values in a relatively large interval of at least 6% to 8%.

Intermediate conclusions – all of course for the given situation without jit-compilation :

  1. Minor differences for providing Numpy arrays or tensors to “model.fit ()”:
    The differences of using Numpy arrays in comparison to using (prepared) TF-tensors are minor, if not insignificant.
  2. Minor differences for using tf.data datasets”:
    Without jit-compilation, using a “tf.data” dataset makes a relatively small difference vs. providing Numpy arrays or tensors to “model.fit ()”
  3. No performance improvement for tensors preloaded to the GPU:
    Without jit-compilation, pre-loading TF-tensors to the GPU makes things not better, but a slight nuance worse. Data not shown.
  4. Keras3/TF2 without jit slower than pure PyTorch :
    Without jit-compilation the performance of Keras3/TF2 is worse than a pure PyTorch approach by factors between 1.5 and 1.9.
  5. The parameter steps_per_execution [SPE] of “model.compile ()” is of major importance!
    You need to provide a sufficiently large value for SPE ≥ BS to improve the overall performance. (If you go beyond the batch size Keras will reduce it automatically. )

Note: How far you can and should raise the steps_per_execution parameter may depend on physical capabilities of your system. You have to play around to find an optimal value.

Pure Keras3/TF2 and PyTorch runs with activated jit-compilation

Let us see how big the impact of jit-compilation is for both frameworks. Shuffling for tf.data-datasets was again set to 20,000 units. For Keras3 you find infos on using jit-compile here; for PyTorch see here and here :

  • Note that for Keras we activate jit-compilation by setting a parameter “jit_compile” of the model.compile-function either to True or ‘auto’.
  • For PyTorch we get jit-compilation via using the “torch.compile ()”-function. A “mode” of compilation can be set for the default “inductor” backend.

Regarding compilation for PyTorch, I should add the following remarks:

  • For the given numbers I applied compilation to the NN-model, only. I.e.:
model_train = torch.compile (model, mode="reduce-overhead")
  • Other patterns have been proposed. See here. I tried such an extended approach (without xla), too, but the difference to just compiling the model was marginal. Respective numbers are not shown below.
  • I also tried a different compilation backend, namely backend=’cudagraphs’. (No mode-parameter can be set then. The last batch has to be dropped to guarantee equal batch sizes.) Unfortunately, “cudagraphs” gave a worse performance than the standard “inductor”-backend. Respective numbers are not shown below.

Important notes regarding XLA capabilities:

  1. Keras/TF2 automatically detects xla-capabilities of the GPU and uses them during jit-compilation.
  2. PyTorch does not use xla-capabilities automatically. Instead, PyTorch requires extra modules torch-xla (version 2.5) and torch-xla-cuda-plugin (version 2.5). The latest version 2.6 of torch-xla does not support CUDA GPUs, at all. However, even for versions 2.5, first test results were frustrating. So, I did not include XLA based results for PyTorch runs below.

Pure Keras3/TF2 and pure PyTorch runs with jit-compilation

Config Input data
format
NW SPE time
async
[sec]
time
preload,
[sec]
GPU
load
[%]
EC
GPU
[Watt]
Remarks
Keras3/TF2,
jit-compile=True,
model,loss,optimizer by Keras
Numpy arrays, no dataset,
arrays and shuffling via model.fit()
512 43.7 78-84 111-117 CPU load around 19%,
XLA was used
Keras3/TF2,
jit-compile=auto,
model,loss,optimizer by Keras
TF-tensors, no dataset,
tensors and shuffling via model.fit()
512 43.5 43.2 78-84 107-116 CPU load around 17%-18%,
XLA was used
Keras3/TF2,
jit-compile=auto,
model,loss,optimizer by Keras
Numpy in TF2-dataset (tf.data)
autotune for parallel
autotune
prefetch
1 50.6 51.1 65-72 102-104 CPU load around 20%,
XLA was used
Keras3/TF2,
jit-compile=auto,
model,loss,optimizer by Keras
Numpy in TF2-dataset (tf.data)
with or without autotune for parallel
autotune
prefetch
512 41.0 41.3 79-85 112-122 CPU load around 19%,
XLA was used
Keras3/TF2,
jit-compile=auto,
model,loss,optimizer by Keras
Tensors in TF2 dataset (tf.data) autotune
prefetch,
parallel calls
1 51.2 51.1 65-74 97-103 CPU load around 19%,
XLA was used
Keras3/TF2,
jit-compile=auto,
model,loss,optimizer by Keras
Tensors in TF2 dataset (tf.data) autotune
prefetch,
parallel calls
512 41.0 41.3 66-84 112-122 CPU load around 19%,
XLA was used
Pure PyTorch,
jit-compilation with
mode=”reduce-overhead”
model/loss/optimizer by torch with loop
unprepared torch tensors
torch dataset with workers (async)
6 82 37-44 74 84% CPU load – no improvement for tensors without tensor-dataset
– turnaround time even worse than without jit!
Pure PyTorch,
jit-compile without setting the mode
model/loss/optimizer by torch
torch tensor-dataset on prepared tensors
standard params, async
0 33 33 92-100 137-139 some variation of GPU load
Pure PyTorch,
jit-compile with
mode=”reduce-overhead”
model/loss/optimizer by torch
torch.set_float32_matmul_precision(‘high’)
torch tensor-dataset on prepared tensors
standard params, async
0 31.4 30.0 99-100 140-150 CPU load around 16%-17%
highest GPU load and EC up to 150 W for preloaded tensors
GPU temp > 72°C
Pure PyTorch,
jit-compile with
mode=”max-autotune”
model/loss/optimizer by torch
torch.set_float32_matmul_precision(‘high’)
torch tensor-dataset on prepared tensors
standard params, async
0 31.3 30.0 99-100 140-150 CPU load around 16%-17%
highest GPU load and EC up to 150 W for preloaded tensors
GPU temp > 72°C

Variations around 0.1-0.2 secs should not be taken to seriously. They may depend on some background load of my system. I averaged over three runs.

Intermediate conclusions:

  1. jit-compilation gives Keras3/TF2 a boost:
    jit-compilation gives the Keras3/TF2 a real boost up to a factor of 1.4, in particular for a high enough SPE-value.
  2. Keras3/TF2steps_per_execution important: The parameter steps_per_execution [SPE] of “model.compile ()” remains important for optimal importance – also when using jit-compilation.
  3. Keras3/TF2 with jit-compilation is slower than pure “PyTorch – with and without jit” :
    Even with jit and optimal SPE-values the turnaround times of a training with Keras3/TF2 never reach the values of a pure PyTorch approach. Keras3/TF2 with jit still is by a factor of 1.17 slower than PyTorch without jit-compilation.
    Keras3/TF2 with jit is by a factor of up to 1.36 slower than PyTorch with jit-compilation.
    For an old fan of Keras and TF2 this finding is a bit frustrating. PyTorch seems to use the GPU more efficiently than the Keras3/TF2-combination.
  4. PyTorch – minor improvements by jit-compilation:
    jit-compilation improves the PyTorch performance “only” up to a factor 1.16 on a Nvidia 4060 TI.
  5. PyTorch – other unconventional settings may help:
    According to warnings the following settings were helpful for an optimal performance:
    torch.set_float32_matmul_precision(‘high’)
    torch._dynamo.config.cache_size_limit = 64.
    The first parameter enables the usage of tensor units on the GPU. The second parameter may help to avoid some warnings. Note that its value may depend on both the model as also of the dimensions of the torch tensors.
  6. PyTorch – preloading tensors of tensor-datasets to the GPU may help a bit:
    Preloading the prepared torch tensors to the GPU before building the tensor-dataset improved performance by a few percents.
  7. PyTorch- no improvements by jit for standarad datasets with workers:
    jit-compilation does not help with unprepared tensors and a standard torch dataset with worker processes.

Pure Keras3/TF2 and pure PyTorch runs with activated mixed precision

From working with LLMs and StableDiffusion, we know very well that “mixed precision” can come with

  • a substantial improvement of performance
  • and at the same time a reduction of GPU load.

So, let us see how PyTorch and Keras3/TF2 react to mixed precision. For Keras3 just one line of code is required; see here. For PyTorch multiple elements of the training loop have to be changed; see here.

Regarding PyTorch and mixed-precision we also have to regard two special points within the training loop:

  1. It is not recommended to apply reduced precision to the steps correcting the gradients: loss.backward() and optimizer.step().
  2. Furthermore, in case of instability for deep networks we should use a scaler for gradients. The latter is done by invoking a instance of “GradScaler” and using it during the backward pass.

scaler = torch.amp.GradScaler('cuda', enabled=True)
....
# and for gradient adjustment in the training loop 
scaler.scale(loss).backward()
scaler.step(optim)
scaler.update()

I distinguish for PyTorch between using a Scaler and not using it. I indicate this by the phrase “WITH Scaler” or “NO Scaler”, respectively.

Note again: jit-compilation for PyTorch runs was done without using XLA!

Runs Keras3/TF2 and PyTorch with mixed precision / BS=256

16.8 secs
Config Input data
format
NW SPE time
async
[sec]
time
preload,
[sec]
GPU
load
[%]
EC
GPU
[Watt]
Remarks
Keras3/TF2,
NO jit-compilation
model,loss,optimizer by Keras
Numpy arrays, no dataset,
Array handling + shuffling via model.fit()
1 52.7 38-54 64-70 CPU load varying around 22%
Keras3/TF2,
NO jit-compilation,
model,loss,optimizer by Keras
Numpy arrays, no dataset,
Array handling + shuffling via model.fit()
512 41.8 57-63 76-81 CPU load varying around 22%
Keras3/TF2,
jit-compile=’auto’
model,loss,optimizer by Keras
Numpy arrays, no dataset,
Array handling + shuffling via model.fit()
1 25.8 46-84 46-90 CPU load varying around 22%,
XLA was used
Keras3/TF2,
jit-compile=’auto’
model,loss,optimizer by Keras
Numpy arrays, no dataset,
Array handling + shuffling via model.fit()
512 18.0 44-69 84-91 CPU load varying around 22%, XLA was used
Keras3/TF2,
NO jit-compilation,
model,loss,optimizer by Keras
Tensors in TF2 dataset (tf.data) autotune
prefetch,
parallel calls
1 51.1 51.3 48-55 67-71 CPU load varying around 25
Keras3/TF2,
NO jit-compilation,
model,loss,optimizer by Keras
Tensors in TF2 dataset (tf.data) autotune
prefetch,
parallel calls
512 39.9 40.7 56-65 79-84 CPU load varying around 24%
Keras3/TF2,
jit-compile=’auto’,
model,loss,optimizer by Keras
Tensors in TF2 dataset (tf.data) autotune
prefetch,
parallel calls
1 24.7 24.9 37-49 70-73 CPU varying around 22%
XLA was used
Keras3/TF2,
jit-compile=’auto’,
model,loss,optimizer by Keras
Tensors in TF2 dataset (tf.data) autotune
prefetch,
parallel calls
512 16.8 16.8 55-72 88-92 CPU load varying around 27%,
XLA was used
Keras3/TF2,
jit-compile=’auto’,
model,loss,optimizer by Keras
Numpy arrays in TF2 dataset (tf.data) autotune
prefetch,
parallel calls
512 16.8 16.8 54-71 89-92 CPU load varying around 27%,
XLA was used
Pure PyTorch,
jit-compilation with
mode=”reduce-overhead”
model/loss/optimizer by torch
WITH Scaler
unprepared torch tensors
torch dataset with workers (async)
6 82 23-28 47-48 CPU load varying around 85%
turnaround time even worse than without jit and without mixed-precision
Pure PyTorch,
NO jit-compilation
model/loss/optimizer by torch
WITH Scaler
torch tensor-dataset on prepared tensors
standard params, async
0 45.5 46.8 44-50 72-76 CPU load varying around 16%
Pure PyTorch,
jit-compilation with
mode=”reduce-overhead”
model/loss/optimizer by torch
torch.set_float32_matmul_precision(‘high’)
WITH Scaler
torch tensor-dataset on prepared tensors
standard params, async
0 38.3 39.6 45-51 72-80 CPU load varying around 16%
Pure PyTorch,
NO jit-compilation
model/loss/optimizer by torch
torch.set_float32_matmul_precision(‘high’)
NO Scaler
torch tensor-dataset on prepared tensors
standard params, async
0 37.0 36.3 53-58 78-82 CPU load varying around 16% to 17%
Pure PyTorch,
jit-compilation with
mode=”reduce-overhead”
model/loss/optimizer by torch
torch.set_float32_matmul_precision(‘high’)
NO Scaler
torch tensor-dataset on prepared tensors
standard params, async
0 27.0 26.9 64-70 88-95 CPU load varying around 20%
Note the reduced GPU load!

Variations around 0.1-0.2 secs in measured run-times should not be taken too seriously. They may depend on some background load of my system. I averaged over three runs. So, there may still be a tendency.

Intermediate conclusions:

  1. No problems regarding loss convergence:
    In general the loss convergence did not suffer with mixed precision. This is true both for PyTorch and Keras3/TF2.
  2. Substantial reduction of GPU load:
    Mixed precision comes with a substantial reduction of the average GPU load and respective energy consumption. This is independent of Keras or PyTorch. However, for the fastest runs with Keras3/TF2 the GPU load is higher than for the fastest run with PyTorch. This is also reflected in turnaround times.
  3. Keras3/TF2 – performance improvements:
    Even without jit-compilation, the performance for Keras3/TF2 improves substantially with mixed-precision (and otherwise equal parameters and settings) . High values for the steps_per_execution are crucial also for mixed precision runs.
  4. Keras3/TF2 – Top performance with jit-compilation and sufficiently high SPE values:
    A top turnaround time for 41 epochs (each for 60,000 images and with shuffling) of 16.8 secs was produced by Keras3/TF2. jit-compilation (with xla!), tf.data datasets with autotuning and mixed precision had to be combined.
  5. PyTorch – performance reduction regarding time for mixed precision with a Scaler:
    ixed performance with the recommended use of a gradient Scaler may reduce performance in comparison to standard PyTorch runs without mixed performance.
  6. PyTorch – no performance gain regarding time for mixed precision without jit-compilation: Even without a Scaler PyTorch runs with mixed precision do not give us a better turnaround time when no jit-compilation is done. However, the GPU load and respective energy consumption is reduced in comparison with full precision runs.
  7. PyTorch – only slight improvement of turnaround time for mixed precision with jit-compilation:
    With jit-compilation we may experience a slight performance gain ≤ 10% in comparison to comparable runs.
  8. PyTorch with mixed-precision and jit-compilation consistently slower than the worst Keras3/TF2 runs:
    While PyTorch consistently beat Keras3/TF2 regarding full precision runs, PyTorch (without XLA) cannot compete with Keras3/TF2 when jit-compilation is combined with datasets and mixed-precision. The best jit-compiled (without xla!) PyTorch run is by a factor of 1.6 slower than the best Keras3/TF2 run (with xla-compilation). Despite the fact that the respective energy consumption with PyTorch’s best run is roughly the same as for the best run with Keras3/TF2.

Conclusion

The runs and performance differences discussed above revealed some aspects and important parameters for potential performance gains:

  • Datasets and respective loaders should be used with both frameworks to reach optimal performance. Respective parameters like “num_workers” (Torch) or “num_parallel_calls” (TF2) must be tuned according to a given situation. For PyTorch, tensor-datasets over prepared tensors give optimal performance; for tensor-datasets num_workers should be set to 0.
  • Keras3/TF2: The parameter “steps_per_execution” of the “model.compile ()” – has big impact on performance. It should get a high value around the batch size (if possible).
  • Keras3/TF2: Both jit-compilation and mixed precision have a big impact on performance. The parameter jit_compile should be set to “auto” or True. Mixed precision should be used whenever possible.
  • Keras3/TF2: autotune parameters help with tf.data datsets.
  • PyTorch: Tensor units should be activated by using “torch.set_float32_matmul_precision (‘high’) “.
  • PyTorch: jit-compilation should be tried together with setting a respective mode paramter to “reduce-overhead“.
  • PyTorch: Mixed precision should be tried out for the forward pass in the training loop- and if possible without using a Scaler for gradients in the backward pass. If mixed-precision is used, then jit-compilation has a significant impact on performance.
  • PyTorch: Adjusting the parameter torch._dynamo.config.cache_size_limit (e.g. to 64 or 128) may help to circumvent certain warnings and improve performance a bit.

However, all in all the runtime data gathered for PyTorch and Keras3/TF2 sent mixed messages:

  • Standard precision (float32) – PyTorch better:
    A pure PyTorch approach produces very good results for standard precision. PyTorch is by factors FP2K
    1.17 FP2K ≤ 1.36
    faster, i.e. better than Keras3/TF2 and uses the GPU more efficiently.
  • Mixed precision (float16 for forward pass) and jit-compilation – Keras/TF2 much better:
    For mixed precision a pure PyTorch approach cannot compete with Keras3/TF2. PyTorch without a gradient Scaler, mixed-precision and jit-compilation is by a factor FP2K
    FP2K ≤ 0.63
    slower, i.e. worse than Keras3/TF2 at approximately the same rate of energy consumption.

I want to emphasize that these results refer to small NN-models, a batch size of 256 and a Nvidia 4060 TI GPU. Things may look differently on TPUs, or as soon Torch can be used with improved xla-compilation on a CUDA GPU.

There seem to be options for improvements on both sides. I find it a bit worrisome that Keras/TF2’s “model.fit ()”-function does not reach a GPU load above 90% even when we use tf.data datasets with tensors preloaded to the GPU. It may be that one can compensate for this effect by using a tailored and optimized training loop referring to the preloaded data by simple indices. On the other hand the interaction of datasets (and respective loader processes) may simply be less efficient than the respective data handling on the PyTorch side.

In the next post of this series I will have a look at Keras3 with its support for a torch backend.

Stay tuned …