Skip to content

Commit e8a18c1

Browse files
Shawn-TaoKuangdd01
andauthored
[infer] Modify vllm_infer.py to batch preprocess to avoid too much files opened error (#8051)
Co-authored-by: Kingsley <[email protected]>
1 parent 2b23c0a commit e8a18c1

File tree

1 file changed

+84
-53
lines changed

1 file changed

+84
-53
lines changed

scripts/vllm_infer.py

Lines changed: 84 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import gc
1516
import json
1617
from typing import Optional
1718

1819
import fire
1920
from transformers import Seq2SeqTrainingArguments
21+
from tqdm import tqdm
2022

2123
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
2224
from llamafactory.extras.constants import IGNORE_INDEX
@@ -53,6 +55,7 @@ def vllm_infer(
5355
image_min_pixels: int = 32 * 32,
5456
video_fps: float = 2.0,
5557
video_maxlen: int = 128,
58+
batch_size: int = 1024,
5659
):
5760
r"""Perform batch generation using vLLM engine, which supports tensor parallelism.
5861
@@ -85,42 +88,28 @@ def vllm_infer(
8588
tokenizer = tokenizer_module["tokenizer"]
8689
template_obj = get_template_and_fix_tokenizer(tokenizer, data_args)
8790
template_obj.mm_plugin.expand_mm_tokens = False # for vllm generate
88-
dataset_module = get_dataset(template_obj, model_args, data_args, training_args, "ppo", **tokenizer_module)
8991

90-
inputs, prompts, labels = [], [], []
91-
for sample in dataset_module["train_dataset"]:
92-
if sample["images"]:
93-
multi_modal_data = {
94-
"image": template_obj.mm_plugin._regularize_images(
95-
sample["images"], image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
96-
)["images"]
97-
}
98-
elif sample["videos"]:
99-
multi_modal_data = {
100-
"video": template_obj.mm_plugin._regularize_videos(
101-
sample["videos"],
102-
image_max_pixels=image_max_pixels,
103-
image_min_pixels=image_min_pixels,
104-
video_fps=video_fps,
105-
video_maxlen=video_maxlen,
106-
)["videos"]
107-
}
108-
elif sample["audios"]:
109-
audio_data = template_obj.mm_plugin._regularize_audios(
110-
sample["audios"],
111-
sampling_rate=16000,
112-
)
113-
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
114-
else:
115-
multi_modal_data = None
116-
117-
inputs.append({"prompt_token_ids": sample["input_ids"], "multi_modal_data": multi_modal_data})
118-
prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=skip_special_tokens))
119-
labels.append(
120-
tokenizer.decode(
121-
list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=skip_special_tokens
122-
)
123-
)
92+
engine_args = {
93+
"model": model_args.model_name_or_path,
94+
"trust_remote_code": True,
95+
"dtype": model_args.infer_dtype,
96+
"max_model_len": cutoff_len + max_new_tokens,
97+
"tensor_parallel_size": (get_device_count() // pipeline_parallel_size) or 1,
98+
"pipeline_parallel_size": pipeline_parallel_size,
99+
"disable_log_stats": True,
100+
"enable_lora": model_args.adapter_name_or_path is not None,
101+
}
102+
if template_obj.mm_plugin.__class__.__name__ != "BasePlugin":
103+
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2}
104+
105+
if isinstance(model_args.vllm_config, dict):
106+
engine_args.update(model_args.vllm_config)
107+
108+
llm = LLM(**engine_args)
109+
110+
# load datasets
111+
dataset_module = get_dataset(template_obj, model_args, data_args, training_args, "ppo", **tokenizer_module)
112+
train_dataset = dataset_module["train_dataset"]
124113

125114
sampling_params = SamplingParams(
126115
repetition_penalty=generating_args.repetition_penalty or 1.0, # repetition_penalty must > 0
@@ -137,30 +126,72 @@ def vllm_infer(
137126
else:
138127
lora_request = None
139128

140-
engine_args = {
141-
"model": model_args.model_name_or_path,
142-
"trust_remote_code": True,
143-
"dtype": model_args.infer_dtype,
144-
"max_model_len": cutoff_len + max_new_tokens,
145-
"tensor_parallel_size": (get_device_count() // pipeline_parallel_size) or 1,
146-
"pipeline_parallel_size": pipeline_parallel_size,
147-
"disable_log_stats": True,
148-
"enable_lora": model_args.adapter_name_or_path is not None,
149-
}
150-
if template_obj.mm_plugin.__class__.__name__ != "BasePlugin":
151-
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2}
129+
# Store all results in these lists
130+
all_prompts = []
131+
all_preds = []
132+
all_labels = []
133+
134+
# Add batch process to avoid the issue of too many files opened
135+
for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
136+
vllm_inputs, prompts, labels = [], [], []
137+
138+
batch = train_dataset[i : min(i + batch_size, len(train_dataset))]
139+
140+
for j in range(len(batch["input_ids"])):
141+
if batch["images"][j] is not None:
142+
image = batch["images"][j]
143+
multi_modal_data = {
144+
"image": template_obj.mm_plugin._regularize_images(
145+
image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
146+
)["images"]
147+
}
148+
elif batch["videos"][j] is not None:
149+
video = batch["videos"][j]
150+
multi_modal_data = {
151+
"video": template_obj.mm_plugin._regularize_videos(
152+
video,
153+
image_max_pixels=image_max_pixels,
154+
image_min_pixels=image_min_pixels,
155+
video_fps=video_fps,
156+
video_maxlen=video_maxlen,
157+
)["videos"]
158+
}
159+
elif batch["audios"][j] is not None:
160+
audio = batch["audios"][j]
161+
audio_data = template_obj.mm_plugin._regularize_audios(
162+
audio,
163+
sampling_rate=16000,
164+
)
165+
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
166+
else:
167+
multi_modal_data = None
168+
169+
vllm_inputs.append({"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data})
170+
prompts.append(tokenizer.decode(batch["input_ids"][j], skip_special_tokens=skip_special_tokens))
171+
labels.append(
172+
tokenizer.decode(
173+
list(filter(lambda x: x != IGNORE_INDEX, batch["labels"][j])),
174+
skip_special_tokens=skip_special_tokens,
175+
)
176+
)
152177

153-
if isinstance(model_args.vllm_config, dict):
154-
engine_args.update(model_args.vllm_config)
178+
results = llm.generate(vllm_inputs, sampling_params, lora_request=lora_request)
179+
180+
preds = [result.outputs[0].text for result in results]
181+
182+
# Accumulate results
183+
all_prompts.extend(prompts)
184+
all_preds.extend(preds)
185+
all_labels.extend(labels)
155186

156-
results = LLM(**engine_args).generate(inputs, sampling_params, lora_request=lora_request)
157-
preds = [result.outputs[0].text for result in results]
187+
gc.collect()
188+
# Write all results at once outside the loop
158189
with open(save_name, "w", encoding="utf-8") as f:
159-
for text, pred, label in zip(prompts, preds, labels):
190+
for text, pred, label in zip(all_prompts, all_preds, all_labels):
160191
f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n")
161192

162193
print("*" * 70)
163-
print(f"{len(prompts)} generated results have been saved at {save_name}.")
194+
print(f"{len(all_prompts)} total generated results have been saved at {save_name}.")
164195
print("*" * 70)
165196

166197

0 commit comments

Comments
 (0)