Fast Reasoning on GPT-OSS with Speculative Decoding and Arctic Inference


Reasoning models like OpenAI’s new gpt-oss-20B and gpt-oss-120b “think out loud.” They plan, branch and revise their generated outputs. While these extra reasoning steps often lead to higher quality and more accurate results, they also introduce trade-offs: more tokens generated, slower response times and noticeable degradations in app performance.
For agentic workloads, speed is nonnegotiable. Many have strict latency thresholds that must be met to run reliably in production, yet reasoning models can create exponential delays that render agentic applications impractical and limit their potential.
At Snowflake AI Research, we developed Arctic Inference, a state-of-the-art open source vLLM plugin to optimize LLM decoding. In this blog, we’ll walk you through how we applied our lightweight speculative decoding method, Arctic Speculator, to the gpt-oss models, to improve generation speed by 1.6-1.8x.
Arctic Speculator for gpt-oss
Reasoning models often produce long chains of thought, generating far more tokens than other LLMs. Each token requires a forward pass, meaning the entire sequence runs through all model layers to predict the next token. Over long reasoning chains, these repeated passes compound quickly, causing high latency.
Arctic Speculator addresses this bottleneck as a lightweight draft model — a smaller, faster model that predicts several likely next tokens in a single step. Arctic Speculator integrates directly into the forward pass using the model’s hidden state to make predictions. The larger base model then verifies in parallel, so when predictions are accepted, multiple tokens are produced in a single verification pass. You can learn more about how Arctic Speculator works here.
We trained Arctic Speculator draft models for gpt-oss-20B and gpt-oss-120B, which can be run easily with vLLM via the Arctic Inference plugin.
Architecture and training
For each model, we trained a draft-model speculator based on the LSTM architecture with 1.76B parameters that ingests the base model’s final hidden state and last token ID and predicts the next three tokens.
We trained the model for 3,000 steps (~1 epoch) on the Ultrachat and Magicoder data sets, using offline distillation to match the original base model’s distribution. Our training was done on 16 H200 GPUs using the Arctic Training open source framework, and took three hours for gpt-oss-20B and 21 hours for gpt-oss-120B. Our training code is open sourced and fully reproducible via Arctic Training’s declarative YAML-based configs.
These trained draft models can be found on Hugging Face:
Arctic Speculator results
We evaluated our Arctic Speculator models using real-world requests from the ShareGPT data set and measured the output tokens per second:
Model: gpt-oss-120b | ShareGPT (tokens/s) | HumanEval (tokens/s) |
---|---|---|
No speculation (TP=4) | 220.2 | 220.7 |
Arctic Speculator (TP=4) | 377.3 (1.7x faster) | 400.0 (1.8x faster) |
Model: gpt-oss-20b | ShareGPT (tokens/s) | HumanEval (tokens/s) |
---|---|---|
No speculation (TP=4) | 298.5 | 301.2 |
Arctic Speculator (TP=4) | 476.2 (1.6x faster) | 490.2 (1.6x faster) |
The acceptance rate averages 44%-50% across our benchmark settings. With a speculation length of 3, this means that our model can, on average, accurately predict 2.3-2.5 tokens in advance for each iteration. This results in an end-to-end TPOT (time per output token) speedup of 1.6-1.8x.
Try it yourself with Arctic Inference
Everything discussed in this blog is open sourced, and you can try it using Arctic Inference.
Arctic Inference is a drop-in vLLM plugin; you keep the same CLI and APIs and turn on Arctic Speculator with just a command-line change.
uv pip install --pre vllm==0.10.1+gptoss \
--extra-index-url https://wheels.vllm.ai/gpt-oss/ \
--extra-index-url https://download.pytorch.org/whl/nightly/cu128 \
--index-strategy unsafe-best-match
Then install Arctic Inference using branch release-gpt-oss.
git clone --branch release-gpt-oss https://github.com/snowflakedb/ArcticInference.git
pip install ArcticInference
Run vLLM.
vllm serve openai/gpt-oss-120b \
--tensor-parallel-size 4 \
--speculative-config '{
"method": "arctic",
"model": "Snowflake/Arctic-LSTM-Speculator-gpt-oss-120B"
"num_speculative_tokens": 3,
"disable_by_batch_size": 64
}'
Optimizing open models for real-world AI
At Snowflake AI Research, we build in the open because AI is evolving fast, and so is our research. We share our progress openly and iteratively, tackling one enterprise AI challenge at a time. Our latest step is releasing Arctic Speculator models through Arctic Inference, helping developers cut latency and make agentic workloads practical on their own infrastructure.
Appendix (Benchmark details)
This section contains full instructions and configuration parameters for reproducing our benchmark results with Arctic Inference.
vLLM serve:
# Same as above
Benchmark:
model_path="openai/gpt-oss-120b"
# ShareGPT
python vllm/benchmarks/benchmark_serving.py --backend vllm --base-url http://127.0.0.1:8000 --dataset-path=ShareGPT_2023.05.04v0_Wasteland_Edition.json --dataset-name=sharegpt --model $model_path --num-prompts 100 --max_concurrency 1
# HumanEval
python vllm/benchmarks/benchmark_serving.py --backend vllm --base-url http://127.0.0.1:8000 --dataset-path=HumanEval.jsonl --dataset-name=humaneval --model $model_path --num-prompts 100 --max_concurrency 1
class HumanEvalDataset(BenchmarkDataset):
"""
Implements the HumanEval dataset. Loads data from a JSON file and generates
sample requests based on conversation turns.
"""
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.load_data()
def load_data(self) -> None:
if self.dataset_path is None:
raise ValueError("dataset_path must be provided for loading data.")
json_list = []
with open(self.dataset_path, encoding="utf-8") as f:
json_list = list(f)
self.data = []
for json_str in json_list:
item = json.loads(json_str)
self.data.append(item)
random.seed(self.random_seed)
random.shuffle(self.data)
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
lora_path: Optional[str] = None,
max_loras: Optional[int] = None,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
**kwargs,
) -> list:
samples: list = []
for entry in self.data:
if len(samples) >= num_requests:
break
prompt = entry["prompt"]
lora_request, tokenizer = self.get_random_lora_request(
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
prompt_ids = tokenizer(prompt).input_ids
prompt_len = len(prompt_ids)
if enable_multimodal_chat:
prompt = self.apply_multimodal_chat_transformation(
prompt, None)
samples.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
lora_request=lora_request,
))
self.maybe_oversample_requests(samples, num_requests)
return samples