Back to Blog

Tuning Batch Size for Your Workload

Goldilocks wasn't an engineer, but she understood optimization. Too small, too big, just right.

Batch size is the most important knob in LLM serving. Get it wrong and you either waste expensive GPUs or make users wait.

The Optimization Space

Three constraints bound your batch size:

Memory: Larger batches need more KV cache. At some point, you OOM.

Latency: Larger batches increase per-request latency. At some point, you miss SLA.

Throughput: Smaller batches waste GPU compute. At some point, your costs explode.

def find_optimal_batch_size(
    model,
    memory_limit_gb: float,
    latency_sla_ms: float,
    min_throughput_tps: float,
    prompt_length: int,
    output_length: int,
) -> int:
    results = []

    for batch_size in [1, 2, 4, 8, 16, 32, 64]:
        try:
            memory, latency, throughput = benchmark(
                model, batch_size, prompt_length, output_length
            )
        except OOMError:
            break  # Hit memory limit

        if memory > memory_limit_gb:
            break  # Memory constraint

        if latency > latency_sla_ms:
            break  # Latency constraint

        if throughput >= min_throughput_tps:
            results.append({
                "batch_size": batch_size,
                "memory_gb": memory,
                "latency_ms": latency,
                "throughput_tps": throughput,
            })

    # Return largest batch that meets all constraints
    return max(results, key=lambda x: x["batch_size"])["batch_size"]

The Benchmarking Protocol

Synthetic benchmarks lie. Benchmark with your actual workload:

def realistic_benchmark(
    model,
    batch_size: int,
    workload_sample: list[Request],
    duration_seconds: int = 60
) -> BenchmarkResult:
    results = []
    start = time.time()

    while time.time() - start < duration_seconds:
        # Sample requests matching your production distribution
        batch = random.sample(workload_sample, batch_size)

        batch_start = time.time()
        outputs = model.generate_batch(batch)
        batch_latency = time.time() - batch_start

        for i, output in enumerate(outputs):
            results.append({
                "input_tokens": batch[i].input_tokens,
                "output_tokens": len(output),
                "latency_ms": batch_latency * 1000,
            })

    total_tokens = sum(r["output_tokens"] for r in results)
    elapsed = time.time() - start

    return BenchmarkResult(
        throughput_tps=total_tokens / elapsed,
        latency_p50=percentile([r["latency_ms"] for r in results], 50),
        latency_p99=percentile([r["latency_ms"] for r in results], 99),
        memory_gb=get_gpu_memory_usage(),
    )

Key: use your actual prompt length distribution and output length distribution.

Dynamic Batch Sizing

Fixed batch size optimizes for average load. Production has variance.

class DynamicBatcher:
    def __init__(
        self,
        min_batch: int = 1,
        max_batch: int = 32,
        target_latency_ms: float = 500,
    ):
        self.min_batch = min_batch
        self.max_batch = max_batch
        self.target_latency = target_latency_ms
        self.current_batch_size = 8  # Start middle
        self.recent_latencies = []

    def adjust_batch_size(self, last_latency_ms: float):
        self.recent_latencies.append(last_latency_ms)
        if len(self.recent_latencies) > 100:
            self.recent_latencies.pop(0)

        avg_latency = mean(self.recent_latencies)

        if avg_latency > self.target_latency * 1.1:
            # Too slow, reduce batch
            self.current_batch_size = max(
                self.min_batch,
                self.current_batch_size - 2
            )
        elif avg_latency < self.target_latency * 0.7:
            # Room to grow, increase batch
            self.current_batch_size = min(
                self.max_batch,
                self.current_batch_size + 2
            )

        return self.current_batch_size

When latency creeps up, shrink batches. When latency is comfortable, grow them.

Memory-Aware Batching

Token count matters more than request count for memory:

class MemoryAwareBatcher:
    def __init__(self, max_tokens: int = 100_000):
        self.max_tokens = max_tokens

    def make_batch(self, pending: list[Request]) -> list[Request]:
        batch = []
        tokens = 0

        for request in pending:
            request_tokens = request.input_tokens + request.max_output_tokens

            if tokens + request_tokens > self.max_tokens:
                break

            batch.append(request)
            tokens += request_tokens

        return batch

This naturally adapts: more requests when they're small, fewer when they're large.

The Interaction with Continuous Batching

With continuous batching, "batch size" becomes "max concurrent requests":

class ContinuousBatchConfig:
    def __init__(self):
        self.max_num_seqs = 32           # Max concurrent sequences
        self.max_num_batched_tokens = 4096  # Tokens per iteration step
        self.max_model_len = 8192        # Max sequence length

    def can_add_request(self, current_seqs: int, new_request: Request) -> bool:
        if current_seqs >= self.max_num_seqs:
            return False

        # Also check token budget for this iteration
        return True

vLLM's max_num_seqs is the equivalent of batch size for continuous batching.

Common Mistakes

Optimizing for peak throughput: Batch size 64 might give best throughput, but latency is 5 seconds. Users leave before seeing results.

Ignoring variance: Average latency of 500ms is fine. But if P99 is 10 seconds because some batches have all-long requests, you have a problem.

Static sizing for dynamic load: Fixed batch size 16 is great at 10 QPS. At 100 QPS, you need to adapt.

Forgetting prefill: Batch size tuning usually focuses on decode. But batching prefill is different (compute-bound vs memory-bound).

The right batch size isn't a number. It's a policy that adapts to conditions while respecting constraints.