-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Support for realistic multi-step rollouts via async vLLM API #3284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Great job on all the work you've done, and thanks for sharing it with the TRL community! This likely involves significant changes to TRL, but your motivations seem solid and well thought out—it feels like the right time to explore this direction. Would it be possible to break your work into several smaller PRs? That would make the review process much smoother. For example, you could start with a PR focused on leveraging vLLM server, followed by another that integrates the tools/agents. (Of course, feel free to divide it differently if you think there's a better approach.) |
Will do! I believe I’ve found a clean abstraction that minimizes the impact on existing code. Specifically, I’m exploring repurposing The only other change would be to pass full data dictionaries (rather than just prompts) into I’ll keep iterating on this until I find something that’s both elegant and fits my specific use case. Once it’s settled, I’ll split it into smaller, reviewable PRs. I believe this could meaningfully lower the barrier to entry in this specific domain of RL training. Here’s a minimal example showing how my use case looks now. With this, the "normal" import os, multiprocessing as mp
from contextlib import redirect_stdout, redirect_stderr
from datasets import load_dataset
from aider.coders import Coder
from aider.models import Model
from aider.io import InputOutput
from trl import GRPOConfig, GRPOTrainer
from trl.extras.vllm_client import VLLMClient
class AiderClient(VLLMClient):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
os.environ["OPENAI_API_BASE"] = f"http://{self.host}:{self.server_port}/v1/completions"
def process_one(self, data: dict[str, any]) -> tuple[str, list]:
orig = os.getcwd()
try:
temp = clone_repo_at_commit(data["repo_url"], data["base_commit"])
os.chdir(temp)
with open(os.devnull, "w") as d, redirect_stdout(d), redirect_stderr(d):
coder = Coder.create(main_model=Model("openai/our-model"), io=InputOutput(yes=True), suggest_shell_commands=False)
coder.run(data["problem_statement"])
messages = coder.format_chat_chunks().all_messages()
diff = get_head_commit_diff(temp)
finally:
clean_repo_dir(temp)
os.chdir(orig)
return diff, messages
def generate(self, data: list[dict[str, any]], timeout: int = 300, **kwargs) -> list[dict[str, any]]:
with mp.Pool(min(len(data), mp.cpu_count())) as p:
results = p.map_async(self.process_one, data).get(timeout=timeout)
for i, (d, m) in zip(data, results): i["generated_diff"] = d; i["messages"] = m
return data
trainer = GRPOTrainer(
args=GRPOConfig(use_vllm=True),
client=AiderClient(host="0.0.0.0", server_port=8000),
train_dataset=load_dataset("SWE-Gym/SWE-Gym", split="train")
)
trainer.train() |
Ok, modifying the client-server seems acceptable to me, especially if it can allow easier customization. |
Hey @BjarniHaukur – I’m also looking at migrating GRPO roll-outs to an online vLLM setup for better performance and agent-style usability. |
Hey @kwanUm, still working on it. I've been super close to finishing this for a while now. The main problem resides in the online weight syncing behavior for I've put quite a lot of though into how it would be best to integrate custom clients, and I'm relatively convinced of my approach. It decouples all the generation logic from the GRPOTrainer, and offloads it to client.generate(). It receives the (overly simplified example) class GenerationResult(TypedDict, total=False):
"""GRPO payload with required prompt/completion; extras allowed."""
# Shared inputs across N rollouts in GRPO (across many GenerationResults)
prompt: list[dict[str, str]] # {role: str, content: str}
# This comes after that, N different rollouts of the same prompt
completion: list[dict[str, str]] # {role: str, content: str}
# Extra keys and values are forwarded to the user specified reward functions
class VLLMClient(ABC):
@abstractmethod
def generate(self, data: list[dict], **kwargs) -> list[GenerationResult]:
pass
# Inside GRPOTrainer
...
output = client.generate(inputs)
...
rewards = reward_func(**output) You can check out my working branch (though cautionary warning its not stable / ready at all), it might help you. I'll post here again when I have something more concrete. Would love some help in integrating this type of behavior into TRL though! There's some semblance of it in |
Hey @qgallouedec Finally got it working and found an abstraction that I believe could fit in TRL (#3469). The new Instead of extending When trl vllm-serve-async \
--model Qwen/Qwen3-8B \
--max_model_len 8192 \
--enable-auto-tool-choice \
--reasoning_parser deepseek_r1 \
--tool-call-parser hermes This allows any LLM-powered application with measurable reward metrics to be trained with My CodeRepairRL project provides an example of a rollout_func using a terminal-based coding agent (Nano-Agent). @kwanUm, this might be of interest to you too! |
Uh oh!
There was an error while loading. Please reload this page.
Feature request
I propose adding a new OpenAI-compatible vLLM API server for use with the
GRPOTrainer
.The implementation mirrors the weight syncing logic from
trl/scripts/vllm_serve.py
, but offloads most complexity to the existingvllm.entrypoints.openai.api_server
infrastructure.This enables training on significantly more complex rollouts than the standard synchronous .generate() endpoint can support. By supporting the OpenAI API interface, it also allows seamless integration with a wide range of existing agent frameworks and products.
This direction is a step toward reproducing pipelines like OpenHands LM 32B. I strongly suspect that Claude 3.7 Sonnet was trained in a similar fashion, iteratively reinforced using rollouts generated through its own Claude Code scaffolding.
Motivation
Currently, TRL only supports synchronous, batched
.generate()
calls for inference. This restricts the types of rollouts that can be created, especially in domains that benefit from having multi-step approaches, tool use, or environment interaction.I’ve been using TRL for my Master’s thesis on reinforcement learning for language models in the program repair domain. In several GRPO experiments, I repeatedly encountered the same limitation: with
.generate()
, all context construction, planning, and feedback extraction must happen within a single call. For example, in tasks from SWE-Gym, the model needs to generate code edits for real repositories. To do this in one.generate()
call, the user must manually construct the relevant repo context and later parse outputs like diffs to extract useful reward signals. This makes experimentation slow and always feels like “reinventing the wheel.”Rather than building ad-hoc scaffolding from scratch, I began exploring how to integrate existing coding agents like Aider directly into the training loop. These agents already support rich workflows such as repo mapping, diff parsing, and iterative interaction—and they use the OpenAI API interface. Enabling TRL to train models through this interface would allow us to run them in situ, inside the same environment they’re meant to be deployed in.
This proposal aims to bridge that gap and enable more realistic, multi-step training workflows as a first-class feature in TRL.
Your contribution
I have already developed an initial working implementation in this PR: #3285, which introduces
vllm_serve_openai_compatible.py
.I intend to wrap up remaining loose ends and properly test this approach, both for functional correctness and throughput benchmarking.
The draft PR also includes a few project-specific utilities (WIP) to illustrate how this can be used in practice. For example, it shows how to parallelize existing Aider instances that interact with this server to generate training data.
One open issue is how to reliably access full conversation histories for each rollout. Since API calls happen internally within the agent, we cannot assume access to
.get_conversation_history()
or similar. A possible approach is to record all requests and responses server-side and map them back to the original prompt to reconstruct complete rollouts to train on.I’d be happy to align the implementation with TRL’s design goals and iterate toward something mergeable.
The text was updated successfully, but these errors were encountered: