Scheduling in inference engines

15 min read

Inference engines are systems designed to provide generative AI APIs: you send your text request to them, and you get text responses back.

Usually these systems run ‘online’ - that is, they provide a server, to which requests can be sent (nowadays, in the openAI compatible format, a kind of frozen version of the openAI API circa a couple of years ago). Examples are vllm, sglang, tgi.

These systems tend to want to run on a GPU for efficiency, and GPUs don’t run ‘online’ - they want to process lots of data at once, in batches. So we need to get the requests - arriving online, over time - into the ‘in-progress batch’ that the LLM on the GPU can work on efficiently For autoregressive LLMs (i.e. generative text models, like LLaMA, Deepseek, etc.) this ‘scheduling’ can be done after each generated token. Doing this scheduling at the ‘token-level’ has come to be called ‘continuous batching’. For non-autoregressive LLMs, we’re stuck with ‘request-level’ scheduling, usually by manually ticking a time clock, and processing all the requests in between ticks in a batch. .

The size of the batch - the sum of the sizes of all the requests in the batch - is constrained by the available memory in the system. The number of requests in the batch is the number of things that are being worked on at a time. Maximizing the number of actively processing requests is a good way to improve the performance of your inference system Probably. We can saturate the compute capacity of the accelerator without using all of the GPU memory, so adding more sequences to the batch doesn’t necessarily improve the throughput. But it does increase some other useful properties. .

This is a scheduling problem. We’re going to have a look at how its done, and think about how to improve it.

Defining the problem§

Performant LLM inference requires us to store some memory for each token that is processed. This has to be stored in VRAM, and the amount of VRAM we have therefore bounds the number of sequences we can process at once.

The first inference systems would have this working batch be allocated as a rectangular array, like this:

               Sequence Length (Tokens) →
       0    8    16   24   32   40   48   56   64
     ┌────┬────┬────┬────┬────┬────┬────┬────┬────┐
  0  │████│████│████│████│████│████│████│████│░░░░│ ← (60)
  1  │████│████│████│████│████│░░░░│░░░░│░░░░│░░░░│ ← (40)
  2  │████│████│████│████│████│████│░░░░│░░░░│░░░░│ ← (48)
  3  │████│████│████│░░░░│░░░░│░░░░│░░░░│░░░░│░░░░│ ← (24)
  4  │████│████│░░░░│░░░░│░░░░│░░░░│░░░░│░░░░│░░░░│ ← (16)
  5  │████│░░░░│░░░░│░░░░│░░░░│░░░░│░░░░│░░░░│░░░░│ ← (8)
     └────┴────┴────┴────┴────┴────┴────┴────┴────┘
     ↑                                           ↑
  Start                                    Max Length

The scheduling problem in this case is simple. When a request comes in, we ask, ‘is there a free slot’? If there is, we put the request in the slot, and start working on it. If there isn’t, we wait. As time goes on, the amount of space used by each sequence grows. This is never a problem, because our rectangle has exactly enough space for the largest possible size the sequence could grow to.

There’s some pretty obvious problems with this approach. All the gray squares are just wasted memory. VRAM is expensive, we can’t go around allocating it and not using it.

The solution that most systems have settled on is paged attention. But now we’ve dealt with the fragmentation problem, we’ve made it possible to fit many more sequences in memory. Each request that we want to process needs to be prefilled - which requires us to allocate memory for each incoming token. Over time, each sequence in the batch will grow - by one token per iteration More, if we do speculative decoding. .

We have a hard constraint on the total amount of memory we can use. So we need to make two different decisions:

Which requests are we going to bring into the batch. The obvious thing to do is to just take requests as they come. But (as we’ll see later), this doesn’t necessarily give you maximal performance.

What should we do when we run out of space for existing sequences to continue growing. One of our sequences will need to be ‘preempted’. Which one? When should we bring it back?

The vLLM scheduling algorithm, annotated§

vLLM is probably the most popular inference engine. It’s used by lots of different companies to provide LLM inference.

vLLM’s scheduler implementation can be found here. vLLM makes use of a feature they call ‘chunked prefilling’, first outlined (to my knowledge) here, where both prefills and decodes are managed in a single heterogenous batch. So instead of two different ‘modes’ - one in which you prefill new sequences and add them into the batch, one in which you run decode on the sequences in the batch - you just have a single batch containing both partial prefills and decodes.

For each element of the batch the scheduling decisions is ‘by how much should this sequence grow at this iteration’. For prefills, this can be >1>1 token (i.e. you could prefill the whole prompt at once). For (non-speculative) decodes this is 1.

The core of the scheduler logic is encapsulated in the Scheduler class, reproduced here with all its beautiful docstrings.

class SchedulerInterface(ABC):
  @abstractmethod
  def schedule(self) -> "SchedulerOutput":
      """Schedule the requests to process in this scheduling step.

        The scheduling decision is made at the iteration level.
        Each scheduling step corresponds to a single forward
        pass of the model. Therefore, this method is called
        repeatedly by a busy loop in the engine.

        Essentially, the scheduler produces a dictionary of
        {req_id: num_tokens} that specifies how many tokens
        to process for each request in this scheduling step.
        For example, num_tokens can be as large as the number
        of prompt tokens for new requests, or it can be 1 for
        the requests that are auto-regressively generating new
        tokens one by one. Otherwise, it can be somewhere in
        between in case of chunked prefills, prefix caching,
        speculative decoding, etc.

        Additionally, the scheduler also returns useful data
        about each request or the batch as a whole. The model
        runner will use this information in
        preparing inputs to the model.

        Returns:
            A SchedulerOutput object containing information
            about the scheduled requests.
        """
      ...

  @abstractmethod
  def update_from_output(
      self,
      scheduler_output: "SchedulerOutput",
      model_runner_output: "ModelRunnerOutput"
      ) -> dict[int, "EngineCoreOutputs"]:
      """Update the scheduler state based on the model runner output.

        This method is called after the model runner has processed the
        scheduled requests. The model runner output includes generated
        token ids, draft token ids for next step, etc. The scheduler
        uses this information to update its states, checks the
        finished requests, and returns the output for each request.
      """
      ...

  @abstractmethod
  def add_request(self, request: "Request") -> None:
        """Add a new request to the scheduler's internal queue.

        Args:
            request: The new request being added.
        """
      ...

  @abstractmethod
  def finish_requests(self, request_ids, finished_status) -> None:
      """Finish the requests in the scheduler's internal
      queue. If the request is not in the queue, this
      method will do nothing.

        This method is called in two cases:
        1. When the request is aborted by the client.
        2. When the frontend process detects a stop string
            of the request after de-tokenizing its generated tokens.

        Args:
            request_ids: A single or a list of request IDs.
            finished_status: The finished status of the given requests.
        """
      ...

  # ... more methods

The scheduler implementation maintains several different sources of state:

class Scheduler(SchedulerInterface):
  def __init__(...) ...:
    ...
    self.running: list[Request] = []
    # A queue implementation that yields requests according to the policy. The
    # policy can either be FCFS, or it can be PRIORITY. Either just a list, or
    # some sort of heap (for priority-based)
    self.waiting = create_request_queue(self.policy)

When it comes time to schedule new requests, we start with re-scheduling all the requests in the ‘running’ list:

class Scheduler(SchedulerInterface):
    ...
    def schedule(self) -> SchedulerOutput:
      ...
      # the maximum number of tokens that can be inferenced in any given
      # forward pass, across all requests.
      token_budget = self.max_num_scheduled_tokens

      # First, we re-schedule requests that are in the running state.
      while req_index < len(self.running) and token_budget > 0:
          request = self.running[req_index]

          # num_tokens_with_spec: The total length that this sequence would 'like'
          # to grow to. I.e. for decode, current length + 1. For prefill, just
          # the length of the prompt
          # num_output_placeholders: Space that's made for this sequence to
          # grow into(?), in the event that we don't schedule every iteration
          # num_computed_tokens: The current length of the sequence
          num_new_tokens = (request.num_tokens_with_spec +
                            request.num_output_placeholders -
                            request.num_computed_tokens)

          # use a configuration variable to cap the max number of computed
          # tokens per sequence.
          if (0 < self.scheduler_config.long_prefill_token_threshold <
                  num_new_tokens):
              num_new_tokens = (
                  self.scheduler_config.long_prefill_token_threshold)

          # Don't go over our total token budget.
          num_new_tokens = min(num_new_tokens, token_budget)

          # Next lines deal with scheduling for encoder-decoder models.
          # ...

          # Now we get into the core of the scheduling!
          while True:
              # We check our KV cache to see if there's space for this request.
              new_blocks = self.kv_cache_manager.allocate_slots(
                  request,
                  num_new_tokens,
                  num_lookahead_tokens=self.num_lookahead_tokens)

              if new_blocks is None:
                  # The request cannot be scheduled.
                  # Remember, we're scheduling the running requests here, so
                  # not putting one back into the batch is preemption!

                  # vLLM supports two scheduling algorithms: priority based,
                  # and FCFS. For priority-based, preempt the lowest-priority
                  # request.
                  if self.policy == SchedulingPolicy.PRIORITY:
                      preempted_req = max(
                          self.running,
                          key=lambda r: (r.priority, r.arrival_time),
                      )
                      self.running.remove(preempted_req)
                      if preempted_req in scheduled_running_reqs:
                          scheduled_running_reqs.remove(preempted_req)
                  else:
                      # the self.running list is kept in insertion order,
                      # so for FCFS, just evict the most recently inserted
                      # request
                      preempted_req = self.running.pop()

Once we’ve finished re-scheduling the running requests, we schedule the waiting requests:

        # Next, schedule the WAITING requests.
        # If we've done any preemption already, no need to put in any waiting
        # requests, since we know they won't fit (NB: this isn't actually true!
        # They might fit, but they have lower priority than the preempted requests
        # according to the scheduling algorithm)
        if not preempted_reqs:
            while self.waiting and token_budget > 0:
                if len(self.running) == self.max_num_running_reqs:
                    break

                # ... First check a bunch of pending states. For example: is
                # the KV cache for this request currently being transferred from
                # a remote KV store? Is a state machine currently being compiled
                # so that we can do structured outputs? If either is true, don't
                # schedule this request yet.

                # Then we do lots of prefix caching logic, ...

                # Then, we see if there's space to schedule the request.
                new_blocks = self.kv_cache_manager.allocate_slots(
                    request,
                    # token sources: either we've already got them from an
                    # 'external' cache, or we've already got them from a 'local'
                    # cache, or we're just about to compute them
                    num_new_tokens + num_external_computed_tokens,
                    num_new_local_computed_tokens,
                    new_computed_blocks,
                    num_lookahead_tokens=effective_lookahead_tokens,
                    delay_cache_blocks=load_kv_async,
                    num_encoder_tokens=num_encoder_tokens,
                )

                if new_blocks is None:
                    # The request cannot be scheduled.
                    break

                # otherwise it can be scheduled, ...

        # ...
        # Return the scheduler data
        new_reqs_data = [
            NewRequestData.from_request(
                req, req_to_new_blocks[req.request_id].get_block_ids())
            for req in scheduled_new_reqs
        ]
        cached_reqs_data = self._make_cached_request_data(
            scheduled_running_reqs,
            scheduled_resumed_reqs,
            num_scheduled_tokens,
            scheduled_spec_decode_tokens,
            req_to_new_blocks,
        )

        scheduler_output = SchedulerOutput(
            scheduled_new_reqs=new_reqs_data,
            scheduled_cached_reqs=cached_reqs_data,
            num_scheduled_tokens=num_scheduled_tokens,
            total_num_scheduled_tokens=total_num_scheduled_tokens,
            # ...
        )

The algorithm has a few interesting properties:

  1. It’s completely first come first served There are a few edge cases: 1. For encoder-decoder (whisper) models, there’s some logic that I’ve elided that will let requests with shorter encoder prompts schedule before requests with longer encoder prompts, if the higher priority requests won’t fit into the encoder batch. 2. For models that have a prefix cache hit in a remote KV cache store, the scheduling algorithm will skip scheduling those requests until the KV cache transfer completes. 3. Same applies for models whose structured output FSM is being constructed. . If a sequence gets preempted, it goes straight to the head of the queue. Even if a preempted sequence is so large that a smaller queued sequence could fit in its place, we don’t take the chance to do so, since that would break the fairness guarantee.

  2. It doesn’t take into account ‘temporal cache locality’. Prefix caching is a big deal for prefill inference. Scheduling requests with similar prefix caches next to each other increases the chance that the prefix cache that they share will remain in VRAM for all the requests. But doing so would break the FCFS guarantee.

Pessimization: how to break the vLLM scheduler§

One interesting way to think about optimization for complex non-deterministic algorithms like scheduling is to start with pessimisation. That is, first we think about all the ways we could produce pathologically bad inputs that would ruin the performance of the algorithm. Then, we tweak our algorithm to handle those sorts of inputs better Think pivot selection in quicksort. .

For concreteness, lets say we’ve setup a vLLM instance, and we’re going to send requests to it. Imagine the KV cache is of size NN. Also assume that the maximum sequence length that the system will accept is NN You usually would pick some bound for this number much lower, partly for the reasons we’ll outline. - i.e. a single user can in principle send a request that will fill up the whole KV cache. For the sake of simplicity, assume we generate just one token per request.

Triggering fragmentation§

Then, send the following sequence of requests:

N-1 ┤ █   █   █   █   █   █   █   █   █   █
    │ █   █   █   █   █   █   █   █   █   █
    │ █   █   █   █   █   █   █   █   █   █
    │ █   █   █   █   █   █   █   █   █   █
    │ █   █   █   █   █   █   █   █   █   █
  1 ┤ █ ▄ █ ▄ █ ▄ █ ▄ █ ▄ █ ▄ █ ▄ █ ▄ █ ▄ █ ▄
  0 └─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─...→ Time
      0 1 2 3 4 5 6 7 8 9 ...

What’s the scheduling behaviour? Well, first request 00 of size N1N-1 schedules. Since it needs 11 token of space to grow into, we can’t schedule request 11. After 1 iteration it completes, so we schedule request 11. But now we can’t schedule request 22! So request 1 completes alone, and then we schedule request 22, which completes alone, then we schedule request 33, which completes alone, etc. Apparently this is called the “Convoy effect”.

How could we do better? Well, if our users had been kinder to us, the same requests might have come in like this:

N-1 ┤                     █ █ █ █ █ █ █ █ █ █
    │                     █ █ █ █ █ █ █ █ █ █
    │                     █ █ █ █ █ █ █ █ █ █
    │                     █ █ █ █ █ █ █ █ █ █
    │                     █ █ █ █ █ █ █ █ █ █
  1 ┤ ▄ ▄ ▄ ▄ ▄ ▄ ▄ ▄ ▄ ▄ █ █ █ █ █ █ █ █ █ █
  0 └─┴─┴─┴─┴─┴─┴┴┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴──...→ Time
      0 1 2 3 4 5 6 7 8 9 ...

Then, we could just do the first 1010 requests in the same batch Assuming N>10 . The rest would still complete in series, but that was going to happen anyway. This should be  50%~50\% faster!

So one idea is we can sort the incoming requests as they come in! But this brings with it the problem of ‘starvation’. If you have some system for being ‘a bit unfair’ to sequences for performance reasons, you have to bound this unfairness, otherwise pathological conditions could leave some requests postponed forever. For example, if the long requests come from user AA, and the short ones from user BB, if the load from user AA is sufficient to fill up the working batch, then user BB never gets a turn.

Triggering thrashing§

Even when the batch is too full to bring in new sequences, all the sequences already in the batch grow over time. When they grow, too large, one (or more) sequences will get preempted. When the sequence is preempted, we throw away Or swap out, but the default setting in vLLM is to recompute. all the work we’ve done so far, and the sequence has to start again later.

Given a sequence of requests, how can a nefarious user order them to maximise the amount of useless compute that the engine performs?

First, simplify. Assume each sequence has 11 prefill token. Assume a KV cache size of 1616. Assume that each sequence wants to generate 1616 output tokens. Assume an infinite queue of sequences.

At time 00, 88 sequences will join the batch We need to give each sequence at least 1 token to grow into . At time 11, the KV cache will be full (1616 tokens: 88 prefill tokens, 88 decode tokens), so we need to evict some sequences. Since we need 11 token of space for each request that proceeds to the next iteration, we have to preempt the last 33 sequences. At time 22, the same thing will happen again, so we cut down our batch size again, and again, … Once the first sequence completes, we can start the process again with the second sequence in the first slot. Only the sequence in the first slot ever completes.


  0 ░ → 0 ░▓ → 0 ░▓▓ → 0 ░▓▓▓ → ... 0 ░▓▓▓ → ... →  0 ░▓▓▓▓▓▓▓▓▓▓
  1 ░   1 ░▓   1 ░▓▓   1 ░▓▓▓       1 ░▓▓▓            x
  2 ░   2 ░▓   2 ░▓▓   2 ░▓▓▓         x
  3 ░   3 ░▓   3 ░▓▓     x
  4 ░   4 ░▓     x
  5 ░   5 ░▓     x
  6 ░   6 ░▓     x
  7 ░   7 ░▓     x

  Legend: ░ = Prefill, ▓ = Decode x = Preempted

Each completion of length NN triggers O(N)O(N) preemptions. Is this optimal? Pessimal?

Conclusion§

The vLLM scheduler is simple: FCFS, preempt when you run out of space, resume when you can. It’s predictable, it’s fair, and it works.

But the pathological cases aren’t academic. Send requests in the wrong order and you trigger convoy effects. Let the batch grow too large and you get thrashing.

The problem gets harder as systems improve. Longer context windows mean more fragmentation risk. Prefix caching creates exploitable locality. Multi-tenant workloads have competing priorities. Current schedulers mostly ignore this.

Whether anyone builds something smarter, or whether simple continues to win—hard to say.

Suggest an edit

Last modified: 23 Oct 2025