23
23
import torch
24
24
import sys
25
25
import platform
26
+ import weakref
27
+ import gc
26
28
27
29
class VRAMState (Enum ):
28
30
DISABLED = 0 #No vram present: no need to move models to vram
@@ -287,11 +289,27 @@ def module_size(module):
287
289
288
290
class LoadedModel :
289
291
def __init__ (self , model ):
290
- self .model = model
292
+ self ._set_model ( model )
291
293
self .device = model .load_device
292
- self .weights_loaded = False
293
294
self .real_model = None
294
295
self .currently_used = True
296
+ self .model_finalizer = None
297
+ self ._patcher_finalizer = None
298
+
299
+ def _set_model (self , model ):
300
+ self ._model = weakref .ref (model )
301
+ if model .parent is not None :
302
+ self ._parent_model = weakref .ref (model .parent )
303
+ self ._patcher_finalizer = weakref .finalize (model , self ._switch_parent )
304
+
305
+ def _switch_parent (self ):
306
+ model = self ._parent_model ()
307
+ if model is not None :
308
+ self ._set_model (model )
309
+
310
+ @property
311
+ def model (self ):
312
+ return self ._model ()
295
313
296
314
def model_memory (self ):
297
315
return self .model .model_size ()
@@ -306,32 +324,23 @@ def model_memory_required(self, device):
306
324
return self .model_memory ()
307
325
308
326
def model_load (self , lowvram_model_memory = 0 , force_patch_weights = False ):
309
- patch_model_to = self .device
310
-
311
327
self .model .model_patches_to (self .device )
312
328
self .model .model_patches_to (self .model .model_dtype ())
313
329
314
- load_weights = not self .weights_loaded
330
+ # if self.model.loaded_size() > 0:
331
+ use_more_vram = lowvram_model_memory
332
+ if use_more_vram == 0 :
333
+ use_more_vram = 1e32
334
+ self .model_use_more_vram (use_more_vram , force_patch_weights = force_patch_weights )
335
+ real_model = self .model .model
315
336
316
- if self .model .loaded_size () > 0 :
317
- use_more_vram = lowvram_model_memory
318
- if use_more_vram == 0 :
319
- use_more_vram = 1e32
320
- self .model_use_more_vram (use_more_vram )
321
- else :
322
- try :
323
- self .real_model = self .model .patch_model (device_to = patch_model_to , lowvram_model_memory = lowvram_model_memory , load_weights = load_weights , force_patch_weights = force_patch_weights )
324
- except Exception as e :
325
- self .model .unpatch_model (self .model .offload_device )
326
- self .model_unload ()
327
- raise e
328
-
329
- if is_intel_xpu () and not args .disable_ipex_optimize and 'ipex' in globals () and self .real_model is not None :
337
+ if is_intel_xpu () and not args .disable_ipex_optimize and 'ipex' in globals () and real_model is not None :
330
338
with torch .no_grad ():
331
- self . real_model = ipex .optimize (self . real_model .eval (), inplace = True , graph_mode = True , concat_linear = True )
339
+ real_model = ipex .optimize (real_model .eval (), inplace = True , graph_mode = True , concat_linear = True )
332
340
333
- self .weights_loaded = True
334
- return self .real_model
341
+ self .real_model = weakref .ref (real_model )
342
+ self .model_finalizer = weakref .finalize (real_model , cleanup_models )
343
+ return real_model
335
344
336
345
def should_reload_model (self , force_patch_weights = False ):
337
346
if force_patch_weights and self .model .lowvram_patch_counter () > 0 :
@@ -344,18 +353,23 @@ def model_unload(self, memory_to_free=None, unpatch_weights=True):
344
353
freed = self .model .partially_unload (self .model .offload_device , memory_to_free )
345
354
if freed >= memory_to_free :
346
355
return False
347
- self .model .unpatch_model ( self . model . offload_device , unpatch_weights = unpatch_weights )
348
- self .model . model_patches_to ( self . model . offload_device )
349
- self .weights_loaded = self . weights_loaded and not unpatch_weights
356
+ self .model .detach ( unpatch_weights )
357
+ self .model_finalizer . detach ( )
358
+ self .model_finalizer = None
350
359
self .real_model = None
351
360
return True
352
361
353
- def model_use_more_vram (self , extra_memory ):
354
- return self .model .partially_load (self .device , extra_memory )
362
+ def model_use_more_vram (self , extra_memory , force_patch_weights = False ):
363
+ return self .model .partially_load (self .device , extra_memory , force_patch_weights = force_patch_weights )
355
364
356
365
def __eq__ (self , other ):
357
366
return self .model is other .model
358
367
368
+ def __del__ (self ):
369
+ if self ._patcher_finalizer is not None :
370
+ self ._patcher_finalizer .detach ()
371
+
372
+
359
373
def use_more_memory (extra_memory , loaded_models , device ):
360
374
for m in loaded_models :
361
375
if m .device == device :
@@ -386,38 +400,8 @@ def extra_reserved_memory():
386
400
def minimum_inference_memory ():
387
401
return (1024 * 1024 * 1024 ) * 0.8 + extra_reserved_memory ()
388
402
389
- def unload_model_clones (model , unload_weights_only = True , force_unload = True ):
390
- to_unload = []
391
- for i in range (len (current_loaded_models )):
392
- if model .is_clone (current_loaded_models [i ].model ):
393
- to_unload = [i ] + to_unload
394
-
395
- if len (to_unload ) == 0 :
396
- return True
397
-
398
- same_weights = 0
399
- for i in to_unload :
400
- if model .clone_has_same_weights (current_loaded_models [i ].model ):
401
- same_weights += 1
402
-
403
- if same_weights == len (to_unload ):
404
- unload_weight = False
405
- else :
406
- unload_weight = True
407
-
408
- if not force_unload :
409
- if unload_weights_only and unload_weight == False :
410
- return None
411
- else :
412
- unload_weight = True
413
-
414
- for i in to_unload :
415
- logging .debug ("unload clone {} {}" .format (i , unload_weight ))
416
- current_loaded_models .pop (i ).model_unload (unpatch_weights = unload_weight )
417
-
418
- return unload_weight
419
-
420
403
def free_memory (memory_required , device , keep_loaded = []):
404
+ cleanup_models_gc ()
421
405
unloaded_model = []
422
406
can_unload = []
423
407
unloaded_models = []
@@ -454,6 +438,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
454
438
return unloaded_models
455
439
456
440
def load_models_gpu (models , memory_required = 0 , force_patch_weights = False , minimum_memory_required = None , force_full_load = False ):
441
+ cleanup_models_gc ()
457
442
global vram_state
458
443
459
444
inference_memory = minimum_inference_memory ()
@@ -466,63 +451,45 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
466
451
models = set (models )
467
452
468
453
models_to_load = []
469
- models_already_loaded = []
454
+
470
455
for x in models :
471
456
loaded_model = LoadedModel (x )
472
- loaded = None
473
-
474
457
try :
475
458
loaded_model_index = current_loaded_models .index (loaded_model )
476
459
except :
477
460
loaded_model_index = None
478
461
479
462
if loaded_model_index is not None :
480
463
loaded = current_loaded_models [loaded_model_index ]
481
- if loaded .should_reload_model (force_patch_weights = force_patch_weights ): #TODO: cleanup this model reload logic
482
- current_loaded_models .pop (loaded_model_index ).model_unload (unpatch_weights = True )
483
- loaded = None
484
- else :
485
- loaded .currently_used = True
486
- models_already_loaded .append (loaded )
487
-
488
- if loaded is None :
464
+ loaded .currently_used = True
465
+ models_to_load .append (loaded )
466
+ else :
489
467
if hasattr (x , "model" ):
490
468
logging .info (f"Requested to load { x .model .__class__ .__name__ } " )
491
469
models_to_load .append (loaded_model )
492
470
493
- if len (models_to_load ) == 0 :
494
- devs = set (map (lambda a : a .device , models_already_loaded ))
495
- for d in devs :
496
- if d != torch .device ("cpu" ):
497
- free_memory (extra_mem + offloaded_memory (models_already_loaded , d ), d , models_already_loaded )
498
- free_mem = get_free_memory (d )
499
- if free_mem < minimum_memory_required :
500
- logging .info ("Unloading models for lowram load." ) #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed.
501
- models_to_load = free_memory (minimum_memory_required , d )
502
- logging .info ("{} models unloaded." .format (len (models_to_load )))
503
- else :
504
- use_more_memory (free_mem - minimum_memory_required , models_already_loaded , d )
505
- if len (models_to_load ) == 0 :
506
- return
507
-
508
- logging .info (f"Loading { len (models_to_load )} new model{ 's' if len (models_to_load ) > 1 else '' } " )
471
+ for loaded_model in models_to_load :
472
+ to_unload = []
473
+ for i in range (len (current_loaded_models )):
474
+ if loaded_model .model .is_clone (current_loaded_models [i ].model ):
475
+ to_unload = [i ] + to_unload
476
+ for i in to_unload :
477
+ current_loaded_models .pop (i ).model .detach (unpatch_all = False )
509
478
510
479
total_memory_required = {}
511
480
for loaded_model in models_to_load :
512
- unload_model_clones (loaded_model .model , unload_weights_only = True , force_unload = False ) #unload clones where the weights are different
513
481
total_memory_required [loaded_model .device ] = total_memory_required .get (loaded_model .device , 0 ) + loaded_model .model_memory_required (loaded_model .device )
514
482
515
- for loaded_model in models_already_loaded :
516
- total_memory_required [loaded_model .device ] = total_memory_required .get (loaded_model .device , 0 ) + loaded_model .model_memory_required (loaded_model .device )
517
-
518
- for loaded_model in models_to_load :
519
- weights_unloaded = unload_model_clones (loaded_model .model , unload_weights_only = False , force_unload = False ) #unload the rest of the clones where the weights can stay loaded
520
- if weights_unloaded is not None :
521
- loaded_model .weights_loaded = not weights_unloaded
483
+ for device in total_memory_required :
484
+ if device != torch .device ("cpu" ):
485
+ free_memory (total_memory_required [device ] * 1.1 + extra_mem , device )
522
486
523
487
for device in total_memory_required :
524
488
if device != torch .device ("cpu" ):
525
- free_memory (total_memory_required [device ] * 1.1 + extra_mem , device , models_already_loaded )
489
+ free_mem = get_free_memory (device )
490
+ if free_mem < minimum_memory_required :
491
+ models_l = free_memory (minimum_memory_required , device )
492
+ logging .info ("{} models unloaded." .format (len (models_l )))
526
493
527
494
for loaded_model in models_to_load :
528
495
model = loaded_model .model
@@ -544,17 +511,8 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
544
511
545
512
cur_loaded_model = loaded_model .model_load (lowvram_model_memory , force_patch_weights = force_patch_weights )
546
513
current_loaded_models .insert (0 , loaded_model )
547
-
548
-
549
- devs = set (map (lambda a : a .device , models_already_loaded ))
550
- for d in devs :
551
- if d != torch .device ("cpu" ):
552
- free_mem = get_free_memory (d )
553
- if free_mem > minimum_memory_required :
554
- use_more_memory (free_mem - minimum_memory_required , models_already_loaded , d )
555
514
return
556
515
557
-
558
516
def load_model_gpu (model ):
559
517
return load_models_gpu ([model ])
560
518
@@ -568,21 +526,35 @@ def loaded_models(only_currently_used=False):
568
526
output .append (m .model )
569
527
return output
570
528
571
- def cleanup_models (keep_clone_weights_loaded = False ):
529
+
530
+ def cleanup_models_gc ():
531
+ do_gc = False
532
+ for i in range (len (current_loaded_models )):
533
+ cur = current_loaded_models [i ]
534
+ if cur .real_model () is not None and cur .model is None :
535
+ logging .info ("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code." .format (cur .real_model ().__class__ .__name__ ))
536
+ do_gc = True
537
+ break
538
+
539
+ if do_gc :
540
+ gc .collect ()
541
+ soft_empty_cache ()
542
+
543
+ for i in range (len (current_loaded_models )):
544
+ cur = current_loaded_models [i ]
545
+ if cur .real_model () is not None and cur .model is None :
546
+ logging .warning ("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere." .format (cur .real_model ().__class__ .__name__ ))
547
+
548
+
549
+
550
+ def cleanup_models ():
572
551
to_delete = []
573
552
for i in range (len (current_loaded_models )):
574
- #TODO: very fragile function needs improvement
575
- num_refs = sys .getrefcount (current_loaded_models [i ].model )
576
- if num_refs <= 2 :
577
- if not keep_clone_weights_loaded :
578
- to_delete = [i ] + to_delete
579
- #TODO: find a less fragile way to do this.
580
- elif sys .getrefcount (current_loaded_models [i ].real_model ) <= 3 : #references from .real_model + the .model
581
- to_delete = [i ] + to_delete
553
+ if current_loaded_models [i ].real_model () is None :
554
+ to_delete = [i ] + to_delete
582
555
583
556
for i in to_delete :
584
557
x = current_loaded_models .pop (i )
585
- x .model_unload ()
586
558
del x
587
559
588
560
def dtype_size (dtype ):
0 commit comments