Speculative Decoding for PLaMo 2 in llama.cpp: When a 6× Faster Draft Model Still Loses
Measuring draft-model and n-gram speculation for the Mamba-hybrid pfnet/plamo-2-8b
(target) with pfnet/plamo-2-1b (draft) on Apple Silicon — a verified-lossless, mostly negative result,
plus a novel quantization bug found along the way.
1. Why this experiment
Speculative decoding is the standard recipe for accelerating autoregressive inference: a small draft model proposes several tokens, the large target model verifies them in a single batched forward pass, and accepted tokens come essentially for free. With greedy verification the output is provably identical to the target model decoding alone — the speedup is supposed to be a free lunch.
PLaMo 2 (Preferred Networks) is an interesting test case because it is not a pure transformer.
It is a Samba-style hybrid that interleaves Mamba2 (SSM) layers with sliding-window attention
(PLaMo 2 Technical Report, arXiv:2509.04897). The model family ships a
natural draft/target pair: plamo-2-1b and plamo-2-8b, sharing a tokenizer. Standalone, the 1B
decodes at 134 tok/s versus the 8B’s 21.4 tok/s on this machine —
a 6.3× ratio that looks, on paper, like ideal speculation territory.
The catch: speculative decoding requires rolling back the target model’s state when a drafted token is rejected. For transformers that is trivial. For SSMs it is not — and that is the real subject of this report.
Setup. Both models in GGUF Q8_0 with the ssm_out tensor kept bf16 (“Q8fixed”,
see section 4), greedy sampling (temp 0, top_k 1), n_predict 256,
12 prompts × 3 repeats per configuration across three tasks: Japanese chat (ja_chat),
English→Japanese translation (translation_en_ja), and code generation (code).
Configurations: fx_baseline (no speculation), fx_spec_d{2,4,8,16}
(--spec-type draft-simple with draft-n-max = 2/4/8/16, draft-n-min 1), and
fx_ngram (--spec-type ngram-simple, prompt-lookup drafting, no draft model).
2. The SSM rollback problem
A transformer’s KV cache is a per-token data structure: rejecting drafted tokens just means truncating the cache back to the last accepted position. A Mamba/SSM layer instead carries a fixed-size recurrent state that is overwritten in place at every step. Once you have advanced it through five speculative tokens, there is no per-token record to truncate back to — the information is destructively folded into the state.
This was historically a hard blocker. As recently as late 2025, vLLM raised
NotImplementedError("Mamba with speculative decoding is not supported yet") for all speculative
decoding methods on Mamba-hybrid models
(vllm-project/vllm #30114) — precisely because the
SSM fixed-size recurrent state cannot be truncated like a KV cache. By mid-2026 the problem was solved generically in
three stacks (llama.cpp, and conditionally vLLM and SGLang); see the platform landscape below.
llama.cpp solves it generically. llama_memory_recurrent keeps a rolling checkpoint buffer
(depth 8 per sequence) of SSM state tensors, saved at batch boundaries and restored on rejection
(ggml-org/llama.cpp PR #19493, merged April 2026; an
earlier alternative, #20075, was closed unmerged). In our runs the server
logged roughly 33.5 MiB per context checkpoint for the 8B model. We run our experiment on
llama.cpp on a local MacBook (Apple M4 Pro, Metal) and do not benchmark the other engines — it gives a
clean, accessible implementation in which to ask the empirical question this report answers: does speculative decoding
ever pay off here?
3. Losslessness check
All speculative configurations produced byte-identical greedy output to the baseline, verified by per-prompt SHA-256 of the generated content across all 12 prompts × 3 tasks × all configs. The checkpoint/restore machinery is exact: whatever we conclude about speed, correctness is not in question.
4. A detour: the ssm_out Q8_0 quantization bug
Before any benchmark could run, the straightforward Q8_0 conversions of both PLaMo 2 models were broken
in a strange, deterministic way: on any prompt containing a bare newline (token 10, which PLaMo 2 maps to the
reserved byte token 0x0A), the model emitted <|plamo:reserved:0x1F|> (token 31)
indefinitely. The debugging chain, each step verified:
- Tokenizer round-trip: OK. Encoding/decoding the offending prompts through the GGUF tokenizer is faithful — not a tokenizer conversion bug.
- HF transformers reference: correct. The fp32 reference implementation generates correctly on the
same prompts (run with pure-PyTorch shims for
causal_conv1d/mamba_ssm, since pfnet’s CPU fallback path is itself broken). - bf16 GGUF: correct. So the GGUF conversion is fine; the failure is introduced by quantization.
- Per-tensor isolation:
ssm_outalone is responsible. Keepingssm_in,ssm_x,ssm_dt, embeddings, or the output head in high precision while quantizing the rest still reproduces the bug; keeping onlyssm_outat bf16 fixes it completely. - It affects both model sizes. 1B and 8B fail identically, suggesting a property of the PLaMo 2 Mamba2 block’s output projection (e.g. outlier sensitivity) rather than a single bad checkpoint.
Mitigation used for every model in this study:
llama-quantize --tensor-type ssm_out=bf16 plamo-2-8b-bf16.gguf plamo-2-8b-q8fixed.gguf Q8_0
This likely affects all community PLaMo 2 GGUFs quantized with default settings. If a PLaMo 2
GGUF gets stuck emitting <|plamo:reserved:0x1F|> after a newline, re-quantize with
--tensor-type ssm_out=bf16. Reported upstream as
ggml-org/llama.cpp#24501.
(Full command log in commands-log.md, section 3, in the experiment repo.)
5. Results
Median decode throughput (tok/s) over 12 prompts × 3 repeats, speedup vs. baseline, and [token acceptance rate]:
| config | code | ja_chat | translation_en_ja |
|---|---|---|---|
| fx_baseline | 19.7 (1.00×) | 21.6 (1.00×) | 15.7 (1.00×) |
| fx_ngram | 18.1 (0.92×) [19%] | 17.6 (0.81×) [63%] | 19.0 (1.21×) [23%] |
| fx_spec_d2 | 16.5 (0.84×) [90%] | 13.8 (0.64×) [80%] | 15.5 (0.99×) [87%] |
| fx_spec_d4 | 12.0 (0.61×) [77%] | 10.2 (0.47×) [55%] | 18.5 (1.18×) [70%] |
| fx_spec_d8 | 11.6 (0.59×) [55%] | 7.9 (0.36×) [32%] | 11.1 (0.71×) [48%] |
| fx_spec_d16 | 11.7 (0.59×) [33%] | 6.2 (0.29×) [17%] | 9.3 (0.59×) [29%] |
The headline: despite the draft model being 6.3× faster than the target standalone,
draft-model speculation is never profitable here. The only configuration that beats baseline is
n-gram (prompt-lookup) drafting on translation (1.21×); spec_d4 on translation is near-breakeven at
1.18×. Everything else is a slowdown — sometimes a dramatic one.
5a. Speedup vs. draft length
Lines: draft-model speculation (draft-n-max 2–16). Dashed horizontals: n-gram
drafting (no draft length). Hover for median tok/s and acceptance. Click legend entries to toggle tasks.
Two things stand out. First, the best draft length is the shortest one almost everywhere — deeper speculation only digs the hole deeper. Second, the one task where speculation approaches profitability (translation) is the task where the baseline is slowest (15.7 tok/s): the relative cost of a speculation round is smaller when target decode steps are more expensive, and translation outputs are repetitive enough to keep acceptance high at depth 4.
5b. Acceptance decays with depth
Token acceptance rate vs. draft-n-max. Acceptance is per drafted token, so a
roughly geometric decay with depth is expected: each extra speculative position compounds the chance of an earlier
mismatch.
At depth 2 the 1B draft is excellent — 80–90% of drafted tokens accepted. By depth 16 it is wasting most of its work (17–33% acceptance), yet each round still pays full price: serial draft decode of 16 tokens, a 17-token verification batch, and SSM checkpoint traffic.
5c. Per-request distributions
Box + individual points: decode tok/s for each of the 36 requests per config (12 prompts × 3 repeats). Note baseline translation’s wide spread — prompt-dependent output length and content make that task’s baseline noisier than the others.
5d. Acceptance vs. speedup: the profitability map
Each point is one config × task aggregate. The dashed line is breakeven (1.0×). Legend items toggle individually; groups are labeled by config.
This is the chart that kills the simple mental model. spec_d2 on code achieves 90% acceptance
and still runs 16% slower than baseline. On bandwidth-bound Apple Silicon, the serial draft-model decode, the
batched verification pass, and the SSM-state checkpoint save/restore together cost more per round than the accepted
tokens save — even when nearly every drafted token is accepted. Deeper drafts amortize the fixed costs worse, not
better, because per-token acceptance decays with depth. N-gram drafting wins on translation precisely because its
drafts are nearly free: no second model, no extra decode stream — just prompt lookup.
6. Honest conclusions
- Correctness: solved. llama.cpp’s recurrent-state checkpointing (PR #19493) makes speculative decoding on Mamba-hybrid models exact — byte-identical greedy output, verified per prompt. The same snapshot-and-restore primitive now exists in vLLM and SGLang too (we tested only llama.cpp; see platform landscape).
- Performance: a measured negative result. On an M4 Pro with Metal, draft-model speculation for plamo-2-8b/plamo-2-1b at Q8 never beat baseline across three tasks and four draft depths, despite up to 90% acceptance and a 6.3× standalone speed gap. If you are deploying PLaMo 2 on Apple Silicon today, run it without a draft model.
- The one win is the cheapest one. N-gram/prompt-lookup drafting gives 1.21× on translation — a task with high surface overlap between prompt and output — at zero extra memory. It is the only speculation worth enabling here, and only for that workload shape.
- Mechanism, not mystery. Per-round costs (serial draft decode, batched verification, SSM checkpoint save/restore) exceed per-round savings on bandwidth-bound hardware where batch-of-5 verification is not much cheaper than 5 sequential steps. The result could flip on compute-rich, batch-friendly hardware (server GPUs), with a smaller relative checkpoint cost, or with a draft tuned for longer agreement runs — all worth measuring before assuming.
- Check your PLaMo 2 GGUFs. Default Q8_0 quantization breaks the model via the
ssm_outtensor; keep it at bf16.
Platform landscape & scope
We ran this experiment on llama.cpp on a local MacBook (Apple M4 Pro, Metal), and did not test the other engines. That is a deliberate scope choice, not a claim that llama.cpp is unique. As of our measurements (mid-2026), the SSM-rollback problem had been solved in at least three stacks, which independently converged on the same snapshot-and-restore primitive:
- llama.cpp (the engine we used) — merged as PR #19493 (April 2026). An earlier alternative, #20075, was closed unmerged.
- vLLM — into early 2026 it rejected the combination outright
(#30114, closed April 2026); enabled in
PR #33726 (merged Feb 2026). Spec decode now works on
Mamba/GDN hybrids conditionally — with
--mamba-cache-mode alignplus self-drafting heads (EAGLE3, MTP) — while model-free n-gram drafting can still corrupt recurrent state on some hybrids. - SGLang — enabled in PR #13434 (merged Dec 2025); ships the same primitive on NVIDIA hardware: one isolated recurrent-state slot per draft token, plus a convolution-window snapshot/restore. Not yet lossless on every backend.
The open question across all three has shifted from whether rollback is correct to which proposer types preserve state exactly — the very property this report’s SHA-256 byte-identical check measures. Because our negative throughput result is about the economics of running a separate draft model on bandwidth-bound hardware — not the correctness machinery — we expect it to be robust on similar hardware; but we have not measured vLLM or SGLang, and server-class GPUs with cheaper batched verification could change the verdict.
7. Reproducibility: the commands used
Setup — llama.cpp b9596 release binaries (macOS arm64, Metal) and the matching conversion script:
# inference binaries (llama-server, llama-completion, llama-quantize, ...)
curl -sL https://github.com/ggml-org/llama.cpp/releases/download/b9596/llama-b9596-bin-macos-arm64.tar.gz | tar -xz
# conversion tooling at the matching tag
git clone --depth 1 --branch b9596 https://github.com/ggml-org/llama.cpp llama.cpp-src
pip install ./llama.cpp-src/gguf-py sentencepiece
Models — download (plamo-2-8b is gated: accept the PLaMo Community License on its HF page first),
convert to an unquantized bf16 GGUF, then quantize with the ssm_out exemption from §4. Quantizing
straight to Q8_0 reproduces the bug:
hf download pfnet/plamo-2-1b # Apache-2.0
hf download pfnet/plamo-2-8b # PLaMo Community License, gated
python llama.cpp-src/convert_hf_to_gguf.py <snapshot-dir> \
--outfile plamo-2-8b-BF16.gguf --outtype bf16
# the fix discovered in this study: keep ssm_out unquantized
llama-quantize --tensor-type ssm_out=bf16 plamo-2-8b-BF16.gguf plamo-2-8b-Q8fixed.gguf Q8_0
llama-quantize --tensor-type ssm_out=bf16 plamo-2-1b-BF16.gguf plamo-2-1b-Q8fixed.gguf Q8_0
Benchmark configurations — one llama-server instance per config; greedy requests
(temperature 0, top_k 1, n_predict 256, cache_prompt false) against /completion. Note
--spec-type: in b9596, loading a draft model with -md alone does not enable
speculation — without it the server logs “no implementations specified for speculative
decoding” and runs plain decoding (this silently invalidated our first sweep):
# baseline (no speculation)
llama-server -m plamo-2-8b-Q8fixed.gguf -ngl 99 -c 4096 --no-webui
# draft-model speculation (sweep --spec-draft-n-max over 2, 4, 8, 16)
llama-server -m plamo-2-8b-Q8fixed.gguf -ngl 99 -c 4096 --no-webui \
--spec-type draft-simple -md plamo-2-1b-Q8fixed.gguf -ngld 99 \
--spec-draft-n-max 8 --spec-draft-n-min 1 --spec-draft-p-min 0.0
# n-gram / prompt-lookup speculation (no draft model)
llama-server -m plamo-2-8b-Q8fixed.gguf -ngl 99 -c 4096 --no-webui --spec-type ngram-simple
Per-request decode speed and draft statistics come from the response’s timings object
(predicted_per_second, draft_n, draft_n_accepted); losslessness is checked by
sha256-comparing each config’s greedy output against baseline per prompt. Harness:
bench/bench.py + bench/prompts.json + bench/analyze.py in the experiment repo.
Bug repro in one minute (§4) — any prompt containing a bare newline; the plain-Q8_0 model
emits <|plamo:reserved:0x1F|> (token 31) until the token limit, the fixed one writes code:
printf '# Fibonacci with memoization\ndef fibonacci(n, memo=None):\n' > p.txt
llama-completion -m plamo-2-1b-Q8_0.gguf -ngl 99 -n 32 --temp 0 -f p.txt # empty output
llama-completion -m plamo-2-1b-Q8fixed.gguf -ngl 99 -n 32 --temp 0 -f p.txt # correct code
Data: 6 configs × 3 tasks × 12 prompts × 3 repeats, llama.cpp server build b9596 release binaries,
greedy sampling, n_predict 256. Raw per-request JSONL and server logs in the experiment repo
(results/fx_*.jsonl). All numbers in the charts are embedded in this file; it has no network dependency
except the Plotly.js CDN. References: PLaMo 2 Technical Report (arXiv:2509.04897), ggml-org/llama.cpp PR #19493
(supersedes the closed #20075), vllm-project/vllm issue #30114 (closed April 2026).