Skip to content

Error while serving fine-tuned Qwen 2.5 VL model #8147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
1 task done
nishadsinghi opened this issue May 23, 2025 · 3 comments
Open
1 task done

Error while serving fine-tuned Qwen 2.5 VL model #8147

nishadsinghi opened this issue May 23, 2025 · 3 comments
Labels
bug Something isn't working help wanted Extra attention is needed pending This problem is yet to be addressed

Comments

@nishadsinghi
Copy link

nishadsinghi commented May 23, 2025

Reminder

  • I have read the above rules and searched the existing issues.

System Info

[2025-05-23 13:50:43,655] [INFO] [real_accelerator.py:254:get_accelerator] Setting ds_accelerator to cuda (auto detect)
INFO 05-23 13:50:48 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 05-23 13:50:48 [init.py:239] Automatically detected platform cuda.

  • llamafactory version: 0.9.3.dev0
  • Platform: Linux-6.8.0-54-generic-x86_64-with-glibc2.39
  • Python version: 3.9.21
  • PyTorch version: 2.6.0+cu124 (GPU)
  • Transformers version: 4.52.1
  • Datasets version: 3.6.0
  • Accelerate version: 1.7.0
  • PEFT version: 0.15.2
  • TRL version: 0.9.6
  • GPU type: NVIDIA L40S
  • GPU number: 2
  • GPU memory: 44.40GB
  • DeepSpeed version: 0.16.9
  • vLLM version: 0.8.5.post1
  • Git commit: a9211a7

Reproduction

I fine-tuned Qwen 2.5 VL 3B Instruct. Then, I tried to deploy it as follows:
API_PORT=8000 llamafactory-cli api examples/inference/qwen2_5vl.yaml infer_backend=vllm vllm_enforce_eager=true
Which gave me an error. I was able to serve the base model using the same command, but not the fine-tuned version.

Here is the error:

WARNING 05-23 13:44:38 [utils.py:2382] We must use the `spawn` multiprocessing start method. Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. See https://docs.vllm.ai/en/latest/getting_started/troubleshooting.html#python-multiprocessing for more information. Reason: CUDA is initialized
INFO 05-23 13:44:51 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 05-23 13:44:51 [__init__.py:239] Automatically detected platform cuda.
INFO 05-23 13:44:57 [core.py:58] Initializing a V1 LLM engine (v0.8.5.post1) with config: model='saves/qwen2_5vl-3b_llarp_finetuning/full/sft/try_again', speculative_config=None, tokenizer='saves/qwen2_5vl-3b_llarp_finetuning/full/sft/try_again', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=auto, tensor_parallel_size=2, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='auto', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=None, served_model_name=saves/qwen2_5vl-3b_llarp_finetuning/full/sft/try_again, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=False, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"splitting_ops":[],"compile_sizes":[],"cudagraph_capture_sizes":[],"max_capture_size":0}
WARNING 05-23 13:44:57 [multiproc_worker_utils.py:306] Reducing Torch parallelism from 64 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
INFO 05-23 13:44:57 [shm_broadcast.py:266] vLLM message queue communication handle: Handle(local_reader_ranks=[0, 1], buffer_handle=(2, 10485760, 10, 'psm_5f7ba305'), local_subscribe_addr='ipc:///tmp/8990c269-2086-4b12-b6ed-427acf2d1b5b', remote_subscribe_addr=None, remote_addr_ipv6=False)
INFO 05-23 13:45:10 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 05-23 13:45:10 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 05-23 13:45:10 [__init__.py:239] Automatically detected platform cuda.
INFO 05-23 13:45:10 [__init__.py:239] Automatically detected platform cuda.
WARNING 05-23 13:45:14 [utils.py:2522] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes,initialize_cache not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x7208a1a5b070>
WARNING 05-23 13:45:14 [utils.py:2522] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes,initialize_cache not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x7a3b6ea5c160>
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m INFO 05-23 13:45:14 [shm_broadcast.py:266] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_6f812bd4'), local_subscribe_addr='ipc:///tmp/24ee3abb-d3dc-4467-93b2-c52e4ac7bdfd', remote_subscribe_addr=None, remote_addr_ipv6=False)
�[1;36m(VllmWorker rank=0 pid=1423872)�[0;0m INFO 05-23 13:45:14 [shm_broadcast.py:266] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_e9551106'), local_subscribe_addr='ipc:///tmp/066418ee-f44e-4bc5-976d-efb34b980723', remote_subscribe_addr=None, remote_addr_ipv6=False)
�[1;36m(VllmWorker rank=0 pid=1423872)�[0;0m INFO 05-23 13:45:15 [utils.py:1055] Found nccl from library libnccl.so.2
�[1;36m(VllmWorker rank=0 pid=1423872)�[0;0m INFO 05-23 13:45:15 [pynccl.py:69] vLLM is using nccl==2.21.5
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m INFO 05-23 13:45:15 [utils.py:1055] Found nccl from library libnccl.so.2
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m INFO 05-23 13:45:15 [pynccl.py:69] vLLM is using nccl==2.21.5
�[1;36m(VllmWorker rank=0 pid=1423872)�[0;0m INFO 05-23 13:45:15 [custom_all_reduce_utils.py:244] reading GPU P2P access cache from /home/ns94feza/.cache/vllm/gpu_p2p_access_cache_for_0,1.json
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m INFO 05-23 13:45:15 [custom_all_reduce_utils.py:244] reading GPU P2P access cache from /home/ns94feza/.cache/vllm/gpu_p2p_access_cache_for_0,1.json
�[1;36m(VllmWorker rank=0 pid=1423872)�[0;0m INFO 05-23 13:45:15 [shm_broadcast.py:266] vLLM message queue communication handle: Handle(local_reader_ranks=[1], buffer_handle=(1, 4194304, 6, 'psm_1b30d56c'), local_subscribe_addr='ipc:///tmp/d5c40de6-cde0-472e-b0b6-a60f1d32ba91', remote_subscribe_addr=None, remote_addr_ipv6=False)
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m INFO 05-23 13:45:15 [parallel_state.py:1004] rank 1 in world size 2 is assigned as DP rank 0, PP rank 0, TP rank 1
�[1;36m(VllmWorker rank=0 pid=1423872)�[0;0m INFO 05-23 13:45:15 [parallel_state.py:1004] rank 0 in world size 2 is assigned as DP rank 0, PP rank 0, TP rank 0
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m INFO 05-23 13:45:15 [cuda.py:221] Using Flash Attention backend on V1 engine.
�[1;36m(VllmWorker rank=0 pid=1423872)�[0;0m INFO 05-23 13:45:15 [cuda.py:221] Using Flash Attention backend on V1 engine.
�[1;36m(VllmWorker rank=0 pid=1423872)�[0;0m Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
�[1;36m(VllmWorker rank=0 pid=1423872)�[0;0m Unused or unrecognized kwargs: fps, return_tensors.
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m Unused or unrecognized kwargs: return_tensors, fps.
�[1;36m(VllmWorker rank=0 pid=1423872)�[0;0m WARNING 05-23 13:45:19 [topk_topp_sampler.py:69] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m WARNING 05-23 13:45:19 [topk_topp_sampler.py:69] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m INFO 05-23 13:45:19 [gpu_model_runner.py:1329] Starting to load model saves/qwen2_5vl-3b_llarp_finetuning/full/sft/try_again...
�[1;36m(VllmWorker rank=0 pid=1423872)�[0;0m INFO 05-23 13:45:19 [gpu_model_runner.py:1329] Starting to load model saves/qwen2_5vl-3b_llarp_finetuning/full/sft/try_again...
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m INFO 05-23 13:45:19 [config.py:3614] cudagraph sizes specified by model runner [] is overridden by config []
�[1;36m(VllmWorker rank=0 pid=1423872)�[0;0m INFO 05-23 13:45:19 [config.py:3614] cudagraph sizes specified by model runner [] is overridden by config []
�[1;36m(VllmWorker rank=0 pid=1423872)�[0;0m 
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435] WorkerProc failed to start.
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435] Traceback (most recent call last):
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435]   File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/v1/executor/multiproc_executor.py", line 409, in worker_main
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435]     worker = WorkerProc(*args, **kwargs)
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435]   File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/v1/executor/multiproc_executor.py", line 306, in __init__
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435]     self.worker.load_model()
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435]   File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/v1/worker/gpu_worker.py", line 162, in load_model
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435]     self.model_runner.load_model()
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435]   File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/v1/worker/gpu_model_runner.py", line 1332, in load_model
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435]     self.model = get_model(vllm_config=self.vllm_config)
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435]   File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/model_executor/model_loader/__init__.py", line 14, in get_model
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435]     return loader.load_model(vllm_config=vllm_config)
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435]   File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/model_executor/model_loader/loader.py", line 455, in load_model
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435]     loaded_weights = model.load_weights(
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435]   File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/model_executor/models/qwen2_5_vl.py", line 1126, in load_weights
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435]     return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435]   File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/model_executor/models/utils.py", line 261, in load_weights
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435]     autoloaded_weights = set(self._load_module("", self.module, weights))
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435]   File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/model_executor/models/utils.py", line 222, in _load_module
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435]     yield from self._load_module(prefix,
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435]   File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/model_executor/models/utils.py", line 195, in _load_module
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435]     loaded_params = module_load_weights(weights)
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435]   File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/model_executor/models/qwen2.py", line 486, in load_weights
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435]     return loader.load_weights(weights)
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435]   File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/model_executor/models/utils.py", line 261, in load_weights
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435]     autoloaded_weights = set(self._load_module("", self.module, weights))
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435]   File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/model_executor/models/utils.py", line 222, in _load_module
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435]     yield from self._load_module(prefix,
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435]   File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/model_executor/models/utils.py", line 195, in _load_module
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435]     loaded_params = module_load_weights(weights)
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435]   File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/model_executor/models/qwen2.py", line 405, in load_weights
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435]     param = params_dict[name]
�[1;36m(VllmWorker rank=1 pid=1423873)�[0;0m ERROR 05-23 13:45:19 [multiproc_executor.py:435] KeyError: 'language_model.layers.19.input_layernorm.weight'
�[1;36m(VllmWorker rank=0 pid=1423872)�[0;0m 
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
�[1;36m(VllmWorker rank=0 pid=1423872)�[0;0m 
[rank0]:[W523 13:45:20.315050792 ProcessGroupNCCL.cpp:1496] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
ERROR 05-23 13:45:21 [core.py:396] EngineCore failed to start.
ERROR 05-23 13:45:21 [core.py:396] Traceback (most recent call last):
ERROR 05-23 13:45:21 [core.py:396]   File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/v1/engine/core.py", line 387, in run_engine_core
ERROR 05-23 13:45:21 [core.py:396]     engine_core = EngineCoreProc(*args, **kwargs)
ERROR 05-23 13:45:21 [core.py:396]   File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/v1/engine/core.py", line 329, in __init__
ERROR 05-23 13:45:21 [core.py:396]     super().__init__(vllm_config, executor_class, log_stats,
ERROR 05-23 13:45:21 [core.py:396]   File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/v1/engine/core.py", line 64, in __init__
ERROR 05-23 13:45:21 [core.py:396]     self.model_executor = executor_class(vllm_config)
ERROR 05-23 13:45:21 [core.py:396]   File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/executor/executor_base.py", line 52, in __init__
ERROR 05-23 13:45:21 [core.py:396]     self._init_executor()
ERROR 05-23 13:45:21 [core.py:396]   File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/v1/executor/multiproc_executor.py", line 91, in _init_executor
ERROR 05-23 13:45:21 [core.py:396]     self.workers = WorkerProc.wait_for_ready(unready_workers)
ERROR 05-23 13:45:21 [core.py:396]   File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/v1/executor/multiproc_executor.py", line 370, in wait_for_ready
ERROR 05-23 13:45:21 [core.py:396]     raise e from None
ERROR 05-23 13:45:21 [core.py:396] Exception: WorkerProc initialization failed due to an exception in a background process. See stack trace for root cause.
Process EngineCore_0:
Traceback (most recent call last):
  File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/v1/engine/core.py", line 400, in run_engine_core
    raise e
  File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/v1/engine/core.py", line 387, in run_engine_core
    engine_core = EngineCoreProc(*args, **kwargs)
  File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/v1/engine/core.py", line 329, in __init__
    super().__init__(vllm_config, executor_class, log_stats,
  File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/v1/engine/core.py", line 64, in __init__
    self.model_executor = executor_class(vllm_config)
  File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/executor/executor_base.py", line 52, in __init__
    self._init_executor()
  File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/v1/executor/multiproc_executor.py", line 91, in _init_executor
    self.workers = WorkerProc.wait_for_ready(unready_workers)
  File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/v1/executor/multiproc_executor.py", line 370, in wait_for_ready
    raise e from None
Exception: WorkerProc initialization failed due to an exception in a background process. See stack trace for root cause.
Traceback (most recent call last):
  File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/weakref.py", line 667, in _exitfunc
    f()
  File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/weakref.py", line 591, in __call__
    return info.func(*info.args, **(info.kwargs or {}))
  File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/v1/executor/multiproc_executor.py", line 228, in shutdown
    for w in self.workers:
AttributeError: 'MultiprocExecutor' object has no attribute 'workers'
Traceback (most recent call last):
  File "/home/ns94feza/miniconda3/envs/llama-factory/bin/llamafactory-cli", line 8, in <module>
    sys.exit(main())
  File "/mnt/beegfs/hdd/mirror/home/ns94feza/LLaMA-Factory/src/llamafactory/cli.py", line 115, in main
    COMMAND_MAP[command]()
  File "/mnt/beegfs/hdd/mirror/home/ns94feza/LLaMA-Factory/src/llamafactory/api/app.py", line 128, in run_api
    chat_model = ChatModel()
  File "/mnt/beegfs/hdd/mirror/home/ns94feza/LLaMA-Factory/src/llamafactory/chat/chat_model.py", line 55, in __init__
    self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args)
  File "/mnt/beegfs/hdd/mirror/home/ns94feza/LLaMA-Factory/src/llamafactory/chat/vllm_engine.py", line 97, in __init__
    self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
  File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/engine/async_llm_engine.py", line 684, in from_engine_args
    return async_engine_cls.from_vllm_config(
  File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/v1/engine/async_llm.py", line 150, in from_vllm_config
    return cls(
  File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/v1/engine/async_llm.py", line 118, in __init__
    self.engine_core = core_client_class(
  File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/v1/engine/core_client.py", line 642, in __init__
    super().__init__(
  File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/v1/engine/core_client.py", line 398, in __init__
    self._wait_for_engine_startup()
  File "/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/site-packages/vllm/v1/engine/core_client.py", line 430, in _wait_for_engine_startup
    raise RuntimeError("Engine core initialization failed. "
RuntimeError: Engine core initialization failed. See root cause above.
/home/ns94feza/miniconda3/envs/llama-factory/lib/python3.9/multiprocessing/resource_tracker.py:216: UserWarning: resource_tracker: There appear to be 1 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

Others

No response

@nishadsinghi nishadsinghi added bug Something isn't working pending This problem is yet to be addressed labels May 23, 2025
@XueSongTap
Copy link

same problem while serving fine-tuned qwen2.5 vl 3B model

(VllmWorker rank=2 pid=21227) INFO 05-26 07:08:03 [config.py:3614] cudagraph sizes specified by model runner [1, 2, 4, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256, 264, 272, 280, 288, 296, 304, 312, 320, 328, 336, 344, 352, 360, 368, 376, 384, 392, 400, 408, 416, 424, 432, 440, 448, 456, 464, 472, 480, 488, 496, 504, 512] is overridden by config [512, 384, 256, 128, 4, 2, 1, 392, 264, 136, 8, 400, 272, 144, 16, 408, 280, 152, 24, 416, 288, 160, 32, 424, 296, 168, 40, 432, 304, 176, 48, 440, 312, 184, 56, 448, 320, 192, 64, 456, 328, 200, 72, 464, 336, 208, 80, 472, 344, 216, 88, 120, 480, 352, 248, 224, 96, 488, 504, 360, 232, 104, 496, 368, 240, 112, 376]
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435] WorkerProc failed to start.
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435] Traceback (most recent call last):
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 409, in worker_main
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]     worker = WorkerProc(*args, **kwargs)
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 306, in __init__
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]     self.worker.load_model()
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 162, in load_model
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]     self.model_runner.load_model()
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1332, in load_model
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]     self.model = get_model(vllm_config=self.vllm_config)
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/__init__.py", line 14, in get_model
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]     return loader.load_model(vllm_config=vllm_config)
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/loader.py", line 455, in load_model
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]     loaded_weights = model.load_weights(
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]                      ^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/qwen2_5_vl.py", line 1126, in load_weights
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]     return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/utils.py", line 261, in load_weights
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]     autoloaded_weights = set(self._load_module("", self.module, weights))
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/utils.py", line 222, in _load_module
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]     yield from self._load_module(prefix,
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/utils.py", line 195, in _load_module
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]     loaded_params = module_load_weights(weights)
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/qwen2.py", line 486, in load_weights
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]     return loader.load_weights(weights)
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/utils.py", line 261, in load_weights
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]     autoloaded_weights = set(self._load_module("", self.module, weights))
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/utils.py", line 222, in _load_module
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]     yield from self._load_module(prefix,
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/utils.py", line 195, in _load_module
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]     loaded_params = module_load_weights(weights)
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/qwen2.py", line 405, in load_weights
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]     param = params_dict[name]
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435]             ~~~~~~~~~~~^^^^^^
(VllmWorker rank=3 pid=21228) ERROR 05-26 07:08:03 [multiproc_executor.py:435] KeyError: 'language_model.layers.19.input_layernorm.weight'
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
(VllmWorker rank=0 pid=21225) 
[rank0]:[W526 07:08:04.170710658 ProcessGroupNCCL.cpp:1496] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
ERROR 05-26 07:08:06 [core.py:396] EngineCore failed to start.
ERROR 05-26 07:08:06 [core.py:396] Traceback (most recent call last):
ERROR 05-26 07:08:06 [core.py:396]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 387, in run_engine_core
ERROR 05-26 07:08:06 [core.py:396]     engine_core = EngineCoreProc(*args, **kwargs)
ERROR 05-26 07:08:06 [core.py:396]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 05-26 07:08:06 [core.py:396]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 329, in __init__
ERROR 05-26 07:08:06 [core.py:396]     super().__init__(vllm_config, executor_class, log_stats,
ERROR 05-26 07:08:06 [core.py:396]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 64, in __init__
ERROR 05-26 07:08:06 [core.py:396]     self.model_executor = executor_class(vllm_config)
ERROR 05-26 07:08:06 [core.py:396]                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 05-26 07:08:06 [core.py:396]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/executor_base.py", line 52, in __init__
ERROR 05-26 07:08:06 [core.py:396]     self._init_executor()
ERROR 05-26 07:08:06 [core.py:396]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 91, in _init_executor
ERROR 05-26 07:08:06 [core.py:396]     self.workers = WorkerProc.wait_for_ready(unready_workers)
ERROR 05-26 07:08:06 [core.py:396]                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 05-26 07:08:06 [core.py:396]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 370, in wait_for_ready
ERROR 05-26 07:08:06 [core.py:396]     raise e from None
ERROR 05-26 07:08:06 [core.py:396] Exception: WorkerProc initialization failed due to an exception in a background process. See stack trace for root cause.
Process EngineCore_0:
Traceback (most recent call last):
  File "/usr/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.12/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 400, in run_engine_core
    raise e
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 387, in run_engine_core
    engine_core = EngineCoreProc(*args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 329, in __init__
    super().__init__(vllm_config, executor_class, log_stats,
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 64, in __init__
    self.model_executor = executor_class(vllm_config)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/executor/executor_base.py", line 52, in __init__
    self._init_executor()
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 91, in _init_executor
    self.workers = WorkerProc.wait_for_ready(unready_workers)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 370, in wait_for_ready
    raise e from None
Exception: WorkerProc initialization failed due to an exception in a background process. See stack trace for root cause.
Traceback (most recent call last):
  File "/usr/lib/python3.12/weakref.py", line 666, in _exitfunc
    f()
  File "/usr/lib/python3.12/weakref.py", line 590, in __call__
    return info.func(*info.args, **(info.kwargs or {}))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 228, in shutdown
    for w in self.workers:
             ^^^^^^^^^^^^
AttributeError: 'MultiprocExecutor' object has no attribute 'workers'
Traceback (most recent call last):
  File "/app/scripts/vllm_infer.py", line 199, in <module>
    fire.Fire(vllm_infer)
  File "/usr/local/lib/python3.12/dist-packages/fire/core.py", line 135, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/fire/core.py", line 468, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
                                ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/fire/core.py", line 684, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^
  File "/app/scripts/vllm_infer.py", line 112, in vllm_infer
    llm = LLM(**engine_args)
          ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 1161, in inner
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/llm.py", line 247, in __init__
    self.llm_engine = LLMEngine.from_engine_args(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/engine/llm_engine.py", line 510, in from_engine_args
    return engine_cls.from_vllm_config(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/llm_engine.py", line 112, in from_vllm_config
    return cls(vllm_config=vllm_config,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/llm_engine.py", line 92, in __init__
    self.engine_core = EngineCoreClient.make_client(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core_client.py", line 73, in make_client
    return SyncMPClient(vllm_config, executor_class, log_stats)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core_client.py", line 494, in __init__
    super().__init__(
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core_client.py", line 398, in __init__
    self._wait_for_engine_startup()
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core_client.py", line 430, in _wait_for_engine_startup
    raise RuntimeError("Engine core initialization failed. "
RuntimeError: Engine core initialization failed. See root cause above.
root@c08e3a430e9c:/app# /usr/lib/python3.12/multiprocessing/resource_tracker.py:254: UserWarning: resource_tracker: There appear to be 1 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
root@c08e3a430e9c:/app# ll

Before 3B, I've tried serving fine-tune 7B with lora with no error

@hiyouga
Copy link
Owner

hiyouga commented May 26, 2025

Currently, there are some bugs in Transformers 4.52.* when using vLLM to run inference on fine-tuned models. We are working on a fix: huggingface/transformers#38385

As a temporary workaround, you can downgrade Transformers to version 4.51.3 and train again to avoid this issue.

@hiyouga hiyouga marked this as a duplicate of #8163 May 27, 2025
@alanMachineLeraning
Copy link

问题解决了吗,不需要重新训练能推理了吗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed pending This problem is yet to be addressed
Projects
None yet
Development

No branches or pull requests

4 participants