12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import gc
15
16
import json
16
17
from typing import Optional
17
18
18
19
import fire
19
20
from transformers import Seq2SeqTrainingArguments
21
+ from tqdm import tqdm
20
22
21
23
from llamafactory .data import get_dataset , get_template_and_fix_tokenizer
22
24
from llamafactory .extras .constants import IGNORE_INDEX
@@ -53,6 +55,7 @@ def vllm_infer(
53
55
image_min_pixels : int = 32 * 32 ,
54
56
video_fps : float = 2.0 ,
55
57
video_maxlen : int = 128 ,
58
+ batch_size : int = 1024 ,
56
59
):
57
60
r"""Perform batch generation using vLLM engine, which supports tensor parallelism.
58
61
@@ -85,42 +88,28 @@ def vllm_infer(
85
88
tokenizer = tokenizer_module ["tokenizer" ]
86
89
template_obj = get_template_and_fix_tokenizer (tokenizer , data_args )
87
90
template_obj .mm_plugin .expand_mm_tokens = False # for vllm generate
88
- dataset_module = get_dataset (template_obj , model_args , data_args , training_args , "ppo" , ** tokenizer_module )
89
91
90
- inputs , prompts , labels = [], [], []
91
- for sample in dataset_module ["train_dataset" ]:
92
- if sample ["images" ]:
93
- multi_modal_data = {
94
- "image" : template_obj .mm_plugin ._regularize_images (
95
- sample ["images" ], image_max_pixels = image_max_pixels , image_min_pixels = image_min_pixels
96
- )["images" ]
97
- }
98
- elif sample ["videos" ]:
99
- multi_modal_data = {
100
- "video" : template_obj .mm_plugin ._regularize_videos (
101
- sample ["videos" ],
102
- image_max_pixels = image_max_pixels ,
103
- image_min_pixels = image_min_pixels ,
104
- video_fps = video_fps ,
105
- video_maxlen = video_maxlen ,
106
- )["videos" ]
107
- }
108
- elif sample ["audios" ]:
109
- audio_data = template_obj .mm_plugin ._regularize_audios (
110
- sample ["audios" ],
111
- sampling_rate = 16000 ,
112
- )
113
- multi_modal_data = {"audio" : zip (audio_data ["audios" ], audio_data ["sampling_rates" ])}
114
- else :
115
- multi_modal_data = None
116
-
117
- inputs .append ({"prompt_token_ids" : sample ["input_ids" ], "multi_modal_data" : multi_modal_data })
118
- prompts .append (tokenizer .decode (sample ["input_ids" ], skip_special_tokens = skip_special_tokens ))
119
- labels .append (
120
- tokenizer .decode (
121
- list (filter (lambda x : x != IGNORE_INDEX , sample ["labels" ])), skip_special_tokens = skip_special_tokens
122
- )
123
- )
92
+ engine_args = {
93
+ "model" : model_args .model_name_or_path ,
94
+ "trust_remote_code" : True ,
95
+ "dtype" : model_args .infer_dtype ,
96
+ "max_model_len" : cutoff_len + max_new_tokens ,
97
+ "tensor_parallel_size" : (get_device_count () // pipeline_parallel_size ) or 1 ,
98
+ "pipeline_parallel_size" : pipeline_parallel_size ,
99
+ "disable_log_stats" : True ,
100
+ "enable_lora" : model_args .adapter_name_or_path is not None ,
101
+ }
102
+ if template_obj .mm_plugin .__class__ .__name__ != "BasePlugin" :
103
+ engine_args ["limit_mm_per_prompt" ] = {"image" : 4 , "video" : 2 , "audio" : 2 }
104
+
105
+ if isinstance (model_args .vllm_config , dict ):
106
+ engine_args .update (model_args .vllm_config )
107
+
108
+ llm = LLM (** engine_args )
109
+
110
+ # load datasets
111
+ dataset_module = get_dataset (template_obj , model_args , data_args , training_args , "ppo" , ** tokenizer_module )
112
+ train_dataset = dataset_module ["train_dataset" ]
124
113
125
114
sampling_params = SamplingParams (
126
115
repetition_penalty = generating_args .repetition_penalty or 1.0 , # repetition_penalty must > 0
@@ -137,30 +126,72 @@ def vllm_infer(
137
126
else :
138
127
lora_request = None
139
128
140
- engine_args = {
141
- "model" : model_args .model_name_or_path ,
142
- "trust_remote_code" : True ,
143
- "dtype" : model_args .infer_dtype ,
144
- "max_model_len" : cutoff_len + max_new_tokens ,
145
- "tensor_parallel_size" : (get_device_count () // pipeline_parallel_size ) or 1 ,
146
- "pipeline_parallel_size" : pipeline_parallel_size ,
147
- "disable_log_stats" : True ,
148
- "enable_lora" : model_args .adapter_name_or_path is not None ,
149
- }
150
- if template_obj .mm_plugin .__class__ .__name__ != "BasePlugin" :
151
- engine_args ["limit_mm_per_prompt" ] = {"image" : 4 , "video" : 2 , "audio" : 2 }
129
+ # Store all results in these lists
130
+ all_prompts = []
131
+ all_preds = []
132
+ all_labels = []
133
+
134
+ # Add batch process to avoid the issue of too many files opened
135
+ for i in tqdm (range (0 , len (train_dataset ), batch_size ), desc = "Processing batched inference" ):
136
+ vllm_inputs , prompts , labels = [], [], []
137
+
138
+ batch = train_dataset [i : min (i + batch_size , len (train_dataset ))]
139
+
140
+ for j in range (len (batch ["input_ids" ])):
141
+ if batch ["images" ][j ] is not None :
142
+ image = batch ["images" ][j ]
143
+ multi_modal_data = {
144
+ "image" : template_obj .mm_plugin ._regularize_images (
145
+ image , image_max_pixels = image_max_pixels , image_min_pixels = image_min_pixels
146
+ )["images" ]
147
+ }
148
+ elif batch ["videos" ][j ] is not None :
149
+ video = batch ["videos" ][j ]
150
+ multi_modal_data = {
151
+ "video" : template_obj .mm_plugin ._regularize_videos (
152
+ video ,
153
+ image_max_pixels = image_max_pixels ,
154
+ image_min_pixels = image_min_pixels ,
155
+ video_fps = video_fps ,
156
+ video_maxlen = video_maxlen ,
157
+ )["videos" ]
158
+ }
159
+ elif batch ["audios" ][j ] is not None :
160
+ audio = batch ["audios" ][j ]
161
+ audio_data = template_obj .mm_plugin ._regularize_audios (
162
+ audio ,
163
+ sampling_rate = 16000 ,
164
+ )
165
+ multi_modal_data = {"audio" : zip (audio_data ["audios" ], audio_data ["sampling_rates" ])}
166
+ else :
167
+ multi_modal_data = None
168
+
169
+ vllm_inputs .append ({"prompt_token_ids" : batch ["input_ids" ][j ], "multi_modal_data" : multi_modal_data })
170
+ prompts .append (tokenizer .decode (batch ["input_ids" ][j ], skip_special_tokens = skip_special_tokens ))
171
+ labels .append (
172
+ tokenizer .decode (
173
+ list (filter (lambda x : x != IGNORE_INDEX , batch ["labels" ][j ])),
174
+ skip_special_tokens = skip_special_tokens ,
175
+ )
176
+ )
152
177
153
- if isinstance (model_args .vllm_config , dict ):
154
- engine_args .update (model_args .vllm_config )
178
+ results = llm .generate (vllm_inputs , sampling_params , lora_request = lora_request )
179
+
180
+ preds = [result .outputs [0 ].text for result in results ]
181
+
182
+ # Accumulate results
183
+ all_prompts .extend (prompts )
184
+ all_preds .extend (preds )
185
+ all_labels .extend (labels )
155
186
156
- results = LLM ( ** engine_args ). generate ( inputs , sampling_params , lora_request = lora_request )
157
- preds = [ result . outputs [ 0 ]. text for result in results ]
187
+ gc . collect ( )
188
+ # Write all results at once outside the loop
158
189
with open (save_name , "w" , encoding = "utf-8" ) as f :
159
- for text , pred , label in zip (prompts , preds , labels ):
190
+ for text , pred , label in zip (all_prompts , all_preds , all_labels ):
160
191
f .write (json .dumps ({"prompt" : text , "predict" : pred , "label" : label }, ensure_ascii = False ) + "\n " )
161
192
162
193
print ("*" * 70 )
163
- print (f"{ len (prompts ) } generated results have been saved at { save_name } ." )
194
+ print (f"{ len (all_prompts ) } total generated results have been saved at { save_name } ." )
164
195
print ("*" * 70 )
165
196
166
197
0 commit comments