Skip to content

MoE finetuning extreme slow #736

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
H-Simpson123 opened this issue Jul 2, 2024 · 6 comments
Closed

MoE finetuning extreme slow #736

H-Simpson123 opened this issue Jul 2, 2024 · 6 comments
Assignees
Labels

Comments

@H-Simpson123
Copy link

The finetuning of Qwen2-57B-A14B-Instruct is extremely slow compared to finetuning of Qwen2-72B-Instruct.

Here are the runtimes:

Qwen/Qwen2-7B-Instruct:
{'train_runtime': 100.8509, 'train_samples_per_second': 5.652, 'train_steps_per_second': 0.099, 'train_loss': 0.751581035554409, 'epoch': 10.0}
Qwen/Qwen2-72B-Instruct:
{'train_runtime': 483.8572, 'train_samples_per_second': 1.178, 'train_steps_per_second': 0.021, 'train_loss': 0.6512975960969924, 'epoch': 10.0}
Qwen/Qwen2-57B-A14B-Instruct:
{'train_runtime': 2713.6648, 'train_samples_per_second': 0.21, 'train_steps_per_second': 0.004, 'train_loss': 10.314393818378448, 'epoch': 10.0}

I'm using the finetune.sh /finetune.py from this repository with --use_lora True and your provided deepspeed 3 config.
I've set the per_device_train_batch_size to 1

Hardware Setup is 8xH100 80GB

Environment:

accelerate==0.31.0
aiohttp==3.9.5
aiosignal==1.3.1
annotated-types==0.7.0
async-timeout==4.0.3
attrs==23.2.0
certifi==2024.6.2
charset-normalizer==3.3.2
coloredlogs==15.0.1
datasets==2.20.0
deepspeed==0.14.4
dill==0.3.8
filelock==3.15.4
frozenlist==1.4.1
fsspec==2024.5.0
hjson==3.1.0
huggingface-hub==0.23.4
humanfriendly==10.0
idna==3.7
Jinja2==3.1.4
MarkupSafe==2.1.5
mpmath==1.3.0
multidict==6.0.5
multiprocess==0.70.16
networkx==3.3
ninja==1.11.1.1
numpy==2.0.0
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-ml-py==12.555.43
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.5.82
nvidia-nvtx-cu12==12.1.105
optimum==1.20.0
packaging==24.1
pandas==2.2.2
peft==0.11.1
protobuf==5.27.2
psutil==6.0.0
py-cpuinfo==9.0.0
pyarrow==16.1.0
pyarrow-hotfix==0.6
pydantic==2.8.0
pydantic_core==2.20.0
python-dateutil==2.9.0.post0
pytz==2024.1
PyYAML==6.0.1
regex==2024.5.15
requests==2.32.3
safetensors==0.4.3
sentencepiece==0.2.0
six==1.16.0
sympy==1.12.1
tokenizers==0.19.1
torch==2.3.1
tqdm==4.66.4
transformers==4.41.2
triton==2.3.1
typing_extensions==4.12.2
tzdata==2024.1
urllib3==2.2.2
xxhash==3.4.1
yarl==1.9.4

cutlass is v3.5.0

@jklj077
Copy link
Collaborator

jklj077 commented Jul 2, 2024

hi, it is expected as the MoE modeling code in transformers is not optimized and finetune.py uses transformers.

currently, the optimized usecase is inference in vllm with the original model, whose base implementation is contributed by the community.

@H-Simpson123
Copy link
Author

Thanks for the quick update. Any plans from your side to add an optimized implementation to HF transformers?

@jklj077
Copy link
Collaborator

jklj077 commented Jul 4, 2024

All MoE models in transformers compute the results of the expert FFNs in loops, because it is simple and easier to understand, but it is less efficient for GPUs by nature. The perf is even worse when the model has a lot of experts, which is the case for Qwen MoE. To optimize that, a fused kernel implementation (as in vllm) or methods like expert parallel (as in mcore) is needed.

To be frank, I don't think it could be done in transformers in the short term.

Copy link

This issue has been automatically marked as inactive due to lack of recent activity. Should you believe it remains unresolved and warrants attention, kindly leave a comment on this thread.

@github-actions github-actions bot closed this as not planned Won't fix, can't repro, duplicate, stale Sep 5, 2024
@FL77N
Copy link

FL77N commented Sep 24, 2024

@jklj077 Hi, when I use transformers to sft qwen2 57b moe on 32xA100 80G with input length 2048, it is oom. is there something wrong with my usage?

Copy link

This issue has been automatically locked since there has not been any recent activity after it was closed. Please open a new issue for related bugs.

@github-actions github-actions bot locked as resolved and limited conversation to collaborators Feb 24, 2025
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

No branches or pull requests

4 participants