Tuning for Efficient Inferencing with vLLM on MI300X

amd
gpus
benchmarking
Author

Leonard Lin

Published

October 24, 2024

Over the past couple weeks I’ve been doing testing on an 8 x AMD MI300X node provided by Hot Aisle. I’ll have an article on some of my experiments training with MI300’s coming soon, but first, a deep dive into inferencing with vLLM.

There’s been plenty of inferencing testing on the MI300X done over the past few months, so my original goal was just to do some quick revalidation to gauge how rapidly the software stack has been maturing, as ROCm 6.2 is significantly improved, and vLLM v0.6 has had recent performance optimizations as well. I’m able to confirm that there have indeed been some big performance gains, but perhaps more interesting, during my testing I also ended up exploring/characterizing a few of the ways that you can tune vLLM for improved performance.

TL;DR

If you’re just perusing or short on time, you can jump straight to the conclusions for a brief summary. Also, just for fun (not accurate, just listening to the first few second, I notice they mention/fumble some of the benchmark numbers), but here’s an 11 minute NotebookLM Deep Dive Podcast summary as well:

System Info

As mentioned, compute for this testing has been provided by Hot Aisle on their newly installed Dell PowerEdge XE960 systems with 8 x MI300X GPUs. If you like the results, you can rent the exact same setup I’ve tested on here.

  • Intel® Xeon® Platinum 8470 2G, 52C/104T, 16GT/s, 105M Cache, Turbo, HT (350W)
  • AMD MI300X 8-GPU OAM 192GB 750W GPUs [x8] (1.5TB HBM3 VRAM)
  • 64GB RDIMM, 4800MT/s Dual Rank [x32] (2TB System Memory)
  • 15.36TB Enterprise NVMe Read Intensive AG Drive U.2 Gen4 [x8] (122.88TB Storage)
  • Broadcom 57608 Dual Port 200G Q112 Adapter, PCIe Full Height [x8] (8x 400G (3200Gbps ROCEv2 Ethernet))

8 x MI300X

Building vLLM

The vLLM docs recommend building from source with docker but in practice, you run into all the same problems (if not more) as building from source, so I just built in a clean venv. The docs are missing a couple of key details, so here’s a condensed version of how I built vLLM for my system:

# mamba ftw; requires newer cmake than what Ubuntu 24.04 LTS provides
mamba create -n vllm python=3.11 cmake ninja

# I used nightly (2.6.0) but now that 2.5.0 is stable, you can use that for more reliability
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2 -U

# Latest upstream triton for Triton FA
pip install triton

# I had a permission issue installing if I didn't copy the folder... 
cp -r /opt/rocm/share/amd_smi ./
cd amd_smi
pip install .
cd ..

# vLLM time
git clone https://github.com/vllm-project/vllm
cd vllm
git pull

# Dependencies
pip install numba scipy huggingface-hub -U
pip install "numpy<2" -U
pip install -r requirements-rocm.txt

# Undocumented dependencies
pip install setuptools_scm -U

# Build vLLM (change architecture if not MI300X)
PYTORCH_ROCM_ARCH="gfx942" python setup.py develop

# Verify installation
python -c 'import vllm; print(vllm.__version__)'
> 0.6.4.dev9+g5d264f4a

# Native FA2 - if you don't have ninja installed, this will take forever
cd ..
git clone https://github.com/ROCm/flash-attention.git
cd flash-attention
git pull
git submodule update --init
GPU_ARCHS="gfx942" time python setup.py install
cd ..

# bitsandbytes - not strictly necessary
# https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend
pip install 'https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_multi-backend-refactor/bitsandbytes-0.44.1.dev0-py3-none-manylinux_2_24_x86_64.whl'

At the end, you can run python vllm/collect_env.py and you should see something like:

...
PyTorch version: 2.6.0.dev20241015+rocm6.2
Is debug build: False
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: 6.2.41133-dd7f95766

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.30.5
Libc version: glibc-2.35

Python version: 3.11.10 | packaged by conda-forge | (main, Sep 30 2024, 18:08:57) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.8.0-47-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: AMD Instinct MI300X (gfx942:sramecc+:xnack-)
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: 6.2.41133
MIOpen runtime version: 3.2.0
Is XNNPACK available: True

...

Versions of relevant libraries:
[pip3] lion-pytorch==0.2.2
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] pytorch-triton-rocm==3.1.0+cf34004b8a
[pip3] pyzmq==26.2.0
[pip3] torch==2.6.0.dev20241015+rocm6.2
[pip3] torchao==0.5.0
[pip3] torchaudio==2.5.0.dev20241015+rocm6.2
[pip3] torchtune==0.3.1
[pip3] torchvision==0.20.0.dev20241015+rocm6.2
[pip3] transformers==4.45.2
[pip3] triton==3.1.0
[conda] lion-pytorch              0.2.2                    pypi_0    pypi
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] pytorch-triton-rocm       3.1.0+cf34004b8a          pypi_0    pypi
[conda] pyzmq                     26.2.0          py311h7deb3e3_3    conda-forge
[conda] torch                     2.6.0.dev20241015+rocm6.2          pypi_0    pypi
[conda] torchao                   0.5.0                    pypi_0    pypi
[conda] torchaudio                2.5.0.dev20241015+rocm6.2          pypi_0    pypi
[conda] torchtune                 0.3.1                    pypi_0    pypi
[conda] torchvision               0.20.0.dev20241015+rocm6.2          pypi_0    pypi
[conda] transformers              4.45.2                   pypi_0    pypi
[conda] triton                    3.1.0                    pypi_0    pypi
ROCM Version: 6.2.41134-65d174c3e
Neuron SDK Version: N/A
vLLM Version: 0.6.3.post2.dev90+gb7df53cd
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
...

hipBLASLt Issues with -tp 8

Note, when I originally set this up, I ran into issues with errors that looked like this with loading PyTorch’s included hipBLASLt libs:

rocblaslt error: Could not load /home/hotaisle/miniforge3/envs/shaberi/lib/python3.12/site-packages/torch/lib/hipblaslt/library/TensileLibrary_lazy_gfx942.dat

This only affected -tp 8, not -tp 4 or lower. I’d run into PyTorch hipBLASLt issues before so my go-to workaround was to skip it with TORCH_BLAS_PREFER_HIPBLASLT=0, but TJian, from EmbeddedLLM, came through with the solution, which was to increase the open files ulimit, eg setting:

ulimit -n 131072

EmbeddedLLM (who did the original vLLM ROCm port, I believe) recently also published an article on How to Build vLLM on MI300X from Source that covers some similar ground (and some of the same potential issues) if you’re looking for another take on vLLM setup.

Also, you can take a look at my setup ipynb for full details.

vLLM Tuning

This section will detail a couple of the settings we tested to help characterize/improve vLLM performance, and the previous hipBLASLt issue segues nicely into our first tuning topic: hipBLAS vs hipBLASLt.

hipBLAS vs hipBLASLt

hipBLAS is AMD’s standard BLAS interface for vector/matrix operations, while hipBLASLt is targeted with tuned kernels specifically for GEMM operations. Since can easily turn this on and off with TORCH_BLAS_PREFER_HIPBLASLT, we can also see if there’s any performance difference in practice. We use vLLM’s benchmark_throughput.py to do our testing.

So, one thing to preface with, is that while hipBLASLt basically always outperformed hipBLAS, the performance differences in my testing were highly variable, dependent on both input and output length, so if you’re looking to characterize performance exactly, you’ll probably need to simulate your own workload. I saw anywhere from +0.4-11.9% in performance.

Here’s a sweep with Llama-3.1-8B-Instruct with input_len=0:

output_len hipBLAS tps hipBLASLt tps % Difference tps
128 10591.1 11845.6 11.85
256 11049.0 11089.0 0.36
512 11125.3 11325.7 1.80
1024 10903.1 11180.7 2.55
2048 10441.8 10733.4 2.79
4096 8849.4 8960.3 1.25

And here’s one with Llama-3.1-8B-Instruct where we fix output_len=128 and sweep the input_len:

input_len hipBLAS tps hipBLASLt tps % Difference tps
128 20609.0 21345.9 3.58
256 27755.2 28269.3 1.85
512 37137.5 39335.0 5.92
1024 49055.8 53931.8 9.94
2048 61932.7 68066.9 9.9
4096 72107.9 77925.5 8.07

It looks like at larger input_len, there seem to be more solid gains.

I saw a mention that for non-even input/outputs, hipBLASLt could also be more efficient, but when I tested input_len=131, output_len=131, I got a +1.29% uplift, which seemed worse than the 128 in/out gains.

Still, free performance is free performance. It seems like it’d be safe to say that you should always prefer HIPBLASLt (the default since PyTorch 2.4 I believe).

As usual, for full details, see my Jupyter notebook.

Flash Attention

ROCm now has two Flash Attention implementations, a Triton and Composable Kernel version. vLLM uses the Triton implementation by default (it is more broadly compatible, working with RNDA3 GPUs for example), but it is missing some features, like not having Sliding Window support (it won’t work with Mistral and other models that use SWA). The CK version is supposedly upstreamed, but I couldn’t get the upstream FA2 repo to build, and per my setup notes above, I used the ROCm/flash-attention fork. There is a VLLM_USE_FLASH_ATTN_TRITON env variable to control whether to use the Triton FA or not.

Our testing with Flash Attention is similar to our prior hipBLASLt testing. Here’s what a sweep of Llama-3.1-8B-Instruct with input_len=0 looks like:

output_len triton tps ck tps % Difference tps
128 11111.6 11985.9 7.87
256 11191.0 12422.2 11.00
512 11670.6 11548.6 -1.05
1024 10888.4 11077.8 1.74
2048 9844.9 9828.1 -0.17
4096 7995.5 8025.8 0.38

While the CK is overall faster, sometimes it loses to the Triton version.

Let’s see what it looks like with a fixed output_len=128 and variable input_len:

input_len triton tps ck tps % Difference tps
128 17381.8 21647.3 24.54
256 24437.4 29065.8 18.94
512 39432.0 40659.1 3.11
1024 48892.9 53986.2 10.42
2048 62844.8 68172.6 8.48
4096 71548.2 78139.2 9.21

We’re seeing sometimes huge (almost 25%) performance gains using the CK FA vs Triton with different input_len.

Finally, lets test some oddball input_len/output_len just to see. In all of these, the CK version seems to win out:

input_len output_len triton tps ck tps % Difference tps
131 131 17602.3 21533.2 22.33
2000 2000 16936.5 17104.8 0.99
2048 2048 16811.4 17143.6 1.98

Basically, if you are not using the CK FA implementation (vLLM defaults to the the Triton FA), then you are probably missing out on a fair amount of performance (not to mention, not using FA at all on any SWA models). (notebook)

Tensor Parallelism vs Pipeline Parallelism

Note, for all our tests I’ve been using -tp 8 - tensor parallelism, which has been vLLM’s default way of serving with multiple GPUs, but it’s probably also worth noting that vLLM now has beta support for pipeline parallelism as well for certain models.

In general, it seems like you will probably get superior token throughput by using the smallest tp required for fitting your model in memory and then using pipeline parallelism or another distributed backend to serve with your GPUs. For lower batch sizes/concurrency, you might be better off with a higher tp. I ran this test with Mixtral-8x7B-Instruct-v0.1 as part of replicating some prior benchmarks, so this is a preliminary conclusion that needs more exploration:

bsz tp 2 tps 2 x tp1 tps % Difference tps
1 125.5 119.4 -4.85
2 173.6 238.8 37.52
4 303.7 315.9 3.99
8 785.2 530.9 -32.39
16 1212.3 1033.3 -14.76
32 2628.7 1634.5 -37.82
64 4246.4 3393.6 -20.08
128 6444.2 6376.9 -1.04
256 6936.8 9369.2 35.07
512 7502 11578.2 54.34
1024 7942.3 12604.5 58.70

More details on these results can be found in this very messy notebook.

2024-10-30 UPDATE: Scheduler Steps

I’ve been making some updates and errata at the bottom of the page, but this optimization is big enough that I feel like I have to add this in the main section. I only just saw this, but apparently, last week EmbeddedLLM also posted an excellent writeup, Serving LLMs on AMD MI300X: Best Practices on the vLLM Blog. It’s well worth a read. I went through and ran all the suggested optimizations and the one that stood out was in modifying the multi-step scheduler. This was one of the vLLM 0.6 performance improvements, but what I didn’t realize was that by default, --num-scheduler-steps is set to 1 (effectively off) and that a recommended value is 10-15.

In my testing, I set it to --num-scheduler-steps 15 and the throughput difference was quite frankly, insane:

input_len output_len ss=1 tps ss=15 tps % Difference
0 128 11728.8 21396.5 82.43
0 256 11217.2 21638.6 92.91
0 512 11497.5 21993.7 91.29
0 1024 11203.4 21907.4 95.54
0 2048 10702.3 19976.1 86.65
0 4096 8965.8 14485.7 61.57
input_len output_len ss=1 tps ss=15 tps % Difference
128 128 21318.3 35745.2 67.67
256 128 28244.2 45797.9 62.15
512 128 40052.5 59037.9 47.40
1024 128 51987.4 72281.8 39.04
2048 128 66720.0 83190.6 24.69
4096 128 76695.1 89324.9 16.47

Basically, increasing your scheduler steps can almost double your token generation throughput. So, um, yeah, you should probably do that.

Tuning Conclusions

While my testing wasn’t comprehensive I think the results are strong enough to say that you will want to use HIPBLASLt (the default), and the CK FA2 implementation (not the default! set VLLM_USE_TRITON_FLASH_ATTN=0) when running vLLM for improved performance. You should also definitely be setting the --num-scheduler-steps to something higher than the default.

Benchmarks

OK, with that out of the way, let’s get to what I was planning on doing, which was to compare how the latest ROCm 6.2 + vLLM 0.6.x improves on older benchmarks. We’ll work through this chronologically.

Mixtral 8x7B

Back in June 2024, Tensorwave published AMD’s MI300X Outperforms NVIDIA’s H100 for LLM Inference which tested MI300X performance with MK1 Flywheel.

bsz MK1 2GPU tps vLLM 2GPU tps % Difference tps
1 142 119.4 -15.92
2 280 238.8 -14.71
4 466 315.9 -32.22
8 804 530.9 -33.96
16 1452 1033.3 -28.83
32 2662 1634.5 -38.60
64 4442 3393.6 -23.60
128 6348 6376.9 0.46
256 6292 9369.2 48.91
512 6292 11578.2 84.01
1024 6288 12604.5 100.45

It’s interesting to note that while vLLM 0.6.x seems to still fall a bit behind at lower batch sizes, at bsz=128, things turn around, and whereas MK1 Flywheel (from June 2024) seems to top off, vLLM now keeps scaling, ending up at +100% TPS at bsz=1024.

Runpod also published their own Mixtral benchmarks in July 2024, AMD MI300X vs. Nvidia H100 SXM: Performance Comparison on Mixtral 8x7B Inference, with a very similar methodology (extrapolating 2 GPUs at TP1). They use vLLM 0.4.3 and ROCm 6.1.2 (presumably PyTorch 2.3? not mentioned) so here we can get a relatively clean comparison of how the latest versions of the same software improve performance:

bsz Runpod 2GPU tps lhl 2GPU tps % Difference tps
1 122.2 119.4 -2.33
2 244.5 238.8 -2.33
4 377 315.9 -16.21
8 550.6 530.9 -3.58
16 1078 1033.3 -4.15
32 1756.6 1634.5 -6.95
64 3236.3 3393.6 4.86
128 5043.5 6376.9 26.44
256 7208 9369.2 29.98
512 7989.1 11578.2 44.92
1024 8801.7 12604.5 43.21

While most of the numbers up to bsz=64 seem within system variance (except for bsz=4, hmm), we can see that above that, we seem to be getting some pretty huge gains. At bsz=1024, each GPU is pushing almost an extra 2000 tok/s.

Llama 3.1 405B

The last comparison I want to make, is actually a bit different, and may be interesting in a different way. dstack recently published Benchmarking Llama 3.1 405B on 8x AMD MI300X GPUs just a week or so ago in October 2024, running on an almost identical Hot Aisle machine (SVR13 vs SVR09, so sitting within a rack of each other).

They published their system details what appears to be a very close HEAD build of vLLM similar to mine, albeit against ROCm 6.1 and the then-stable version (2.4.1) of PyTorch.

dstack’s Surprising vLLM vs TGI results

dstack’s Surprising vLLM vs TGI results

They found that TGI significantly outperformed vLLM 0.6.x, which was quite surprising. The last big comparison I saw was BentoML’s Benchmarking LLM Inference Backends: vLLM, LMDeploy, MLC-LLM, TensorRT-LLM, and TGI back in June and the token generation seemed comparable between most of the models. (those tests were on A100s and with Llama 3 8B and 70B Q4 models, so admittedly note quite 1:1)

Still, since I had an almost identical system at my fingertips, I did a quick initial test when I first saw those numbers and didn’t seem to be seeing the same huge differences.

Since dstack also published their full testing repo as well, I felt it was worth spending a bit more effort in trying to replicate some of their testing.

Instead of benchmark_throughput.py, their testing was using vLLM’s benchmark_serving.py. Llama 3 405B is big enough that loading/benching it isn’t particularly pleasant, so I just re-ran their first 80 token test to try to match their topline results.

I found that running the same tests, my ROCm 6.2 + PyTorch 2.6 + CK FA setup I got significantly better OOTB results:

bsz dstack vLLM tps lhl vLLM tps % Difference tps
1 36.4 32.3 -11.37
2 61.6 62.7 1.72
4 112.3 115.9 3.21
8 198.9 200.3 0.69
16 337.3 221.7 -34.27
32 524.1 406.1 -22.50
64 460.6 681.0 47.85
128 645.9 980.6 51.83
256 814.3 1220.5 49.88
512 955.8 1412.3 47.77
1024 1038.3 1549.8 49.25
2048 1049.0 1609.3 53.42

From bs=64, about +50% faster throughput. This maybe mirrors the Runpod performance delta. My vLLM results were still about 30% lower than the TGI numbers, however.

TFFT was a similar story, where my tests were again significantly improved (but still worse than the TGI results):

bsz dstack vLLM Mean TTFT (ms) lhl vLLM Mean TTFT (ms) % Difference tps
1 314.58 239.86 -23.75
2 372.80 258.13 -30.76
4 397.54 319.37 -19.66
8 646.20 484.87 -24.97
16 1034.69 675.79 -34.69
32 1546.00 1132.04 -26.78
64 2698.98 2022.17 -25.08
128 5161.26 3773.87 -26.88
256 11787.30 8584.42 -27.17
512 27060.40 19352.70 -28.48
1024 58452.87 42155.73 -27.88
2048 128150.00 88453.70 -30.98

Now, I’m not really a vLLM tuning expert, but I wouldn’t have expected such a big difference (or anyone using vLLM if that was the case). So, once again I hit up TJian and he mentioned that the dstack TGI config was set to --max-concurrent-requests 8192, which was signficantly higher than vLLM’s.

I couldn’t find a way to adjust vLLM’s concurrent requests exactly, but what vLLM does seem to have is a max_num_batched_tokens parameter (found via this issue) which seems to affect a formula based on GPU blocks. vLLM has a very short Performance and Tuning page in their docs, and at the bottom, includes this little bit on vLLM’s batching:

You can tune the performance by changing max_num_batched_tokens. By default, it is set to 512, which has the best ITL on A100 in the initial benchmark (llama 70B and mixtral 8x22B). Smaller max_num_batched_tokens achieves better ITL because there are fewer prefills interrupting decodes. Higher max_num_batched_tokens achieves better TTFT as you can put more prefill to the batch.

If max_num_batched_tokens is the same as max_model_len, that’s almost the equivalent to the default scheduling policy (except that it still prioritizes decodes).

Note that the default value (512) of max_num_batched_tokens is optimized for ITL, and it may have lower throughput than the default scheduler.

We recommend you set max_num_batched_tokens > 2048 for throughput.

So, armed with this, I ran the tests with max_num_batched_tokens=1024 and max_num_batched_tokens=2048 (beyond that it OOMs).

For max_num_batched_tokens=2048, we now get to a solid 2X of dstack’s original results:

bsz dstack vLLM tps lhl vLLM tps % Difference tps
1 36.40 32.11 -11.79
2 61.65 62.17 0.84
4 112.27 117.55 4.70
8 198.89 203.16 2.15
16 337.26 229.27 -32.02
32 524.06 439.97 -16.05
64 460.56 753.95 63.70
128 645.88 1152.32 78.41
256 814.31 1531.02 88.01
512 955.75 1865.03 95.14
1024 1038.34 2087.99 101.09

And, while it still seems to have a lead at some batch sizes, at higher batch sizes, vLLM seems to finally be about neck and neck with TGI:

bsz lhl vLLM tps dstack TGI tps % Difference tps
1 32.11 53.52 66.68
2 62.17 60.80 -2.20
4 117.55 143.13 21.76
8 203.16 266.97 31.41
16 229.27 422.14 84.12
32 439.97 528.31 20.08
64 753.95 884.41 17.30
128 1152.32 1226.40 6.43
256 1531.02 1612.85 5.34
512 1865.03 1850.80 -0.76
1024 2087.99 1991.88 -4.60

Now, this is only some preliminary testing with one single input/output combination, so there’s probably a lot left to explore. I haven’t looked very much at TGI’s AMD optimizations (or at TGI at all, actually).

I also left a fair bit of the raw numbers out for brevity, but all the testing data is available in my MI300-testing repo.

Conclusions

I think the dstack replication/tuning is actually a great way to wrap up, as it shows that there can be huge performance disparities when running on exactly the same hardware with very slight configuration/software version changes.

It also points towards exploring beyond vLLM for inference. Besides TGI, SGLang also has AMD support and looks promising, and there are probably other inference engines suitable for MI300X as well.

There’s also plenty of performance left on the table still for inferencing for vLLM as well:

But, in terms of actual conclusions from the testing:

  • Running the latest ROCm, PyTorch, CK FA, and vLLM can give you substantial (up to 50%) throughput improvements on MI300 now vs a few months ago
  • hipBLASLt consistently outperforms hipBLAS, you should always try to use it
  • the Composable Kernel Flash Attention can offer even bigger performance gains vs the Triton version and probably should be preferred. You will need to build this separately and also set VLLM_USE_TRITON_FLASH_ATTN=0
  • Proper tuning of max_num_batched_tokens is key for vLLM performance and you might have up to a 100% throughput difference depending on your batching

Overall, MI300X is already a strong platform for inferencing, but it’s clear there’s a lot yet to be properly explored (or at least, publicly published). Hopefully, this testing helps to push those boundaries a bit.

Updates