Skip to content

Commit f85c08d

Browse files
Make VACE conditionings stackable. (#8240)
1 parent 4202e95 commit f85c08d

File tree

3 files changed

+23
-13
lines changed

3 files changed

+23
-13
lines changed

comfy/ldm/wan/model.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,7 @@ def forward_orig(
635635
t,
636636
context,
637637
vace_context,
638-
vace_strength=1.0,
638+
vace_strength,
639639
clip_fea=None,
640640
freqs=None,
641641
transformer_options={},
@@ -661,8 +661,11 @@ def forward_orig(
661661
context = torch.concat([context_clip, context], dim=1)
662662
context_img_len = clip_fea.shape[-2]
663663

664+
orig_shape = list(vace_context.shape)
665+
vace_context = vace_context.movedim(0, 1).reshape([-1] + orig_shape[2:])
664666
c = self.vace_patch_embedding(vace_context.float()).to(vace_context.dtype)
665667
c = c.flatten(2).transpose(1, 2)
668+
c = list(c.split(orig_shape[0], dim=0))
666669

667670
# arguments
668671
x_orig = x
@@ -682,8 +685,9 @@ def block_wrap(args):
682685

683686
ii = self.vace_layers_mapping.get(i, None)
684687
if ii is not None:
685-
c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
686-
x += c_skip * vace_strength
688+
for iii in range(len(c)):
689+
c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
690+
x += c_skip * vace_strength[iii]
687691
del c_skip
688692
# head
689693
x = self.head(x, e)

comfy/model_base.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,20 +1062,25 @@ def extra_conds(self, **kwargs):
10621062
vace_frames = kwargs.get("vace_frames", None)
10631063
if vace_frames is None:
10641064
noise_shape[1] = 32
1065-
vace_frames = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)
1066-
1067-
for i in range(0, vace_frames.shape[1], 16):
1068-
vace_frames = vace_frames.clone()
1069-
vace_frames[:, i:i + 16] = self.process_latent_in(vace_frames[:, i:i + 16])
1065+
vace_frames = [torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)]
10701066

10711067
mask = kwargs.get("vace_mask", None)
10721068
if mask is None:
10731069
noise_shape[1] = 64
1074-
mask = torch.ones(noise_shape, device=noise.device, dtype=noise.dtype)
1070+
mask = [torch.ones(noise_shape, device=noise.device, dtype=noise.dtype)] * len(vace_frames)
1071+
1072+
vace_frames_out = []
1073+
for j in range(len(vace_frames)):
1074+
vf = vace_frames[j].clone()
1075+
for i in range(0, vf.shape[1], 16):
1076+
vf[:, i:i + 16] = self.process_latent_in(vf[:, i:i + 16])
1077+
vf = torch.cat([vf, mask[j]], dim=1)
1078+
vace_frames_out.append(vf)
10751079

1076-
out['vace_context'] = comfy.conds.CONDRegular(torch.cat([vace_frames.to(noise), mask.to(noise)], dim=1))
1080+
vace_frames = torch.stack(vace_frames_out, dim=1)
1081+
out['vace_context'] = comfy.conds.CONDRegular(vace_frames)
10771082

1078-
vace_strength = kwargs.get("vace_strength", 1.0)
1083+
vace_strength = kwargs.get("vace_strength", [1.0] * len(vace_frames_out))
10791084
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
10801085
return out
10811086

comfy_extras/nodes_wan.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,9 @@ def encode(self, positive, negative, vae, width, height, length, batch_size, str
268268
trim_latent = reference_image.shape[2]
269269

270270
mask = mask.unsqueeze(0)
271-
positive = node_helpers.conditioning_set_values(positive, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength})
272-
negative = node_helpers.conditioning_set_values(negative, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength})
271+
272+
positive = node_helpers.conditioning_set_values(positive, {"vace_frames": [control_video_latent], "vace_mask": [mask], "vace_strength": [strength]}, append=True)
273+
negative = node_helpers.conditioning_set_values(negative, {"vace_frames": [control_video_latent], "vace_mask": [mask], "vace_strength": [strength]}, append=True)
273274

274275
latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device())
275276
out_latent = {}

0 commit comments

Comments
 (0)