Update: IT WORKS!!!
Inverse cosine is a huge winner. It provides massive fidelity boosts and accuracy when using Simulacrum V4. It introduces much much more information than I could have expected.
The listed math here is similar to the actual math used when training the model, and it seems to have activated a series of pathways that have gone under the radar in a form of bypass that has been unrelentingly destroyed by the stronger UNET sections.
It seems to have enabled my CLIP_L to truly shine in a way that is very unexpected and very welcome.
It has enabled similar behavior to the core flux, while retaining many of the SFW and NSFW elements that I trained into Simulacrum's core, and the subsequent followup trainings!
This has made me very excited!
In the process, it seems to have enabled consistent CFG uniformity as well. What a treat of a byproduct.
Both CFG and Distilled CFG have functional negative prompting with this.
Try it yourself. You will be shocked. Though sharing EXACT math for why this works, and potentially improving it would be ideal if you could.
Enabling standard cfg with the distilled cfg produces much better results overall now. Every model I've tried has behaved more like the dedistilled model.
Update: Here's some forge code to play with.
Replace the whole nn\flux.py file with this one.
Inside you'll see a method named translate_guidance here at the top. You can adjust the input for this code below, and adjust the various parameters here to play with.
You'll likely need to fully restart forge per change since I didn't bother rigging it up to a notebook or a toggle somewhere else, but it's definitely working and a good showcase of proof of concept.
These formulas can be expanded, changed, added, whatever. Just stick to the methodology.
Currently it's set to print a bunch of debug info, so you can monitor there to see what is happening specifically if you wish.
I've updated the code slightly down at the bottom so everything below "translate_guidance" isn't necessary if you paste the whole file, allowing the code at the top to simply have it's "method" parameter changed to reflect the desired guidance variation.
I've adjusted inverted cosine to better reflect a more consistent and accurate outcome value.
inverted cosine is better served with a very high peak for start and finish, with the middle being left up to lower distilled cfg based on my testing.
import math
import torch
from torch import nn
from einops import rearrange, repeat
from backend.attention import attention_function
from backend.utils import fp16_fix, tensor2parameter
import torch
import math
def translate_guidance(timestep, guidance, device, method="inverted_cosine"):
"""
Interpolate and transform guidance strength based on a normalized 'timestep' in [0,1].
Arguments:
timestep : scalar or 1D tensor in [0,1], representing the normalized progress.
- 0 => beginning of inference
- 1 => end of inference
guidance : scalar or 1D tensor containing the original guidance value(s).
device : torch device on which computations will run.
method : Which transformation method to use on the guidance.
The returned guidance = (1 - t)*original_guidance + t*transformed_guidance.
The difference here is that 'transformed_guidance' itself may depend on `t`,
so we incorporate `t` both in the transformation and in the final interpolation.
Possible examples of 'method':
1) "cosine" : cos(pi*(g + t))
2) "inverted_cosine" : 3.5 - cos(pi*(g + t))
3) "sin" : sin(pi*(g + t))
4) "linear_increase" : g + t*1.5
5) "linear_decrease" : g - t*1.5
6) "random_noise" : add uniform noise once, then keep it consistent
7) "random_gaussian" : add normal(0, std=0.3) noise once
8) "random_extreme" : add uniform(-2, 2) noise once
You can of course tailor these transformations to your training/inference scheme.
"""
# Ensure timestep is [0,1].
t = timestep.to(device, dtype=torch.float32).clamp_(0.0, 1.0)
# Original guidance as float on the correct device
g = guidance.to(device, dtype=torch.float32)
# 1) Decide how the *transformed* guidance is computed, including dependence on t
if method == "cosine":
# Incorporate t so that the transform changes over time
# e.g. cos(pi*(g + t)) yields a changing transformation
transformed = torch.cos(math.pi * (g + t))
elif method == "inverted_cosine":
# Shift by 3.5, but also incorporate t so it evolves over time
# If you use '3.5 - cos(pi*g)', it may remain constant if g is constant.
# So add t to ensure it shifts each step:
transformed = 5.0 - torch.cos(math.pi * (g + t))
elif method == "sin":
# Use g + t inside the sine for changing transformations
transformed = torch.sin(math.pi * (g + t))
elif method == "linear_increase":
# Let the transformation itself incorporate t. For example:
transformed = g + (1.5 * t)
elif method == "linear_decrease":
transformed = g - (1.5 * t)
elif method == "random_noise":
# If you truly want random noise each call, apply it here
# But be aware that if you keep calling it, you'll get new noise every time
noise = torch.empty_like(g).uniform_(-0.75, 0.75)
# Optionally incorporate t in the final scaled noise
transformed = g + (noise * t)
elif method == "random_gaussian":
# Same idea, but Gaussian
noise = torch.randn_like(g) * 0.3
transformed = g + (noise * t)
elif method == "random_extreme":
noise = torch.empty_like(g).uniform_(-2.0, 2.0)
transformed = g + (noise * t)
else:
# Fallback if method is invalid or not recognized
transformed = g
# 2) Interpolate between the original guidance (g) and the new transformed value
# The factor t in the interpolation means:
# - When t=0, you stick to the original g.
# - When t=1, you fully adopt 'transformed'.
# If the transform also uses t, you get a more pronounced shift over time.
print("Original guidance:", g)
print("Transformed guidance:", transformed)
print("Linear progress (t):", t)
print("Adjusted Transform:", t * transformed)
print("Timestep Adjusted:", (1.0 - t) * g)
print("Resulting guidance:", (1.0 - t) * g + (t * transformed))
out = (1.0 - t) * g + (t * transformed)
# You can log for debugging:
# print("Original guidance:", g)
# print("Transformed guidance:", transformed)
# print("Linear progress (t):", t)
# print("Resulting guidance:", out)
return out
# --------------------------------------------------------------------------------
# Core attention + RoPE
# --------------------------------------------------------------------------------
def attention(q, k, v, pe):
q, k = apply_rope(q, k, pe)
x = attention_function(q, k, v, q.shape[1], skip_reshape=True)
return x
def rope(pos, dim, theta):
if pos.device.type == "mps" or pos.device.type == "xpu":
scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
else:
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta ** scale)
out = pos.unsqueeze(-1) * omega.unsqueeze(0)
cos_out = torch.cos(out)
sin_out = torch.sin(out)
out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
del cos_out, sin_out
b, n, d, _ = out.shape
out = out.view(b, n, d, 2, 2)
return out.float()
def apply_rope(xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
del xq_, xk_
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
# --------------------------------------------------------------------------------
# Timestep Embedding
# --------------------------------------------------------------------------------
def timestep_embedding(t, dim, max_period=10000, time_factor=1000.0):
t = time_factor * t
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
args = t[:, None].float() * freqs[None]
del freqs
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
del args
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
if torch.is_floating_point(t):
embedding = embedding.to(t)
return embedding
# --------------------------------------------------------------------------------
# Positional EmbedND
# --------------------------------------------------------------------------------
class EmbedND(nn.Module):
def __init__(self, dim, theta, axes_dim):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids):
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
del ids, n_axes
return emb.unsqueeze(1)
# --------------------------------------------------------------------------------
# MLPEmbedder
# --------------------------------------------------------------------------------
class MLPEmbedder(nn.Module):
def __init__(self, in_dim, hidden_dim):
super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
self.silu = nn.SiLU()
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
def forward(self, x):
x = self.silu(self.in_layer(x))
return self.out_layer(x)
# --------------------------------------------------------------------------------
# RMSNorm / QKNorm
# --------------------------------------------------------------------------------
if hasattr(torch, 'rms_norm'):
functional_rms_norm = torch.rms_norm
else:
def functional_rms_norm(x, normalized_shape, weight, eps):
if x.dtype in [torch.bfloat16, torch.float32]:
n = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + eps) * weight
else:
n = torch.rsqrt(torch.mean(x.float() ** 2, dim=-1, keepdim=True) + eps).to(x.dtype) * weight
return x * n
class RMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.weight = None
self.scale = nn.Parameter(torch.ones(dim))
self.eps = 1e-6
self.normalized_shape = [dim]
def forward(self, x):
if self.scale.dtype != x.dtype:
self.scale = tensor2parameter(self.scale.to(dtype=x.dtype))
return functional_rms_norm(x, self.normalized_shape, self.scale, self.eps)
class QKNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.query_norm = RMSNorm(dim)
self.key_norm = RMSNorm(dim)
def forward(self, q, k, v):
del v
q = self.query_norm(q)
k = self.key_norm(k)
return q.to(k), k.to(q)
# --------------------------------------------------------------------------------
# Self-Attention
# --------------------------------------------------------------------------------
class SelfAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim)
def forward(self, x, pe):
qkv = self.qkv(x)
B, L, _ = qkv.shape
qkv = qkv.view(B, L, 3, self.num_heads, -1)
q, k, v = qkv.permute(2, 0, 3, 1, 4)
del qkv
q, k = self.norm(q, k, v)
x = attention(q, k, v, pe=pe)
del q, k, v
x = self.proj(x)
return x
# --------------------------------------------------------------------------------
# Modulation + Blocks
# --------------------------------------------------------------------------------
class Modulation(nn.Module):
def __init__(self, dim, double):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
def forward(self, vec):
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
return out
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size, num_heads, mlp_ratio, qkv_bias=False):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True)
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
self.txt_mod = Modulation(hidden_size, double=True)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
def forward(self, img, txt, vec, pe):
# Image
img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate = self.img_mod(vec)
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1_scale) * img_modulated + img_mod1_shift
del img_mod1_shift, img_mod1_scale
img_qkv = self.img_attn.qkv(img_modulated)
del img_modulated
B, L, _ = img_qkv.shape
H = self.num_heads
D = img_qkv.shape[-1] // (3 * H)
img_q, img_k, img_v = img_qkv.view(B, L, 3, H, D).permute(2, 0, 3, 1, 4)
del img_qkv
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# Text
txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate = self.txt_mod(vec)
del vec
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1_scale) * txt_modulated + txt_mod1_shift
del txt_mod1_shift, txt_mod1_scale
txt_qkv = self.txt_attn.qkv(txt_modulated)
del txt_modulated
B, L, _ = txt_qkv.shape
txt_q, txt_k, txt_v = txt_qkv.view(B, L, 3, H, D).permute(2, 0, 3, 1, 4)
del txt_qkv
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# Merge
q = torch.cat((txt_q, img_q), dim=2)
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)
del txt_q, img_q, txt_k, img_k, txt_v, img_v
attn = attention(q, k, v, pe=pe)
del pe, q, k, v
txt_attn, img_attn = attn[:, :txt.shape[1]], attn[:, txt.shape[1]:]
del attn
# Combine Image
img = img + img_mod1_gate * self.img_attn.proj(img_attn)
del img_attn, img_mod1_gate
img = img + img_mod2_gate * self.img_mlp((1 + img_mod2_scale) * self.img_norm2(img) + img_mod2_shift)
del img_mod2_gate, img_mod2_scale, img_mod2_shift
# Combine Text
txt = txt + txt_mod1_gate * self.txt_attn.proj(txt_attn)
del txt_attn, txt_mod1_gate
txt = txt + txt_mod2_gate * self.txt_mlp((1 + txt_mod2_scale) * self.txt_norm2(txt) + txt_mod2_shift)
del txt_mod2_gate, txt_mod2_scale, txt_mod2_shift
# Safety fix for half-precision
txt = fp16_fix(txt)
return img, txt
class SingleStreamBlock(nn.Module):
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, qk_scale=None):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
self.norm = QKNorm(head_dim)
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False)
def forward(self, x, vec, pe):
mod_shift, mod_scale, mod_gate = self.modulation(vec)
del vec
x_mod = (1 + mod_scale) * self.pre_norm(x) + mod_shift
del mod_shift, mod_scale
qkv, mlp = torch.split(
self.linear1(x_mod),
[3 * self.hidden_dim, self.mlp_hidden_dim],
dim=-1
)
del x_mod
qkv = qkv.view(qkv.size(0), qkv.size(1), 3, self.num_heads, self.hidden_dim // self.num_heads)
q, k, v = qkv.permute(2, 0, 3, 1, 4)
del qkv
q, k = self.norm(q, k, v)
attn = attention(q, k, v, pe=pe)
del q, k, v, pe
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), dim=2))
del attn, mlp
x = x + mod_gate * output
del mod_gate, output
x = fp16_fix(x)
return x
# --------------------------------------------------------------------------------
# Last Layer
# --------------------------------------------------------------------------------
class LastLayer(nn.Module):
def __init__(self, hidden_size, patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
def forward(self, x, vec):
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
del vec
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
del scale, shift
x = self.linear(x)
return x
# --------------------------------------------------------------------------------
# Final IntegratedFluxTransformer2DModel
# --------------------------------------------------------------------------------
class IntegratedFluxTransformer2DModel(nn.Module):
def __init__(
self,
in_channels: int,
vec_in_dim: int,
context_in_dim: int,
hidden_size: int,
mlp_ratio: float,
num_heads: int,
depth: int,
depth_single_blocks: int,
axes_dim: list[int],
theta: int,
qkv_bias: bool,
guidance_embed: bool
):
super().__init__()
self.guidance_embed = guidance_embed
self.in_channels = in_channels * 4
self.out_channels = self.in_channels
if hidden_size % num_heads != 0:
raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}")
pe_dim = hidden_size // num_heads
if sum(axes_dim) != pe_dim:
raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = hidden_size
self.num_heads = num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(vec_in_dim, self.hidden_size)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
if guidance_embed
else nn.Identity()
)
self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
self.double_blocks = nn.ModuleList([
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
)
for _ in range(depth)
])
self.single_blocks = nn.ModuleList([
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio)
for _ in range(depth_single_blocks)
])
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
def inner_forward(
self,
img,
img_ids,
txt,
txt_ids,
timesteps,
y,
guidance=None,
guidance_method="inverted_cosine"
):
"""
The main forward pass that merges image and text embeddings with
optional guidance embedding for classifier-free guidance or related tasks.
guidance_method: str
One of:
"cosine", "inverted_cosine", "sin", "linear_increase",
"linear_decrease", "random_noise", "random_gaussian", "random_extreme"
"""
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# Image patches => embed
img = self.img_in(img)
# Timestep => embed
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
# Guidance (if enabled)
if self.guidance_embed:
if guidance is None:
raise ValueError("Missing 'guidance' for guidance-distilled model.")
# Translate guidance with a chosen method, using the same normalized 'timesteps'
method_guidance = translate_guidance(timesteps, guidance, img.device, guidance_method)
vec = vec + self.guidance_in(timestep_embedding(method_guidance, 256).to(img.dtype))
# Additional vector conditioning
vec = vec + self.vector_in(y)
# Text => embed
txt = self.txt_in(txt)
del y, guidance
# Positional IDs => embed
ids = torch.cat((txt_ids, img_ids), dim=1)
del txt_ids, img_ids
pe = self.pe_embedder(ids)
del ids
# Double-Stream blocks (joint image/text attention)
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
# Merge text+image, Single-Stream blocks
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
del pe
# Separate out the image portion
img = img[:, txt.shape[1]:, ...]
del txt
# Final output
img = self.final_layer(img, vec)
del vec
return img
def forward(self, x, timestep, context, y, guidance=None, **kwargs):
bs, c, h, w = x.shape
input_device = x.device
input_dtype = x.dtype
# Patchify the input
patch_size = 2
pad_h = (patch_size - x.shape[-2] % patch_size) % patch_size
pad_w = (patch_size - x.shape[-1] % patch_size) % patch_size
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode="circular")
img = rearrange(x, "b c (hh ph) (ww pw) -> b (hh ww) (c ph pw)",
ph=patch_size, pw=patch_size)
del x, pad_h, pad_w
# Dimensions after patchification
h_len = (h + (patch_size // 2)) // patch_size
w_len = (w + (patch_size // 2)) // patch_size
# Build image IDs
img_ids = torch.zeros((h_len, w_len, 3),
device=input_device, dtype=input_dtype)
# The 'height' index
img_ids[..., 1] = torch.linspace(
0, h_len - 1, steps=h_len, device=input_device, dtype=input_dtype
)[:, None]
# The 'width' index
img_ids[..., 2] = torch.linspace(
0, w_len - 1, steps=w_len, device=input_device, dtype=input_dtype
)[None, :]
img_ids = repeat(img_ids, "hh ww c -> b (hh ww) c", b=bs)
# Build text IDs (dummy: all zeros except possible batch dimension)
txt_ids = torch.zeros((bs, context.shape[1], 3),
device=input_device, dtype=input_dtype)
del input_device, input_dtype
# Pass to inner_forward
# We can specify guidance_method in kwargs if we want something else
out = self.inner_forward(
img,
img_ids,
context,
txt_ids,
timesteps=timestep,
y=y,
guidance=guidance,
# guidance_method=kwargs.get("guidance_method", "inverted_cosine")
)
del img, img_ids, txt_ids, timestep, context
# Un-patchify the output
out = rearrange(
out,
"b (hh ww) (c ph pw) -> b c (hh ph) (ww pw)",
hh=h_len, ww=w_len, ph=2, pw=2
)[:, :, :h, :w]
del h_len, w_len, bs
return out
Experiment 1: direct control.
Thanks to a combination effort and thought experiment I had with Felldude, we have made a potential discovery here.
First; we have diffusion_engine/flux.py
import torch
from huggingface_guess import model_list
from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects
from backend.patcher.clip import CLIP
from backend.patcher.vae import VAE
from backend.patcher.unet import UnetPatcher
from backend.text_processing.classic_engine import ClassicTextProcessingEngine
from backend.text_processing.t5_engine import T5TextProcessingEngine
from backend.args import dynamic_args
from backend.modules.k_prediction import PredictionFlux
from backend import memory_management
class Flux(ForgeDiffusionEngine):
matched_guesses = [model_list.Flux, model_list.FluxSchnell]
def __init__(self, estimated_config, huggingface_components):
super().__init__(estimated_config, huggingface_components)
self.is_inpaint = False
clip = CLIP(
model_dict={
'clip_l': huggingface_components['text_encoder'],
't5xxl': huggingface_components['text_encoder_2']
},
tokenizer_dict={
'clip_l': huggingface_components['tokenizer'],
't5xxl': huggingface_components['tokenizer_2']
}
)
vae = VAE(model=huggingface_components['vae'])
if 'schnell' in estimated_config.huggingface_repo.lower():
k_predictor = PredictionFlux(
mu=1.0
)
else:
k_predictor = PredictionFlux(
seq_len=4096,
base_seq_len=256,
max_seq_len=4096,
base_shift=0.5,
max_shift=1.15,
)
self.use_distilled_cfg_scale = True
unet = UnetPatcher.from_model(
model=huggingface_components['transformer'],
diffusers_scheduler=None,
k_predictor=k_predictor,
config=estimated_config
)
self.text_processing_engine_l = ClassicTextProcessingEngine(
text_encoder=clip.cond_stage_model.clip_l,
tokenizer=clip.tokenizer.clip_l,
embedding_dir=dynamic_args['embedding_dir'],
embedding_key='clip_l',
embedding_expected_shape=768,
emphasis_name=dynamic_args['emphasis_name'],
text_projection=False,
minimal_clip_skip=1,
clip_skip=1,
return_pooled=True,
final_layer_norm=True,
)
self.text_processing_engine_t5 = T5TextProcessingEngine(
text_encoder=clip.cond_stage_model.t5xxl,
tokenizer=clip.tokenizer.t5xxl,
emphasis_name=dynamic_args['emphasis_name'],
)
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
self.forge_objects_original = self.forge_objects.shallow_copy()
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
def set_clip_skip(self, clip_skip):
self.text_processing_engine_l.clip_skip = clip_skip
@torch.inference_mode()
def get_learned_conditioning(self, prompt: list[str]):
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
cond_l, pooled_l = self.text_processing_engine_l(prompt)
cond_t5 = self.text_processing_engine_t5(prompt)
cond = dict(crossattn=cond_t5, vector=pooled_l)
if self.use_distilled_cfg_scale:
distilled_cfg_scale = getattr(prompt, 'distilled_cfg_scale', 3.5) or 3.5
cond['guidance'] = torch.FloatTensor([distilled_cfg_scale] * len(prompt))
print(f'Distilled CFG Scale: {distilled_cfg_scale}')
else:
print('Distilled CFG Scale will be ignored for Schnell')
return cond
@torch.inference_mode()
def get_prompt_lengths_on_ui(self, prompt):
token_count = len(self.text_processing_engine_t5.tokenize([prompt])[0])
return token_count, max(255, token_count)
@torch.inference_mode()
def encode_first_stage(self, x):
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
sample = self.forge_objects.vae.first_stage_model.process_in(sample)
return sample.to(x)
@torch.inference_mode()
def decode_first_stage(self, x):
sample = self.forge_objects.vae.first_stage_model.process_out(x)
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
return sample.to(x)
I tinkered with this for a bit, but noticed that forge only passes one distillation config per, so I couldn't do what I wanted with it.
Experiment 2: Deeper I go
Then I went deeper.
Down the rabbit hole I went, until I found the guts of Flux itself.
# Single File Implementation of Flux with aggressive optimizations, Copyright Forge 2024
# If used outside Forge, only non-commercial use is allowed.
# See also https://github.com/black-forest-labs/flux
import math
import torch
from torch import nn
from einops import rearrange, repeat
from backend.attention import attention_function
from backend.utils import fp16_fix, tensor2parameter
def attention(q, k, v, pe):
q, k = apply_rope(q, k, pe)
x = attention_function(q, k, v, q.shape[1], skip_reshape=True)
return x
def rope(pos, dim, theta):
if pos.device.type == "mps" or pos.device.type == "xpu":
scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
else:
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta ** scale)
# out = torch.einsum("...n,d->...nd", pos, omega)
out = pos.unsqueeze(-1) * omega.unsqueeze(0)
cos_out = torch.cos(out)
sin_out = torch.sin(out)
out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
del cos_out, sin_out
# out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
b, n, d, _ = out.shape
out = out.view(b, n, d, 2, 2)
return out.float()
def apply_rope(xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
del xq_, xk_
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def timestep_embedding(t, dim, max_period=10000, time_factor=1000.0):
t = time_factor * t
half = dim // 2
# TODO: Once A trainer for flux get popular, make timestep_embedding consistent to that trainer
# Do not block CUDA steam, but having about 1e-4 differences with Flux official codes:
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
# Block CUDA steam, but consistent with official codes:
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
args = t[:, None].float() * freqs[None]
del freqs
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
del args
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
if torch.is_floating_point(t):
embedding = embedding.to(t)
return embedding
class EmbedND(nn.Module):
def __init__(self, dim, theta, axes_dim):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids):
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
del ids, n_axes
return emb.unsqueeze(1)
class MLPEmbedder(nn.Module):
def __init__(self, in_dim, hidden_dim):
super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
self.silu = nn.SiLU()
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
def forward(self, x):
x = self.silu(self.in_layer(x))
return self.out_layer(x)
if hasattr(torch, 'rms_norm'):
functional_rms_norm = torch.rms_norm
else:
def functional_rms_norm(x, normalized_shape, weight, eps):
if x.dtype in [torch.bfloat16, torch.float32]:
n = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + eps) * weight
else:
n = torch.rsqrt(torch.mean(x.float() ** 2, dim=-1, keepdim=True) + eps).to(x.dtype) * weight
return x * n
class RMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.weight = None # to trigger module_profile
self.scale = nn.Parameter(torch.ones(dim))
self.eps = 1e-6
self.normalized_shape = [dim]
def forward(self, x):
if self.scale.dtype != x.dtype:
self.scale = tensor2parameter(self.scale.to(dtype=x.dtype))
return functional_rms_norm(x, self.normalized_shape, self.scale, self.eps)
class QKNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.query_norm = RMSNorm(dim)
self.key_norm = RMSNorm(dim)
def forward(self, q, k, v):
del v
q = self.query_norm(q)
k = self.key_norm(k)
return q.to(k), k.to(q)
class SelfAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim)
def forward(self, x, pe):
qkv = self.qkv(x)
# q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
B, L, _ = qkv.shape
qkv = qkv.view(B, L, 3, self.num_heads, -1)
q, k, v = qkv.permute(2, 0, 3, 1, 4)
del qkv
q, k = self.norm(q, k, v)
x = attention(q, k, v, pe=pe)
del q, k, v
x = self.proj(x)
return x
class Modulation(nn.Module):
def __init__(self, dim, double):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
def forward(self, vec):
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
return out
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size, num_heads, mlp_ratio, qkv_bias=False):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True)
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
self.txt_mod = Modulation(hidden_size, double=True)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
def forward(self, img, txt, vec, pe):
img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate = self.img_mod(vec)
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1_scale) * img_modulated + img_mod1_shift
del img_mod1_shift, img_mod1_scale
img_qkv = self.img_attn.qkv(img_modulated)
del img_modulated
# img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
B, L, _ = img_qkv.shape
H = self.num_heads
D = img_qkv.shape[-1] // (3 * H)
img_q, img_k, img_v = img_qkv.view(B, L, 3, H, D).permute(2, 0, 3, 1, 4)
del img_qkv
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate = self.txt_mod(vec)
del vec
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1_scale) * txt_modulated + txt_mod1_shift
del txt_mod1_shift, txt_mod1_scale
txt_qkv = self.txt_attn.qkv(txt_modulated)
del txt_modulated
# txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
B, L, _ = txt_qkv.shape
txt_q, txt_k, txt_v = txt_qkv.view(B, L, 3, H, D).permute(2, 0, 3, 1, 4)
del txt_qkv
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
q = torch.cat((txt_q, img_q), dim=2)
del txt_q, img_q
k = torch.cat((txt_k, img_k), dim=2)
del txt_k, img_k
v = torch.cat((txt_v, img_v), dim=2)
del txt_v, img_v
attn = attention(q, k, v, pe=pe)
del pe, q, k, v
txt_attn, img_attn = attn[:, :txt.shape[1]], attn[:, txt.shape[1]:]
del attn
img = img + img_mod1_gate * self.img_attn.proj(img_attn)
del img_attn, img_mod1_gate
img = img + img_mod2_gate * self.img_mlp((1 + img_mod2_scale) * self.img_norm2(img) + img_mod2_shift)
del img_mod2_gate, img_mod2_scale, img_mod2_shift
txt = txt + txt_mod1_gate * self.txt_attn.proj(txt_attn)
del txt_attn, txt_mod1_gate
txt = txt + txt_mod2_gate * self.txt_mlp((1 + txt_mod2_scale) * self.txt_norm2(txt) + txt_mod2_shift)
del txt_mod2_gate, txt_mod2_scale, txt_mod2_shift
txt = fp16_fix(txt)
return img, txt
class SingleStreamBlock(nn.Module):
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, qk_scale=None):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
self.norm = QKNorm(head_dim)
self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False)
def forward(self, x, vec, pe):
mod_shift, mod_scale, mod_gate = self.modulation(vec)
del vec
x_mod = (1 + mod_scale) * self.pre_norm(x) + mod_shift
del mod_shift, mod_scale
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
del x_mod
# q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
qkv = qkv.view(qkv.size(0), qkv.size(1), 3, self.num_heads, self.hidden_size // self.num_heads)
q, k, v = qkv.permute(2, 0, 3, 1, 4)
del qkv
q, k = self.norm(q, k, v)
attn = attention(q, k, v, pe=pe)
del q, k, v, pe
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), dim=2))
del attn, mlp
x = x + mod_gate * output
del mod_gate, output
x = fp16_fix(x)
return x
class LastLayer(nn.Module):
def __init__(self, hidden_size, patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
def forward(self, x, vec):
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
del vec
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
del scale, shift
x = self.linear(x)
return x
class IntegratedFluxTransformer2DModel(nn.Module):
def __init__(self, in_channels: int, vec_in_dim: int, context_in_dim: int, hidden_size: int, mlp_ratio: float, num_heads: int, depth: int, depth_single_blocks: int, axes_dim: list[int], theta: int, qkv_bias: bool, guidance_embed: bool):
super().__init__()
self.guidance_embed = guidance_embed
self.in_channels = in_channels * 4
self.out_channels = self.in_channels
if hidden_size % num_heads != 0:
raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}")
pe_dim = hidden_size // num_heads
if sum(axes_dim) != pe_dim:
raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = hidden_size
self.num_heads = num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(vec_in_dim, self.hidden_size)
self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if guidance_embed else nn.Identity()
self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
)
for _ in range(depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio)
for _ in range(depth_single_blocks)
]
)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
def inner_forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None):
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
if self.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
# Generate random noise within the range [-2, 2]
random_noise = torch.empty(1).uniform_(-0.75, 0.75).to(img.device)
# Adjust guidance strength with random noise
adjusted_guidance = guidance + random_noise
# Embed the adjusted guidance into the timestep
vec = vec + self.guidance_in(timestep_embedding(adjusted_guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
del y, guidance
ids = torch.cat((txt_ids, img_ids), dim=1)
del txt_ids, img_ids
pe = self.pe_embedder(ids)
del ids
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
del pe
img = img[:, txt.shape[1]:, ...]
del txt
img = self.final_layer(img, vec)
del vec
return img
def forward(self, x, timestep, context, y, guidance=None, **kwargs):
bs, c, h, w = x.shape
input_device = x.device
input_dtype = x.dtype
patch_size = 2
pad_h = (patch_size - x.shape[-2] % patch_size) % patch_size
pad_w = (patch_size - x.shape[-1] % patch_size) % patch_size
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode="circular")
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
del x, pad_h, pad_w
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=input_device, dtype=input_dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=input_device, dtype=input_dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=input_device, dtype=input_dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=input_device, dtype=input_dtype)
del input_device, input_dtype
out = self.inner_forward(img, img_ids, context, txt_ids, timestep, y, guidance)
del img, img_ids, txt_ids, timestep, context
out = rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:, :, :h, :w]
del h_len, w_len, bs
return out
I noticed a few core things here;
There is a comment in the timestep embedding code.
def timestep_embedding(t, dim, max_period=10000, time_factor=1000.0):
t = time_factor * t
half = dim // 2
# TODO: Once A trainer for flux get popular, make timestep_embedding consistent to that trainer
There is no config directly trickled to this point.
Here you'll see my badly optimized contribution meant for this experiment.
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
# Generate random noise within the range [-2, 2]
random_noise = torch.empty(1).uniform_(-0.75, 0.75).to(img.device)
# Adjust guidance strength with random noise
adjusted_guidance = guidance + random_noise
# Embed the adjusted guidance into the timestep
vec = vec + self.guidance_in(timestep_embedding(adjusted_guidance, 256).to(img.dtype))
Alright, so my hypothesis was simple. If we can control the distillation every step, we have a toggle to determine a few factors. Primarily speed and quality. A nearly completed image, does not necessarily need full distillation if it's a flat 2d image or an anime character, just like a 3d image does not gain anywhere near as much earlier in the distillation configuration.
So I tried a few experiments; bell curve, matching the cosine math, linear reduction, linear increase, inverted cosine, random noise, and a few others.
The outcomes were... very interesting each time.
Higher speed when the cfg reduces later, distorted hands with random, additional details when the cfg is reduced in reduction and then increased at the end. The code posted here is random noise.
The most interesting results were from less distilled cfg to more, and then some reduction near the end but not too much. It also reduces the likelihood of image frying and burning of the edges, or destroying coloration.
Training resnets for rapid difference classification during inference time.
My idea was to introduce and train block-by-block resnet 50s or smaller to capture the remnant roundings while comparing Schnel values to Dev, but that didn't pan out. The numbers are drastically different, that and running both of them took an H100 and I didn't want to run that for very long. The classifiers would have been a value between 0 and 10, 0 being deviance below a certain value, 10 being deviance above a certain value.
The captured data did not pan out however, it's too drastically different, so I moved back to create a bit of a proof of concept tool, rather than going full new model training mode.
On top of that, the implications by reading the Forge code imply that they REALLY don't like passing configs downward into the actual model logic code like how I did... but I did it anywayyyyyyy... So this is most definitely a topical experiment for now; but this code will most definitely work if you place it into forge manually until I prepare the tool.
A new tool to play with later today.
I've replicated a quasi-similar scheduler mathematics to the distillation cfg that is implied by Flux1D and Flux1dPro2 and will be pushing my distillation cfg timestep based tool to a github repo meant for comfyui later today, and formatting careful changes to forge. These two models are nearly identical configuration-wise, and the outcomes from both are so drastically different that it's made me question everything related to distillation.
https://huggingface.co/multimodalart/FLUX.1-dev2pro-full/blob/main/scheduler/scheduler_config.json
This distillation configuration concept started as a bit of a toy, something that could have been a bit of a side-product or a byproduct of maybe remnants of something internal in the Flux architecture itself.
THE BIGGEST TAKEAWAY HERE.
This is CLEARLY meant to be controlled in a DYNAMIC FASHION step-by-step or in a more specific timestep allocated value, and NOT a STATIC FASHION as implied by these inference engines by them giving us a single slider that does a single internal multiplier.
There are many MANY steps in some image requests, that quite frankly finish early. If you reduce the distilled CFG early, you end up having a better outcome overall in less time.
I'll be making a a more thorough write-up with comparison photos, and pushing a working distillation cfg shifter into forge and comfyui for testing and playing with.