1) Grab https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true
2) Adjust paths at the bottom of the script
3) Make sure you have torch and safetensors installed, Might be enough to source an existing environment like your automatic1111 venv. For example on Linux:
#!/bin/bash
source PATH_TO_AUTOMATIC1111/venv/bin/activate
python PATH_TO_SCRIPT/swap_vae_sdxl.py
before:
after:
'
from pathlib import Path
from datetime import datetime
# You need to have torch and safetensors installed somewhere
# Might be enough to source some existing environment, like your automatic1111 venv, first
import safetensors.torch
import torch
def load_state_dict(in_path: Path | str):
in_path = Path(in_path)
if in_path.suffix.lower() == ".ckpt":
weights = torch.load(str(in_path), map_location="cpu")
else:
weights = safetensors.torch.load_file(str(in_path), device="cpu")
return weights['state_dict'] if 'state_dict' in weights else weights
def swap_vae_sdxl(
tgt_path,
vae_path,
out_path=None,
prefix="first_stage_model.",
excludes=("model_ema", "loss"),
includes=None,
conversion_dtype="auto",
additional_metadata=None
):
tgt_path = Path(tgt_path)
vae_path = Path(vae_path)
sd_tgt = load_state_dict(tgt_path)
sd_vae = load_state_dict(vae_path)
if conversion_dtype == "auto":
for k, v in sd_tgt.items():
if isinstance(v, torch.Tensor) and torch.is_floating_point(v):
conversion_dtype = v.dtype
print(f"found {v.dtype} in tgt model, auto converting vae tensors to {conversion_dtype}")
break
for key_vae, val_vae in sd_vae.items():
if excludes and any(e in key_vae for e in excludes):
print(f"exclude: {key_vae}")
continue
if includes and not any(e in key_vae for e in includes):
continue
key_tgt = prefix + key_vae
if key_tgt not in sd_tgt:
print(f"key {key_tgt} is missing in tgt model, adding key anyway")
if conversion_dtype and isinstance(val_vae, torch.Tensor) and torch.is_floating_point(val_vae):
sd_tgt[key_tgt] = val_vae.to(conversion_dtype)
else:
sd_tgt[key_tgt] = val_vae
if not out_path:
out_path = tgt_path.with_name(tgt_path.stem + "-vaefix.safetensors")
meta = {"generator": "swap_vae_sdxl.py", "datetime": datetime.now().isoformat(), "tgt_name": tgt_path.name, "vae_name": vae_path.name}
if additional_metadata:
meta |= additional_metadata
safetensors.torch.save_file(sd_tgt, str(out_path), metadata=meta)
print(f"file written to {out_path}")
# Example for
# https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl.vae.safetensors
# https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true
# renamed to madebyollin_sdxl_vae_fix.safetensors
swap_vae_sdxl(
tgt_path="YOUR_PATH/CHEYENNE_v14.safetensors",
vae_path="YOUR_PATH/madebyollin_sdxl_vae_fix.safetensors"
)