Tuning Batch Size for Your Workload
Goldilocks wasn't an engineer, but she understood optimization. Too small, too big, just right.
Batch size is the most important knob in LLM serving. Get it wrong and you either waste expensive GPUs or make users wait.
The Optimization Space
Three constraints bound your batch size:
Memory: Larger batches need more KV cache. At some point, you OOM.
Latency: Larger batches increase per-request latency. At some point, you miss SLA.
Throughput: Smaller batches waste GPU compute. At some point, your costs explode.
def find_optimal_batch_size(
model,
memory_limit_gb: float,
latency_sla_ms: float,
min_throughput_tps: float,
prompt_length: int,
output_length: int,
) -> int:
results = []
for batch_size in [1, 2, 4, 8, 16, 32, 64]:
try:
memory, latency, throughput = benchmark(
model, batch_size, prompt_length, output_length
)
except OOMError:
break # Hit memory limit
if memory > memory_limit_gb:
break # Memory constraint
if latency > latency_sla_ms:
break # Latency constraint
if throughput >= min_throughput_tps:
results.append({
"batch_size": batch_size,
"memory_gb": memory,
"latency_ms": latency,
"throughput_tps": throughput,
})
# Return largest batch that meets all constraints
return max(results, key=lambda x: x["batch_size"])["batch_size"]
The Benchmarking Protocol
Synthetic benchmarks lie. Benchmark with your actual workload:
def realistic_benchmark(
model,
batch_size: int,
workload_sample: list[Request],
duration_seconds: int = 60
) -> BenchmarkResult:
results = []
start = time.time()
while time.time() - start < duration_seconds:
# Sample requests matching your production distribution
batch = random.sample(workload_sample, batch_size)
batch_start = time.time()
outputs = model.generate_batch(batch)
batch_latency = time.time() - batch_start
for i, output in enumerate(outputs):
results.append({
"input_tokens": batch[i].input_tokens,
"output_tokens": len(output),
"latency_ms": batch_latency * 1000,
})
total_tokens = sum(r["output_tokens"] for r in results)
elapsed = time.time() - start
return BenchmarkResult(
throughput_tps=total_tokens / elapsed,
latency_p50=percentile([r["latency_ms"] for r in results], 50),
latency_p99=percentile([r["latency_ms"] for r in results], 99),
memory_gb=get_gpu_memory_usage(),
)
Key: use your actual prompt length distribution and output length distribution.
Dynamic Batch Sizing
Fixed batch size optimizes for average load. Production has variance.
class DynamicBatcher:
def __init__(
self,
min_batch: int = 1,
max_batch: int = 32,
target_latency_ms: float = 500,
):
self.min_batch = min_batch
self.max_batch = max_batch
self.target_latency = target_latency_ms
self.current_batch_size = 8 # Start middle
self.recent_latencies = []
def adjust_batch_size(self, last_latency_ms: float):
self.recent_latencies.append(last_latency_ms)
if len(self.recent_latencies) > 100:
self.recent_latencies.pop(0)
avg_latency = mean(self.recent_latencies)
if avg_latency > self.target_latency * 1.1:
# Too slow, reduce batch
self.current_batch_size = max(
self.min_batch,
self.current_batch_size - 2
)
elif avg_latency < self.target_latency * 0.7:
# Room to grow, increase batch
self.current_batch_size = min(
self.max_batch,
self.current_batch_size + 2
)
return self.current_batch_size
When latency creeps up, shrink batches. When latency is comfortable, grow them.
Memory-Aware Batching
Token count matters more than request count for memory:
class MemoryAwareBatcher:
def __init__(self, max_tokens: int = 100_000):
self.max_tokens = max_tokens
def make_batch(self, pending: list[Request]) -> list[Request]:
batch = []
tokens = 0
for request in pending:
request_tokens = request.input_tokens + request.max_output_tokens
if tokens + request_tokens > self.max_tokens:
break
batch.append(request)
tokens += request_tokens
return batch
This naturally adapts: more requests when they're small, fewer when they're large.
The Interaction with Continuous Batching
With continuous batching, "batch size" becomes "max concurrent requests":
class ContinuousBatchConfig:
def __init__(self):
self.max_num_seqs = 32 # Max concurrent sequences
self.max_num_batched_tokens = 4096 # Tokens per iteration step
self.max_model_len = 8192 # Max sequence length
def can_add_request(self, current_seqs: int, new_request: Request) -> bool:
if current_seqs >= self.max_num_seqs:
return False
# Also check token budget for this iteration
return True
vLLM's max_num_seqs is the equivalent of batch size for continuous batching.
Common Mistakes
Optimizing for peak throughput: Batch size 64 might give best throughput, but latency is 5 seconds. Users leave before seeing results.
Ignoring variance: Average latency of 500ms is fine. But if P99 is 10 seconds because some batches have all-long requests, you have a problem.
Static sizing for dynamic load: Fixed batch size 16 is great at 10 QPS. At 100 QPS, you need to adapt.
Forgetting prefill: Batch size tuning usually focuses on decode. But batching prefill is different (compute-bound vs memory-bound).
The right batch size isn't a number. It's a policy that adapts to conditions while respecting constraints.