FlashOffload: 7x Cheaper Prefills with Offloading
Improving SGLang’s Offloading Engine to get 7x Cheaper Prefills on DeepSeek V4 Flash
Offloading in LLM inference refers to keeping parts of your model in slower CPU memory, in order to run larger models than can normally fit in your GPU HBM. Generally it is a bad idea.
GPU HBM is extremely fast, CPU memory is much slower, and the interconnect between them is slower still. If every layer has to pause while weights are copied from the host, the pros of offloading (being able to run models larger than your HBM) are not usually worthwhile.
But the hardware story is changing. Grace Hopper systems connect CPU and GPU memory with much higher bandwidth. On GH200, host-to-device bandwidth can be as high as 450GB/s - fast enough that offloading is no longer automatically absurd. It is still not free, but it can become useful when the transfer can be hidden behind GPU compute.
This work was motivated by our recent access to Isambard, the UK supercomputer with Grace Hopper nodes. GH200 is one of the first platforms where CPU-GPU bandwidth is high enough that inference-engine offloading feels worth revisiting seriously, rather than treating it as a last-resort fallback when the model does not fit.
Why Offload at All?
The main reason is capacity. Large models do not always fit neatly into the GPU allocations you can actually get. This becomes particularly important when using GPUs in spot markets. Spot markets are tight, especially for the kind of large, scale-out GPU allocations that big models usually want. If a model requires a full node or a multi-node deployment, it becomes much harder to opportunistically use transient compute.
You also need to consider the type of workload. LLM inference is split into two phases with very different workload profiles which we will consider separately:
- Decode: one or a few new tokens at a time. This is often bandwidth-bound.
- Prefill: many prompt tokens at once. This is much more FLOP-heavy.
The concrete target in this post is DeepSeek V4 Flash, running in 4-bit on a single GH200 with 96GB of GPU memory and 120GB of CPU memory. The model requires about 160GB in 4-bit, so it cannot fit entirely in HBM and it also cannot fit entirely in the CPU memory pool. It has to be split across both, which makes it a useful test case for whether offloading can be made efficient rather than merely possible.
The Arithmetic of Hiding Weight Transfers
Suppose we move GB of weights out of GPU HBM and keep them in CPU memory. To use those weights in a forward pass, we need to copy them back to the GPU:
Here is the host-to-device copy time, is the number of gigabytes that must be copied from CPU memory into GPU memory for the layer, and is the sustained host-to-device bandwidth between CPU memory and GPU HBM.
The only way this does not slow us down is if we can overlap the copy with some other useful GPU work. If the GPU is busy for long enough, the copy becomes hidden. If the GPU finishes first, the forward pass stalls.
For decode, the amount of compute per layer is small. A simple way to think about the limit is to compare the host-to-device bandwidth with the GPU’s own memory bandwidth:
Here is the amount of resident GPU data streamed from HBM during the compute window, and is the GPU’s effective HBM memory bandwidth. The left-hand side is the offloaded fraction we can hope to hide; the right-hand side is the bandwidth ratio that bounds it.
On the GH200 systems available on Isambard, this ratio is roughly . That means if the GPU is streaming GB from HBM during a decode step, it can only hide about GB of CPU-to-GPU traffic.
On the 96GB GPU in a GH200 node, that gives you only a small offload budget during decode. If you hide 6GB of offloaded weights, things can look fine. If you offload something more like 45GB and need to pull it back through the CPU link, the forward pass can slow down dramatically.
Prefill is different. With a long enough prompt or a large enough batch, the GPU does much more compute per layer. Using the formula for the number of flops in an MoE layer, the rough overlap condition becomes:
In this rough estimate, is the number of MoE layers whose compute contributes to the overlap window, is the number of experts selected per token, is the number of prefill tokens, is the model hidden size, and is the expert intermediate size. The two factors of account for multiply-add FLOP counting and the two large linear projections in a simple expert MLP. is the sustained GPU compute rate we expect to achieve on that workload.
The exact constants are less important than the direction of travel. As the number of prefill tokens grows, the right-hand side grows. Eventually the computation is long enough to hide meaningful host-to-device traffic.
That makes offloading interesting for FLOP-dominated inference workloads:
- prefill nodes in a disaggregated serving system,
- diffusion-style transformer workloads,
- encoder-only or encoder-heavy models.
The most practically useful case for LLM serving is disaggregated prefill. If you can use transient GPUs as prefill workers, and those workers do not need to hold the full model in HBM at all times, you get a much more flexible serving primitive.
What SGLang Does Today
SGLang already has a CPU offload path, exposed through:
--cpu-offload-gb
This is a useful baseline, but the default implementation is not designed around overlap.
When a layer needs a weight that currently lives in CPU memory, the forward pass pauses, copies the weight to the GPU, and then continues. The transfer is directly on the critical path. In a profiler trace, you see compute blocks separated by host-to-device copies.
Figure 1: The default CPU offload path pays the copy cost directly. The schematic shows the expected bubble in the forward pass, and the SGLang trace below it shows the same pattern in practice: host-to-device copies appear in the hot path and block later kernel launches.
This is the version of offloading that usually disappoints. It may let the model run, but it is not a good serving primitive if every offloaded layer inserts a visible bubble.
FlashOffload: Double-Buffered Weights
FlashOffload changes the scheduling logic to promote as much overlap between compute and memory movement as possible.
Instead of copying a layer’s offloaded weights at the moment the layer needs them, we prefetch the next layer’s weights while the current layer is computing. The implementation uses two GPU buffers:
- buffer 0 holds the offloaded weights needed by one layer,
- buffer 1 is filled with the offloaded weights needed by the next layer.
While layer runs using buffer , we copy layer into buffer . When the model advances to the next layer, the roles swap.
Figure 2: FlashOffload keeps two full expert buffers on the GPU. During one layer, the compute stream consumes the current buffer while the copy stream prepares the next buffer; on the following layer, the buffers swap roles.
The important property is that the copy stream and the compute stream are independent. The forward pass only needs to wait if the next buffer is not ready by the time compute reaches it.
Here are the key pieces from the SGLang implementation.
First, during loading, each packed expert tensor is split into a GPU-resident prefix and a pinned CPU suffix. The pinned CPU tensors are the part we intend to stream into the active layer buffer later:
num_gpu_experts = num_local_experts - self.offload_experts
gpu_tensors[name] = tensor[:num_gpu_experts].clone(
memory_format=torch.contiguous_format
)
cpu_tensors[name] = torch.empty(
(self.offload_experts, *tensor.shape[1:]),
dtype=tensor.dtype,
device="cpu",
pin_memory=True,
)
cpu_tensors[name].copy_(tensor[num_gpu_experts:])
Second, the prefetch runs on a dedicated CUDA stream. It copies both the GPU-resident prefix and CPU-resident suffix into the selected full layer buffer, then records an event to signal that the buffer is ready:
with torch.cuda.stream(self.copy_stream):
if wait_event is not None:
self.copy_stream.wait_event(wait_event)
else:
self.copy_stream.wait_stream(torch.cuda.current_stream(self.device))
for name, dst in buffer.tensors.items():
dst[:num_gpu].copy_(store.gpu_tensors[name], non_blocking=True)
dst[num_gpu:num_gpu + num_cpu].copy_(
store.cpu_tensors[name], non_blocking=True
)
buffer.ready_event.record(self.copy_stream)
Finally, the DeepSeek V4 forward loop binds the prepared buffer immediately before the MoE layer runs, unbinds it afterwards, and then launches the prefetch for the layer that will reuse the same buffer two MoE layers later:
expert_buffer_manager.begin_forward()
for i in range(self.start_layer, self.end_layer):
# bind a full expert buffer to the model so it uses it in the forward pass
expert_buffer_manager.wait_and_bind(i, layer.mlp.experts)
try:
hidden_states, *_ = layer(...)
finally:
# remove a buffer from the model in preparation for the next layer
expert_buffer_manager.unbind(layer.mlp.experts)
expert_buffer_manager.layer_done_and_prefetch(i)
In the good case, the profiler trace changes shape: copies are still happening, but they are fully hidden GPU compute.

Figure 3: With double buffering, host-to-device copies can be overlapped with the current layer’s GPU computation. The schematic shows the intended schedule, and the SGLang prefill trace below it shows the same behavior in the implementation: memory movement is almost entirely hidden under compute.
If the prefill workload is large enough, and the offloaded slice is small enough, the transfer can disappear from the end-to-end runtime.
On GH200, this is enough to make a large prefill step meaningfully faster. With SGLang’s default offloading, a 16,384-token prefill took 2.1s, or about 7,800 tok/s. With FlashOffload, the same step took 1.47s, or about 11,100 tok/s. That is a 42% throughput improvement from changing the offload schedule. If you use transient 1xGH200 spot instances at $1/hr (which I could find on Prime Intellect’s spot market), this equates to an input token price of $0.025/MTok - 7x cheaper than the DeepSeek V4 Flash input token price that DeepSeek themselves offer.
What Happens When the Copy Is Too Long?
The overlap condition is not magic. If the next layer’s weights take longer to copy than the current layer takes to compute, the forward pass stalls.
That is useful to see explicitly.

Figure 4: When the host-to-device copy is longer than the compute window, only part of the transfer is hidden and the remaining tail becomes visible latency. The schematic shows the failure mode, and the SGLang trace shows it during batch-size-14 DeepSeek V4 Flash decoding: the copy extends far past the compute window and the GPU stalls before it can launch the next kernels.
This gives us a clean tuning rule. For a given hardware platform and model, the offloaded weight volume per layer has to fit inside the compute window. Longer prefills, larger batches, and more FLOP-heavy layers all increase the window. More offloaded weights shrink the margin.
That is why FlashOffload is mostly a prefill tool. Decode does not usually give us enough time per step to hide large CPU transfers.
Why This Fits Disaggregated Prefill
Disaggregated serving separates prefill and decode onto different workers.
Decode workers repeatedly generate tokens, usually with small per-step compute and high memory pressure. Prefill workers see the opposite workload: large prompt batches, high arithmetic intensity, and much more opportunity to keep tensor cores busy.
That makes prefill workers a natural place to use offloading. We can imagine a serving system where decode runs on stable, fully-resident GPU deployments, while transient or spot GPUs are used as extra prefill capacity. Those prefill workers may not have enough HBM to hold the full model comfortably, but they can still contribute useful throughput if moving the offloaded weights can be hidden behind compute.
This is the core deployment story for FlashOffload:
- Use fewer GPUs than would normally be required to fit the whole model.
- Keep only the active layer buffers resident in HBM.
- Use long prefill compute windows to hide CPU-to-GPU transfers.
- Turn opportunistic GPU capacity into useful prefill throughput.
The result is not the same as having the whole model resident in HBM. But for asynchronous prefill capacity, it may be the right trade.
Can We Make Decode Work?
Decode is harder because the compute window is too small. But there is one interesting escape hatch: speculative decoding.
In ordinary decode, the target model verifies one new token at a time. In speculative decoding, a draft source proposes several future tokens, and the target model verifies them in a single pass. If many of those tokens are accepted, the target model skips several serial decode steps.
For FlashOffload, speculation has another effect: it makes the target model’s verification pass wider. Instead of verifying one token, we verify many candidate tokens at once. That increases the amount of compute available to hide the CPU-to-GPU transfer.
In high-throughput serving, speculative decoding is often less useful than expected. Large draft lengths can slow down verification, especially for MoE models where more tokens activate more experts and increase memory traffic. But offloaded decode changes the accounting. If a large host-to-device copy is already happening in the background, we have a much bigger verification budget than usual.
The ideal shape would be something like this:
- Start a long host-to-device copy for the next offloaded weights.
- During that copy window, generate or retrieve a large speculative draft.
- Use the target model to verify many draft tokens while the transfer is still hidden.
Figure 5: Speculative decoding can create a wider verification step, giving the offload copy more compute to hide behind.
There are practical problems. Very long speculation depths are wasteful if the acceptance rate is not high. A thousand draft tokens are not useful if hardly any are accepted, and most current drafters are not good enough to reliably produce very deep accepted branches. Generating those drafts can also become expensive unless the draft generation itself is overlapped with the offload transfer.
We have an upcoming post about speculative-speculative decoding, which becomes very interesting here. The draft process could itself be structured to use the shadow of the long CPU transfer. But not every speculative method fits. EAGLE and DFlash, for example, are conditioned on target-model latent states, which makes them awkward here because the target model state is exactly what we are trying to avoid depending on while the copy is in flight.
So decode offload is not the main result. It is a direction that could become more plausible with speculation methods that prioritize expensive, high-quality, deep draft trees that can make use of a very large overlap window.
Conclusion
Offloading is usually dismissed because the naive version puts slow memory movement directly in the forward pass. That criticism is correct, but incomplete.
If the hardware has enough CPU-to-GPU bandwidth, and the workload has enough FLOPs, offloading can be scheduled instead of merely endured. FlashOffload applies that idea to SGLang by double-buffering offloaded weights and copying layer while layer computes.
The result is a more useful offload primitive for prefill-heavy inference. It does not make decode magically cheap, and it does not remove the need to reason about bandwidth. But it does make offloading a viable tool for the place where LLM serving most naturally gives us overlap: large, asynchronous prefill work.