Make torch.compile LoRA/key-compatible #8213
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Originally, torch.compile was implemented as an object_patch because there was no other built-in way to deal with changing an internal object to another without monkey patching.
However, this means the entirety of diffusion_model is changed to the OptimizedModel class as soon as the model begins to load, even before any weight patches are applied. The actual 'diffusion_model' is stored on the OptimizedModel's ._orig_mod parameter. So, anything relating to keys no longer applies, whether it be loras, hooks, or a simple get attr call.
This PR replaces the object_patch usage with a built-in wrapper that got added late last year as part of the hooks PR. It replaces diffusion_model with the torch.compile'd OptimizedModel object only when it gets to BaseModel.apply_model function, and replaces it with the original diffusion_model as soon as it leaves, for every step of sampling. This is only a reference swap so has basically no cost, and this is about as deep in the sampling code the existing wrappers allow for (meaning very little should be broken now by torch.compile).
The OptimizedModel object gets reused from the torch.compile call same as before.