Skip to content

Commit f879cac

Browse files
Merge pull request #12311 from AUTOMATIC1111/efficient-vae-methods
Add TAESD(or more) options for all the VAE encode/decode operation
2 parents ad510b2 + b85ec2b commit f879cac

6 files changed

+100
-27
lines changed

modules/generation_parameters_copypaste.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,12 @@ def parse_generation_parameters(x: str):
307307
if "Schedule rho" not in res:
308308
res["Schedule rho"] = 0
309309

310+
if "VAE Encoder" not in res:
311+
res["VAE Encoder"] = "Full"
312+
313+
if "VAE Decoder" not in res:
314+
res["VAE Decoder"] = "Full"
315+
310316
return res
311317

312318

@@ -332,6 +338,8 @@ def parse_generation_parameters(x: str):
332338
('RNG', 'randn_source'),
333339
('NGMS', 's_min_uncond'),
334340
('Pad conds', 'pad_cond_uncond'),
341+
('VAE Encoder', 'sd_vae_encode_method'),
342+
('VAE Decoder', 'sd_vae_decode_method'),
335343
]
336344

337345

modules/processing.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import modules.sd_hijack
1717
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors
1818
from modules.sd_hijack import model_hijack
19+
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
1920
from modules.shared import opts, cmd_opts, state
2021
import modules.shared as shared
2122
import modules.paths as paths
@@ -30,7 +31,6 @@
3031
from einops import repeat, rearrange
3132
from blendmodes.blend import blendLayers, BlendType
3233

33-
decode_first_stage = sd_samplers_common.decode_first_stage
3434

3535
# some of those options should not be changed at all because they would break the model, so I removed them from options.
3636
opt_C = 4
@@ -84,7 +84,7 @@ def txt2img_image_conditioning(sd_model, x, width, height):
8484

8585
# The "masked-image" in this case will just be all zeros since the entire image is masked.
8686
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
87-
image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
87+
image_conditioning = images_tensor_to_samples(image_conditioning, approximation_indexes.get(opts.sd_vae_encode_method))
8888

8989
# Add the fake full 1s mask to the first dimension.
9090
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
@@ -203,7 +203,7 @@ def depth2img_image_conditioning(self, source_image):
203203
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
204204
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
205205

206-
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
206+
conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
207207
conditioning = torch.nn.functional.interpolate(
208208
self.sd_model.depth_model(midas_in),
209209
size=conditioning_image.shape[2:],
@@ -216,7 +216,7 @@ def depth2img_image_conditioning(self, source_image):
216216
return conditioning
217217

218218
def edit_image_conditioning(self, source_image):
219-
conditioning_image = self.sd_model.encode_first_stage(source_image).mode()
219+
conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
220220

221221
return conditioning_image
222222

@@ -795,6 +795,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
795795
if getattr(samples_ddim, 'already_decoded', False):
796796
x_samples_ddim = samples_ddim
797797
else:
798+
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
798799
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
799800

800801
x_samples_ddim = torch.stack(x_samples_ddim).float()
@@ -1135,11 +1136,10 @@ def save_intermediate(image, index):
11351136
batch_images.append(image)
11361137

11371138
decoded_samples = torch.from_numpy(np.array(batch_images))
1138-
decoded_samples = decoded_samples.to(shared.device)
1139-
decoded_samples = 2. * decoded_samples - 1.
11401139
decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae)
11411140

1142-
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
1141+
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
1142+
samples = images_tensor_to_samples(decoded_samples, approximation_indexes.get(opts.sd_vae_encode_method))
11431143

11441144
image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
11451145

@@ -1374,10 +1374,9 @@ def init(self, all_prompts, all_seeds, all_subseeds):
13741374
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
13751375

13761376
image = torch.from_numpy(batch_images)
1377-
image = 2. * image - 1.
13781377
image = image.to(shared.device, dtype=devices.dtype_vae)
1379-
1380-
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
1378+
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
1379+
self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
13811380
devices.torch_gc()
13821381

13831382
if self.resize_mode == 3:

modules/sd_samplers_common.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,29 @@ def setup_img2img_steps(p, steps=None):
2323
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
2424

2525

26-
def single_sample_to_image(sample, approximation=None):
26+
def samples_to_images_tensor(sample, approximation=None, model=None):
27+
'''latents -> images [-1, 1]'''
2728
if approximation is None:
2829
approximation = approximation_indexes.get(opts.show_progress_type, 0)
2930

3031
if approximation == 2:
31-
x_sample = sd_vae_approx.cheap_approximation(sample) * 0.5 + 0.5
32+
x_sample = sd_vae_approx.cheap_approximation(sample)
3233
elif approximation == 1:
33-
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() * 0.5 + 0.5
34+
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype)).detach()
3435
elif approximation == 3:
3536
x_sample = sample * 1.5
36-
x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
37+
x_sample = sd_vae_taesd.decoder_model()(x_sample.to(devices.device, devices.dtype)).detach()
38+
x_sample = x_sample * 2 - 1
3739
else:
38-
x_sample = decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5
40+
if model is None:
41+
model = shared.sd_model
42+
x_sample = model.decode_first_stage(sample)
43+
44+
return x_sample
45+
46+
47+
def single_sample_to_image(sample, approximation=None):
48+
x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5
3949

4050
x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
4151
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
@@ -45,9 +55,9 @@ def single_sample_to_image(sample, approximation=None):
4555

4656

4757
def decode_first_stage(model, x):
48-
x = model.decode_first_stage(x.to(devices.dtype_vae))
49-
50-
return x
58+
x = x.to(devices.dtype_vae)
59+
approx_index = approximation_indexes.get(opts.sd_vae_decode_method, 0)
60+
return samples_to_images_tensor(x, approx_index, model)
5161

5262

5363
def sample_to_image(samples, index=0, approximation=None):
@@ -58,6 +68,24 @@ def samples_to_image_grid(samples, approximation=None):
5868
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
5969

6070

71+
def images_tensor_to_samples(image, approximation=None, model=None):
72+
'''image[0, 1] -> latent'''
73+
if approximation is None:
74+
approximation = approximation_indexes.get(opts.sd_vae_encode_method, 0)
75+
76+
if approximation == 3:
77+
image = image.to(devices.device, devices.dtype)
78+
x_latent = sd_vae_taesd.encoder_model()(image)
79+
else:
80+
if model is None:
81+
model = shared.sd_model
82+
image = image.to(shared.device, dtype=devices.dtype_vae)
83+
image = image * 2 - 1
84+
x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
85+
86+
return x_latent
87+
88+
6189
def store_latent(decoded):
6290
state.current_latent = decoded
6391

modules/sd_vae_approx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,6 @@ def cheap_approximation(sample):
8181

8282
coefs = torch.tensor(coeffs).to(sample.device)
8383

84-
x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
84+
x_sample = torch.einsum("...lxy,lr -> ...rxy", sample, coefs)
8585

8686
return x_sample

modules/sd_vae_taesd.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,17 @@ def decoder():
4444
)
4545

4646

47-
class TAESD(nn.Module):
47+
def encoder():
48+
return nn.Sequential(
49+
conv(3, 64), Block(64, 64),
50+
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
51+
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
52+
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
53+
conv(64, 4),
54+
)
55+
56+
57+
class TAESDDecoder(nn.Module):
4858
latent_magnitude = 3
4959
latent_shift = 0.5
5060

@@ -55,21 +65,28 @@ def __init__(self, decoder_path="taesd_decoder.pth"):
5565
self.decoder.load_state_dict(
5666
torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
5767

58-
@staticmethod
59-
def unscale_latents(x):
60-
"""[0, 1] -> raw latents"""
61-
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
68+
69+
class TAESDEncoder(nn.Module):
70+
latent_magnitude = 3
71+
latent_shift = 0.5
72+
73+
def __init__(self, encoder_path="taesd_encoder.pth"):
74+
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
75+
super().__init__()
76+
self.encoder = encoder()
77+
self.encoder.load_state_dict(
78+
torch.load(encoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
6279

6380

6481
def download_model(model_path, model_url):
6582
if not os.path.exists(model_path):
6683
os.makedirs(os.path.dirname(model_path), exist_ok=True)
6784

68-
print(f'Downloading TAESD decoder to: {model_path}')
85+
print(f'Downloading TAESD model to: {model_path}')
6986
torch.hub.download_url_to_file(model_url, model_path)
7087

7188

72-
def model():
89+
def decoder_model():
7390
model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth"
7491
loaded_model = sd_vae_taesd_models.get(model_name)
7592

@@ -78,11 +95,30 @@ def model():
7895
download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
7996

8097
if os.path.exists(model_path):
81-
loaded_model = TAESD(model_path)
98+
loaded_model = TAESDDecoder(model_path)
8299
loaded_model.eval()
83100
loaded_model.to(devices.device, devices.dtype)
84101
sd_vae_taesd_models[model_name] = loaded_model
85102
else:
86103
raise FileNotFoundError('TAESD model not found')
87104

88105
return loaded_model.decoder
106+
107+
108+
def encoder_model():
109+
model_name = "taesdxl_encoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_encoder.pth"
110+
loaded_model = sd_vae_taesd_models.get(model_name)
111+
112+
if loaded_model is None:
113+
model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name)
114+
download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
115+
116+
if os.path.exists(model_path):
117+
loaded_model = TAESDEncoder(model_path)
118+
loaded_model.eval()
119+
loaded_model.to(devices.device, devices.dtype)
120+
sd_vae_taesd_models[model_name] = loaded_model
121+
else:
122+
raise FileNotFoundError('TAESD model not found')
123+
124+
return loaded_model.encoder

modules/shared.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,8 @@ def list_samplers():
435435
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
436436
"auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
437437
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
438+
"sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
439+
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to decode latent to image"),
438440
}))
439441

440442
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {

0 commit comments

Comments
 (0)