Skip to content

Commit 6c200fd

Browse files
authored
[model] add llama4 (#7611)
1 parent 61b24c3 commit 6c200fd

File tree

11 files changed

+167
-8
lines changed

11 files changed

+167
-8
lines changed

requirements.txt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
transformers>=4.41.2,<=4.50.0,!=4.46.*,!=4.47.*,!=4.48.*;python_version<'3.10' and sys_platform != 'darwin'
2-
transformers>=4.41.2,<=4.50.0,!=4.46.*,!=4.47.*,!=4.48.0;python_version>='3.10' and sys_platform != 'darwin'
3-
transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.*;sys_platform == 'darwin'
1+
transformers>=4.41.2,<=4.51.0,!=4.46.*,!=4.47.*,!=4.48.0
42
datasets>=2.16.0,<=3.4.1
53
accelerate>=0.34.0,<=1.5.2
64
peft>=0.14.0,<=0.15.0

scripts/convert_ckpt/tiny_llama4.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright 2025 the LlamaFactory team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from transformers import Llama4Config, Llama4ForConditionalGeneration, Llama4TextConfig, Llama4VisionConfig
16+
17+
18+
if __name__ == "__main__":
19+
vision_config = Llama4VisionConfig(
20+
hidden_size=1408,
21+
image_size=336,
22+
intermediate_size=5632,
23+
num_attention_heads=16,
24+
num_hidden_layers=4,
25+
vision_output_dim=4096,
26+
)
27+
text_config = Llama4TextConfig(
28+
hidden_size=512,
29+
intermediate_size=1024,
30+
intermediate_size_mlp=1024,
31+
num_hidden_layers=4,
32+
num_attention_heads=8,
33+
num_key_value_heads=2,
34+
head_dim=512 // 8,
35+
num_local_experts=2,
36+
)
37+
config = Llama4Config(vision_config=vision_config, text_config=text_config)
38+
model = Llama4ForConditionalGeneration._from_config(config)
39+
model.save_pretrained("tiny-llama4")

src/llamafactory/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
2020
Dependency graph:
2121
main:
22-
transformers>=4.41.2,<=4.50.0,!=4.46.*,!=4.47.*,!=4.48.0
22+
transformers>=4.41.2,<=4.51.0,!=4.46.*,!=4.47.*,!=4.48.0
2323
datasets>=2.16.0,<=3.4.1
2424
accelerate>=0.34.0,<=1.5.2
2525
peft>=0.14.0,<=0.15.0

src/llamafactory/data/mm_plugin.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,73 @@ def get_mm_inputs(
466466
return mm_inputs
467467

468468

469+
@dataclass
470+
class Llama4Plugin(BasePlugin):
471+
@override
472+
def process_messages(
473+
self,
474+
messages: list[dict[str, str]],
475+
images: list["ImageInput"],
476+
videos: list["VideoInput"],
477+
audios: list["AudioInput"],
478+
processor: Optional["MMProcessor"],
479+
) -> list[dict[str, str]]:
480+
self._validate_input(processor, images, videos, audios)
481+
if self.expand_mm_tokens:
482+
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
483+
if "pixel_values" in mm_inputs:
484+
image_height, image_width = mm_inputs["pixel_values"][0].shape[-2:]
485+
num_patches_per_chunk = int(
486+
(image_height // processor.patch_size)
487+
* (image_width // processor.patch_size)
488+
// processor.downsample_ratio
489+
)
490+
aspect_ratios = mm_inputs.pop("aspect_ratios")
491+
492+
num_image_tokens = 0
493+
messages = deepcopy(messages)
494+
for message in messages:
495+
content = message["content"]
496+
placeholder_count = content.count(IMAGE_PLACEHOLDER)
497+
if self.expand_mm_tokens:
498+
prompt_splits = content.split(IMAGE_PLACEHOLDER)
499+
new_content = []
500+
for local_image_index, split_part in enumerate(prompt_splits):
501+
new_content.append(split_part)
502+
if local_image_index < placeholder_count:
503+
tokens_for_this_image = processor._prompt_split_image(
504+
aspect_ratios[num_image_tokens], num_patches_per_chunk
505+
)
506+
num_image_tokens += 1
507+
new_content.append(tokens_for_this_image)
508+
509+
content = "".join(new_content)
510+
511+
message["content"] = content
512+
513+
if len(images) != num_image_tokens:
514+
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
515+
516+
return messages
517+
518+
@override
519+
def get_mm_inputs(
520+
self,
521+
images: list["ImageInput"],
522+
videos: list["VideoInput"],
523+
audios: list["AudioInput"],
524+
imglens: list[int],
525+
vidlens: list[int],
526+
audlens: list[int],
527+
batch_ids: list[list[int]],
528+
processor: Optional["MMProcessor"],
529+
) -> dict[str, Union[list[int], "torch.Tensor"]]:
530+
self._validate_input(processor, images, videos, audios)
531+
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
532+
mm_inputs.pop("aspect_ratios", None)
533+
return mm_inputs
534+
535+
469536
@dataclass
470537
class LlavaPlugin(BasePlugin):
471538
@override
@@ -1485,6 +1552,7 @@ def process_messages(
14851552
PLUGINS = {
14861553
"base": BasePlugin,
14871554
"gemma3": Gemma3Plugin,
1555+
"llama4": Llama4Plugin,
14881556
"llava": LlavaPlugin,
14891557
"llava_next": LlavaNextPlugin,
14901558
"llava_next_video": LlavaNextVideoPlugin,

src/llamafactory/data/template.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -968,6 +968,26 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
968968
)
969969

970970

971+
register_template(
972+
name="llama4",
973+
format_user=StringFormatter(
974+
slots=["<|header_start|>user<|header_end|>\n\n{{content}}<|eot|><|header_start|>assistant<|header_end|>\n\n"]
975+
),
976+
format_assistant=StringFormatter(slots=["{{content}}<|eot|>"]),
977+
format_system=StringFormatter(slots=["<|header_start|>system<|header_end|>\n\n{{content}}<|eot|>"]),
978+
format_function=FunctionFormatter(slots=["{{content}}<|eot|>"], tool_format="llama3"),
979+
format_observation=StringFormatter(
980+
slots=[
981+
"<|header_start|>ipython<|header_end|>\n\n{{content}}<|eot|><|header_start|>assistant<|header_end|>\n\n"
982+
]
983+
),
984+
format_tools=ToolFormatter(tool_format="llama3"),
985+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
986+
stop_words=["<|eot|>", "<|eom|>"],
987+
mm_plugin=get_mm_plugin(name="llama4", image_token="<|image|>"),
988+
)
989+
990+
971991
# copied from llama3 template
972992
register_template(
973993
name="mllama",

src/llamafactory/extras/constants.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,6 +1111,30 @@ def register_model_group(
11111111
)
11121112

11131113

1114+
register_model_group(
1115+
models={
1116+
"Llama-4-Scout-17B-16E": {
1117+
DownloadSource.DEFAULT: "meta-llama/Llama-4-Scout-17B-16E",
1118+
DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Scout-17B-16E",
1119+
},
1120+
"Llama-4-Scout-17B-16E-Instruct": {
1121+
DownloadSource.DEFAULT: "meta-llama/Llama-4-Scout-17B-16E-Instruct",
1122+
DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Scout-17B-16E-Instruct",
1123+
},
1124+
"Llama-4-Maverick-17B-128E": {
1125+
DownloadSource.DEFAULT: "meta-llama/Llama-4-Maverick-17B-128E",
1126+
DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Maverick-17B-128E",
1127+
},
1128+
"Llama-4-Maverick-17B-128E-Instruct": {
1129+
DownloadSource.DEFAULT: "meta-llama/Llama-4-Maverick-17B-128E-Instruct",
1130+
DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Maverick-17B-128E-Instruct",
1131+
},
1132+
},
1133+
template="llama4",
1134+
multimodal=True,
1135+
)
1136+
1137+
11141138
register_model_group(
11151139
models={
11161140
"LLaVA-1.5-7B-Chat": {

src/llamafactory/extras/misc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
8989

9090
def check_dependencies() -> None:
9191
r"""Check the version of the required packages."""
92-
check_version("transformers>=4.41.2,<=4.50.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
92+
check_version("transformers>=4.41.2,<=4.51.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
9393
check_version("datasets>=2.16.0,<=3.4.1")
9494
check_version("accelerate>=0.34.0,<=1.5.2")
9595
check_version("peft>=0.14.0,<=0.15.0")

src/llamafactory/model/model_utils/checkpointing.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,10 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
7979

8080
@wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",))
8181
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
82-
module: torch.nn.Module = func.__self__
82+
if isinstance(func, partial):
83+
module: torch.nn.Module = func.func.__self__
84+
else:
85+
module: torch.nn.Module = func.__self__
8386

8487
has_grad = False
8588
if any(param.requires_grad for param in module.parameters()):

src/llamafactory/model/model_utils/visual.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,12 @@ def patch_target_modules(
203203
)
204204

205205

206+
_register_composite_model(
207+
model_type="llama4",
208+
vision_model_keys=["vision_model"],
209+
)
210+
211+
206212
_register_composite_model(
207213
model_type="llava",
208214
)

src/llamafactory/train/ppo/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
243243
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
244244
mini_batch = {
245245
"input_ids": batch["input_ids"][idx : idx + self.config.mini_batch_size],
246-
"attention_mask": batch["attention_mask"][idx : idx + self.config.mini_batch_size]
246+
"attention_mask": batch["attention_mask"][idx : idx + self.config.mini_batch_size],
247247
}
248248
mini_batch_queries, mini_batch_responses = self.get_inputs(mini_batch)
249249
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses)

tests/version.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
0.9.3.100
1+
# change if test fails
2+
0.9.3.101

0 commit comments

Comments
 (0)