Skip to content

Commit 06c6610

Browse files
Memory estimation code can now take into account conds. (#8307)
1 parent c9e1821 commit 06c6610

File tree

4 files changed

+47
-7
lines changed

4 files changed

+47
-7
lines changed

comfy/conds.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ def concat(self, others):
2424
conds.append(x.cond)
2525
return torch.cat(conds)
2626

27+
def size(self):
28+
return list(self.cond.size())
29+
30+
2731
class CONDNoiseShape(CONDRegular):
2832
def process_cond(self, batch_size, device, area, **kwargs):
2933
data = self.cond
@@ -64,6 +68,7 @@ def concat(self, others):
6468
out.append(c)
6569
return torch.cat(out)
6670

71+
6772
class CONDConstant(CONDRegular):
6873
def __init__(self, cond):
6974
self.cond = cond
@@ -78,3 +83,6 @@ def can_concat(self, other):
7883

7984
def concat(self, others):
8085
return self.cond
86+
87+
def size(self):
88+
return [1]

comfy/model_base.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod
135135
logging.info("model_type {}".format(model_type.name))
136136
logging.debug("adm {}".format(self.adm_channels))
137137
self.memory_usage_factor = model_config.memory_usage_factor
138+
self.memory_usage_factor_conds = ()
138139

139140
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
140141
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
@@ -325,19 +326,28 @@ def blank_inpaint_image_like(latent_image):
325326
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
326327
return self.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1)), noise, latent_image)
327328

328-
def memory_required(self, input_shape):
329+
def memory_required(self, input_shape, cond_shapes={}):
330+
input_shapes = [input_shape]
331+
for c in self.memory_usage_factor_conds:
332+
shape = cond_shapes.get(c, None)
333+
if shape is not None and len(shape) > 0:
334+
input_shapes += shape
335+
329336
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
330337
dtype = self.get_dtype()
331338
if self.manual_cast_dtype is not None:
332339
dtype = self.manual_cast_dtype
333340
#TODO: this needs to be tweaked
334-
area = input_shape[0] * math.prod(input_shape[2:])
341+
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
335342
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
336343
else:
337344
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
338-
area = input_shape[0] * math.prod(input_shape[2:])
345+
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
339346
return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024)
340347

348+
def extra_conds_shapes(self, **kwargs):
349+
return {}
350+
341351

342352
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None):
343353
adm_inputs = []

comfy/sampler_helpers.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22
import uuid
3+
import math
4+
import collections
35
import comfy.model_management
46
import comfy.conds
57
import comfy.utils
@@ -104,6 +106,21 @@ def cleanup_additional_models(models):
104106
if hasattr(m, 'cleanup'):
105107
m.cleanup()
106108

109+
def estimate_memory(model, noise_shape, conds):
110+
cond_shapes = collections.defaultdict(list)
111+
cond_shapes_min = {}
112+
for _, cs in conds.items():
113+
for cond in cs:
114+
for k, v in model.model.extra_conds_shapes(**cond).items():
115+
cond_shapes[k].append(v)
116+
if cond_shapes_min.get(k, None) is None:
117+
cond_shapes_min[k] = [v]
118+
elif math.prod(v) > math.prod(cond_shapes_min[k][0]):
119+
cond_shapes_min[k] = [v]
120+
121+
memory_required = model.model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:]), cond_shapes=cond_shapes)
122+
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
123+
return memory_required, minimum_memory_required
107124

108125
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
109126
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
@@ -117,9 +134,8 @@ def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=Non
117134
models, inference_memory = get_additional_models(conds, model.model_dtype())
118135
models += get_additional_models_from_model_options(model_options)
119136
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
120-
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
121-
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
122-
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)
137+
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
138+
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
123139
real_model = model.model
124140

125141
return real_model, conds, models

comfy/samplers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,13 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
256256
for i in range(1, len(to_batch_temp) + 1):
257257
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
258258
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
259-
if model.memory_required(input_shape) * 1.5 < free_memory:
259+
cond_shapes = collections.defaultdict(list)
260+
for tt in batch_amount:
261+
cond = {k: v.size() for k, v in to_run[tt][0].conditioning.items()}
262+
for k, v in to_run[tt][0].conditioning.items():
263+
cond_shapes[k].append(v.size())
264+
265+
if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
260266
to_batch = batch_amount
261267
break
262268

0 commit comments

Comments
 (0)