Sign In

Success. Semi-DeDistilling all Flux models at once; Flux1D Distillation CFG Forge Experiment

9
Success. Semi-DeDistilling all Flux models at once; Flux1D Distillation CFG Forge Experiment

Update: IT WORKS!!!

Inverse cosine is a huge winner. It provides massive fidelity boosts and accuracy when using Simulacrum V4. It introduces much much more information than I could have expected.

The listed math here is similar to the actual math used when training the model, and it seems to have activated a series of pathways that have gone under the radar in a form of bypass that has been unrelentingly destroyed by the stronger UNET sections.

It seems to have enabled my CLIP_L to truly shine in a way that is very unexpected and very welcome.

It has enabled similar behavior to the core flux, while retaining many of the SFW and NSFW elements that I trained into Simulacrum's core, and the subsequent followup trainings!

This has made me very excited!

In the process, it seems to have enabled consistent CFG uniformity as well. What a treat of a byproduct.

Both CFG and Distilled CFG have functional negative prompting with this.

Try it yourself. You will be shocked. Though sharing EXACT math for why this works, and potentially improving it would be ideal if you could.

Enabling standard cfg with the distilled cfg produces much better results overall now. Every model I've tried has behaved more like the dedistilled model.

Update: Here's some forge code to play with.

Replace the whole nn\flux.py file with this one.

Inside you'll see a method named translate_guidance here at the top. You can adjust the input for this code below, and adjust the various parameters here to play with.

You'll likely need to fully restart forge per change since I didn't bother rigging it up to a notebook or a toggle somewhere else, but it's definitely working and a good showcase of proof of concept.

These formulas can be expanded, changed, added, whatever. Just stick to the methodology.

Currently it's set to print a bunch of debug info, so you can monitor there to see what is happening specifically if you wish.

I've updated the code slightly down at the bottom so everything below "translate_guidance" isn't necessary if you paste the whole file, allowing the code at the top to simply have it's "method" parameter changed to reflect the desired guidance variation.

I've adjusted inverted cosine to better reflect a more consistent and accurate outcome value.

inverted cosine is better served with a very high peak for start and finish, with the middle being left up to lower distilled cfg based on my testing.

import math
import torch
from torch import nn
from einops import rearrange, repeat
from backend.attention import attention_function
from backend.utils import fp16_fix, tensor2parameter

import torch
import math


def translate_guidance(timestep, guidance, device, method="inverted_cosine"):
    """
    Interpolate and transform guidance strength based on a normalized 'timestep' in [0,1].

    Arguments:
      timestep : scalar or 1D tensor in [0,1], representing the normalized progress.
        - 0 => beginning of inference
        - 1 => end of inference
      guidance : scalar or 1D tensor containing the original guidance value(s).
      device   : torch device on which computations will run.
      method   : Which transformation method to use on the guidance.

    The returned guidance = (1 - t)*original_guidance + t*transformed_guidance.
    The difference here is that 'transformed_guidance' itself may depend on `t`,
    so we incorporate `t` both in the transformation and in the final interpolation.

    Possible examples of 'method':
      1) "cosine"          : cos(pi*(g + t))
      2) "inverted_cosine" : 3.5 - cos(pi*(g + t))
      3) "sin"             : sin(pi*(g + t))
      4) "linear_increase" : g + t*1.5
      5) "linear_decrease" : g - t*1.5
      6) "random_noise"    : add uniform noise once, then keep it consistent
      7) "random_gaussian" : add normal(0, std=0.3) noise once
      8) "random_extreme"  : add uniform(-2, 2) noise once

    You can of course tailor these transformations to your training/inference scheme.
    """
    # Ensure timestep is [0,1].
    t = timestep.to(device, dtype=torch.float32).clamp_(0.0, 1.0)

    # Original guidance as float on the correct device
    g = guidance.to(device, dtype=torch.float32)

    # 1) Decide how the *transformed* guidance is computed, including dependence on t
    if method == "cosine":
        # Incorporate t so that the transform changes over time
        # e.g. cos(pi*(g + t)) yields a changing transformation
        transformed = torch.cos(math.pi * (g + t))

    elif method == "inverted_cosine":
        # Shift by 3.5, but also incorporate t so it evolves over time
        # If you use '3.5 - cos(pi*g)', it may remain constant if g is constant.
        # So add t to ensure it shifts each step:
        transformed = 5.0 - torch.cos(math.pi * (g + t))

    elif method == "sin":
        # Use g + t inside the sine for changing transformations
        transformed = torch.sin(math.pi * (g + t))

    elif method == "linear_increase":
        # Let the transformation itself incorporate t. For example:
        transformed = g + (1.5 * t)

    elif method == "linear_decrease":
        transformed = g - (1.5 * t)

    elif method == "random_noise":
        # If you truly want random noise each call, apply it here
        # But be aware that if you keep calling it, you'll get new noise every time
        noise = torch.empty_like(g).uniform_(-0.75, 0.75)
        # Optionally incorporate t in the final scaled noise
        transformed = g + (noise * t)

    elif method == "random_gaussian":
        # Same idea, but Gaussian
        noise = torch.randn_like(g) * 0.3
        transformed = g + (noise * t)

    elif method == "random_extreme":
        noise = torch.empty_like(g).uniform_(-2.0, 2.0)
        transformed = g + (noise * t)

    else:
        # Fallback if method is invalid or not recognized
        transformed = g

    # 2) Interpolate between the original guidance (g) and the new transformed value
    #    The factor t in the interpolation means:
    #      - When t=0, you stick to the original g.
    #      - When t=1, you fully adopt 'transformed'.
    #    If the transform also uses t, you get a more pronounced shift over time.
    print("Original guidance:", g)
    print("Transformed guidance:", transformed)
    print("Linear progress (t):", t)
    print("Adjusted Transform:", t * transformed)
    print("Timestep Adjusted:", (1.0 - t) * g)
    print("Resulting guidance:", (1.0 - t) * g + (t * transformed))
    out = (1.0 - t) * g + (t * transformed)

    # You can log for debugging:
    # print("Original guidance:", g)
    # print("Transformed guidance:", transformed)
    # print("Linear progress (t):", t)
    # print("Resulting guidance:", out)

    return out

# --------------------------------------------------------------------------------
# Core attention + RoPE
# --------------------------------------------------------------------------------

def attention(q, k, v, pe):
    q, k = apply_rope(q, k, pe)
    x = attention_function(q, k, v, q.shape[1], skip_reshape=True)
    return x

def rope(pos, dim, theta):
    if pos.device.type == "mps" or pos.device.type == "xpu":
        scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
    else:
        scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
    omega = 1.0 / (theta ** scale)

    out = pos.unsqueeze(-1) * omega.unsqueeze(0)
    cos_out = torch.cos(out)
    sin_out = torch.sin(out)
    out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
    del cos_out, sin_out

    b, n, d, _ = out.shape
    out = out.view(b, n, d, 2, 2)
    return out.float()

def apply_rope(xq, xk, freqs_cis):
    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
    xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
    xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
    del xq_, xk_
    return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)


# --------------------------------------------------------------------------------
# Timestep Embedding
# --------------------------------------------------------------------------------

def timestep_embedding(t, dim, max_period=10000, time_factor=1000.0):
    t = time_factor * t
    half = dim // 2
    freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
    args = t[:, None].float() * freqs[None]
    del freqs
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    del args
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    if torch.is_floating_point(t):
        embedding = embedding.to(t)
    return embedding


# --------------------------------------------------------------------------------
# Positional EmbedND
# --------------------------------------------------------------------------------

class EmbedND(nn.Module):
    def __init__(self, dim, theta, axes_dim):
        super().__init__()
        self.dim = dim
        self.theta = theta
        self.axes_dim = axes_dim

    def forward(self, ids):
        n_axes = ids.shape[-1]
        emb = torch.cat(
            [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
            dim=-3,
        )
        del ids, n_axes
        return emb.unsqueeze(1)


# --------------------------------------------------------------------------------
# MLPEmbedder
# --------------------------------------------------------------------------------

class MLPEmbedder(nn.Module):
    def __init__(self, in_dim, hidden_dim):
        super().__init__()
        self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
        self.silu = nn.SiLU()
        self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)

    def forward(self, x):
        x = self.silu(self.in_layer(x))
        return self.out_layer(x)


# --------------------------------------------------------------------------------
# RMSNorm / QKNorm
# --------------------------------------------------------------------------------

if hasattr(torch, 'rms_norm'):
    functional_rms_norm = torch.rms_norm
else:
    def functional_rms_norm(x, normalized_shape, weight, eps):
        if x.dtype in [torch.bfloat16, torch.float32]:
            n = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + eps) * weight
        else:
            n = torch.rsqrt(torch.mean(x.float() ** 2, dim=-1, keepdim=True) + eps).to(x.dtype) * weight
        return x * n

class RMSNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.weight = None
        self.scale = nn.Parameter(torch.ones(dim))
        self.eps = 1e-6
        self.normalized_shape = [dim]

    def forward(self, x):
        if self.scale.dtype != x.dtype:
            self.scale = tensor2parameter(self.scale.to(dtype=x.dtype))
        return functional_rms_norm(x, self.normalized_shape, self.scale, self.eps)

class QKNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.query_norm = RMSNorm(dim)
        self.key_norm = RMSNorm(dim)

    def forward(self, q, k, v):
        del v
        q = self.query_norm(q)
        k = self.key_norm(k)
        return q.to(k), k.to(q)


# --------------------------------------------------------------------------------
# Self-Attention
# --------------------------------------------------------------------------------

class SelfAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.norm = QKNorm(head_dim)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x, pe):
        qkv = self.qkv(x)
        B, L, _ = qkv.shape
        qkv = qkv.view(B, L, 3, self.num_heads, -1)
        q, k, v = qkv.permute(2, 0, 3, 1, 4)
        del qkv

        q, k = self.norm(q, k, v)
        x = attention(q, k, v, pe=pe)
        del q, k, v
        x = self.proj(x)
        return x


# --------------------------------------------------------------------------------
# Modulation + Blocks
# --------------------------------------------------------------------------------

class Modulation(nn.Module):
    def __init__(self, dim, double):
        super().__init__()
        self.is_double = double
        self.multiplier = 6 if double else 3
        self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)

    def forward(self, vec):
        out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
        return out

class DoubleStreamBlock(nn.Module):
    def __init__(self, hidden_size, num_heads, mlp_ratio, qkv_bias=False):
        super().__init__()
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        self.num_heads = num_heads
        self.hidden_size = hidden_size

        self.img_mod = Modulation(hidden_size, double=True)
        self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
        self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.img_mlp = nn.Sequential(
            nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
            nn.GELU(approximate="tanh"),
            nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
        )

        self.txt_mod = Modulation(hidden_size, double=True)
        self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
        self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.txt_mlp = nn.Sequential(
            nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
            nn.GELU(approximate="tanh"),
            nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
        )

    def forward(self, img, txt, vec, pe):
        # Image
        img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate = self.img_mod(vec)
        img_modulated = self.img_norm1(img)
        img_modulated = (1 + img_mod1_scale) * img_modulated + img_mod1_shift
        del img_mod1_shift, img_mod1_scale

        img_qkv = self.img_attn.qkv(img_modulated)
        del img_modulated
        B, L, _ = img_qkv.shape
        H = self.num_heads
        D = img_qkv.shape[-1] // (3 * H)
        img_q, img_k, img_v = img_qkv.view(B, L, 3, H, D).permute(2, 0, 3, 1, 4)
        del img_qkv
        img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)

        # Text
        txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate = self.txt_mod(vec)
        del vec
        txt_modulated = self.txt_norm1(txt)
        txt_modulated = (1 + txt_mod1_scale) * txt_modulated + txt_mod1_shift
        del txt_mod1_shift, txt_mod1_scale

        txt_qkv = self.txt_attn.qkv(txt_modulated)
        del txt_modulated
        B, L, _ = txt_qkv.shape
        txt_q, txt_k, txt_v = txt_qkv.view(B, L, 3, H, D).permute(2, 0, 3, 1, 4)
        del txt_qkv
        txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)

        # Merge
        q = torch.cat((txt_q, img_q), dim=2)
        k = torch.cat((txt_k, img_k), dim=2)
        v = torch.cat((txt_v, img_v), dim=2)
        del txt_q, img_q, txt_k, img_k, txt_v, img_v

        attn = attention(q, k, v, pe=pe)
        del pe, q, k, v
        txt_attn, img_attn = attn[:, :txt.shape[1]], attn[:, txt.shape[1]:]
        del attn

        # Combine Image
        img = img + img_mod1_gate * self.img_attn.proj(img_attn)
        del img_attn, img_mod1_gate
        img = img + img_mod2_gate * self.img_mlp((1 + img_mod2_scale) * self.img_norm2(img) + img_mod2_shift)
        del img_mod2_gate, img_mod2_scale, img_mod2_shift

        # Combine Text
        txt = txt + txt_mod1_gate * self.txt_attn.proj(txt_attn)
        del txt_attn, txt_mod1_gate
        txt = txt + txt_mod2_gate * self.txt_mlp((1 + txt_mod2_scale) * self.txt_norm2(txt) + txt_mod2_shift)
        del txt_mod2_gate, txt_mod2_scale, txt_mod2_shift

        # Safety fix for half-precision
        txt = fp16_fix(txt)
        return img, txt


class SingleStreamBlock(nn.Module):
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, qk_scale=None):
        super().__init__()
        self.hidden_dim = hidden_size
        self.num_heads = num_heads
        head_dim = hidden_size // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
        self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
        self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
        self.norm = QKNorm(head_dim)
        self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.mlp_act = nn.GELU(approximate="tanh")
        self.modulation = Modulation(hidden_size, double=False)

    def forward(self, x, vec, pe):
        mod_shift, mod_scale, mod_gate = self.modulation(vec)
        del vec

        x_mod = (1 + mod_scale) * self.pre_norm(x) + mod_shift
        del mod_shift, mod_scale

        qkv, mlp = torch.split(
            self.linear1(x_mod),
            [3 * self.hidden_dim, self.mlp_hidden_dim],
            dim=-1
        )
        del x_mod

        qkv = qkv.view(qkv.size(0), qkv.size(1), 3, self.num_heads, self.hidden_dim // self.num_heads)
        q, k, v = qkv.permute(2, 0, 3, 1, 4)
        del qkv

        q, k = self.norm(q, k, v)
        attn = attention(q, k, v, pe=pe)
        del q, k, v, pe

        output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), dim=2))
        del attn, mlp

        x = x + mod_gate * output
        del mod_gate, output

        x = fp16_fix(x)
        return x


# --------------------------------------------------------------------------------
# Last Layer
# --------------------------------------------------------------------------------

class LastLayer(nn.Module):
    def __init__(self, hidden_size, patch_size, out_channels):
        super().__init__()
        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 2 * hidden_size, bias=True)
        )

    def forward(self, x, vec):
        shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
        del vec
        x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
        del scale, shift
        x = self.linear(x)
        return x





# --------------------------------------------------------------------------------
# Final IntegratedFluxTransformer2DModel
# --------------------------------------------------------------------------------

class IntegratedFluxTransformer2DModel(nn.Module):
    def __init__(
        self,
        in_channels: int,
        vec_in_dim: int,
        context_in_dim: int,
        hidden_size: int,
        mlp_ratio: float,
        num_heads: int,
        depth: int,
        depth_single_blocks: int,
        axes_dim: list[int],
        theta: int,
        qkv_bias: bool,
        guidance_embed: bool
    ):
        super().__init__()

        self.guidance_embed = guidance_embed
        self.in_channels = in_channels * 4
        self.out_channels = self.in_channels

        if hidden_size % num_heads != 0:
            raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}")

        pe_dim = hidden_size // num_heads
        if sum(axes_dim) != pe_dim:
            raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")

        self.hidden_size = hidden_size
        self.num_heads = num_heads

        self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
        self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
        self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
        self.vector_in = MLPEmbedder(vec_in_dim, self.hidden_size)
        self.guidance_in = (
            MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
            if guidance_embed
            else nn.Identity()
        )
        self.txt_in = nn.Linear(context_in_dim, self.hidden_size)

        self.double_blocks = nn.ModuleList([
            DoubleStreamBlock(
                self.hidden_size,
                self.num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
            )
            for _ in range(depth)
        ])

        self.single_blocks = nn.ModuleList([
            SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio)
            for _ in range(depth_single_blocks)
        ])

        self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)

    def inner_forward(
        self,
        img,
        img_ids,
        txt,
        txt_ids,
        timesteps,
        y,
        guidance=None,
        guidance_method="inverted_cosine"
    ):
        """
        The main forward pass that merges image and text embeddings with
        optional guidance embedding for classifier-free guidance or related tasks.

        guidance_method: str
            One of:
                "cosine", "inverted_cosine", "sin", "linear_increase",
                "linear_decrease", "random_noise", "random_gaussian", "random_extreme"
        """
        if img.ndim != 3 or txt.ndim != 3:
            raise ValueError("Input img and txt tensors must have 3 dimensions.")

        # Image patches => embed
        img = self.img_in(img)

        # Timestep => embed
        vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))

        # Guidance (if enabled)
        if self.guidance_embed:
            if guidance is None:
                raise ValueError("Missing 'guidance' for guidance-distilled model.")
            # Translate guidance with a chosen method, using the same normalized 'timesteps'
            method_guidance = translate_guidance(timesteps, guidance, img.device, guidance_method)
            vec = vec + self.guidance_in(timestep_embedding(method_guidance, 256).to(img.dtype))

        # Additional vector conditioning
        vec = vec + self.vector_in(y)

        # Text => embed
        txt = self.txt_in(txt)
        del y, guidance

        # Positional IDs => embed
        ids = torch.cat((txt_ids, img_ids), dim=1)
        del txt_ids, img_ids
        pe = self.pe_embedder(ids)
        del ids

        # Double-Stream blocks (joint image/text attention)
        for block in self.double_blocks:
            img, txt = block(img=img, txt=txt, vec=vec, pe=pe)

        # Merge text+image, Single-Stream blocks
        img = torch.cat((txt, img), 1)
        for block in self.single_blocks:
            img = block(img, vec=vec, pe=pe)
        del pe

        # Separate out the image portion
        img = img[:, txt.shape[1]:, ...]
        del txt

        # Final output
        img = self.final_layer(img, vec)
        del vec
        return img

    def forward(self, x, timestep, context, y, guidance=None, **kwargs):
        bs, c, h, w = x.shape
        input_device = x.device
        input_dtype = x.dtype

        # Patchify the input
        patch_size = 2
        pad_h = (patch_size - x.shape[-2] % patch_size) % patch_size
        pad_w = (patch_size - x.shape[-1] % patch_size) % patch_size
        x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode="circular")

        img = rearrange(x, "b c (hh ph) (ww pw) -> b (hh ww) (c ph pw)",
                        ph=patch_size, pw=patch_size)
        del x, pad_h, pad_w

        # Dimensions after patchification
        h_len = (h + (patch_size // 2)) // patch_size
        w_len = (w + (patch_size // 2)) // patch_size

        # Build image IDs
        img_ids = torch.zeros((h_len, w_len, 3),
                              device=input_device, dtype=input_dtype)
        # The 'height' index
        img_ids[..., 1] = torch.linspace(
            0, h_len - 1, steps=h_len, device=input_device, dtype=input_dtype
        )[:, None]
        # The 'width' index
        img_ids[..., 2] = torch.linspace(
            0, w_len - 1, steps=w_len, device=input_device, dtype=input_dtype
        )[None, :]

        img_ids = repeat(img_ids, "hh ww c -> b (hh ww) c", b=bs)

        # Build text IDs (dummy: all zeros except possible batch dimension)
        txt_ids = torch.zeros((bs, context.shape[1], 3),
                              device=input_device, dtype=input_dtype)
        del input_device, input_dtype

        # Pass to inner_forward
        # We can specify guidance_method in kwargs if we want something else
        out = self.inner_forward(
            img,
            img_ids,
            context,
            txt_ids,
            timesteps=timestep,
            y=y,
            guidance=guidance,
            # guidance_method=kwargs.get("guidance_method", "inverted_cosine")
        )
        del img, img_ids, txt_ids, timestep, context

        # Un-patchify the output
        out = rearrange(
            out,
            "b (hh ww) (c ph pw) -> b c (hh ph) (ww pw)",
            hh=h_len, ww=w_len, ph=2, pw=2
        )[:, :, :h, :w]
        del h_len, w_len, bs
        return out

Experiment 1: direct control.

Thanks to a combination effort and thought experiment I had with Felldude, we have made a potential discovery here.

First; we have diffusion_engine/flux.py

import torch

from huggingface_guess import model_list
from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects
from backend.patcher.clip import CLIP
from backend.patcher.vae import VAE
from backend.patcher.unet import UnetPatcher
from backend.text_processing.classic_engine import ClassicTextProcessingEngine
from backend.text_processing.t5_engine import T5TextProcessingEngine
from backend.args import dynamic_args
from backend.modules.k_prediction import PredictionFlux
from backend import memory_management


class Flux(ForgeDiffusionEngine):
    matched_guesses = [model_list.Flux, model_list.FluxSchnell]

    def __init__(self, estimated_config, huggingface_components):
        super().__init__(estimated_config, huggingface_components)
        self.is_inpaint = False

        clip = CLIP(
            model_dict={
                'clip_l': huggingface_components['text_encoder'],
                't5xxl': huggingface_components['text_encoder_2']
            },
            tokenizer_dict={
                'clip_l': huggingface_components['tokenizer'],
                't5xxl': huggingface_components['tokenizer_2']
            }
        )

        vae = VAE(model=huggingface_components['vae'])

        if 'schnell' in estimated_config.huggingface_repo.lower():
            k_predictor = PredictionFlux(
                mu=1.0
            )
        else:
            k_predictor = PredictionFlux(
                seq_len=4096,
                base_seq_len=256,
                max_seq_len=4096,
                base_shift=0.5,
                max_shift=1.15,
            )
            self.use_distilled_cfg_scale = True

        unet = UnetPatcher.from_model(
            model=huggingface_components['transformer'],
            diffusers_scheduler=None,
            k_predictor=k_predictor,
            config=estimated_config
        )

        self.text_processing_engine_l = ClassicTextProcessingEngine(
            text_encoder=clip.cond_stage_model.clip_l,
            tokenizer=clip.tokenizer.clip_l,
            embedding_dir=dynamic_args['embedding_dir'],
            embedding_key='clip_l',
            embedding_expected_shape=768,
            emphasis_name=dynamic_args['emphasis_name'],
            text_projection=False,
            minimal_clip_skip=1,
            clip_skip=1,
            return_pooled=True,
            final_layer_norm=True,
        )

        self.text_processing_engine_t5 = T5TextProcessingEngine(
            text_encoder=clip.cond_stage_model.t5xxl,
            tokenizer=clip.tokenizer.t5xxl,
            emphasis_name=dynamic_args['emphasis_name'],
        )

        self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
        self.forge_objects_original = self.forge_objects.shallow_copy()
        self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()

    def set_clip_skip(self, clip_skip):
        self.text_processing_engine_l.clip_skip = clip_skip

    @torch.inference_mode()
    def get_learned_conditioning(self, prompt: list[str]):
        memory_management.load_model_gpu(self.forge_objects.clip.patcher)
        cond_l, pooled_l = self.text_processing_engine_l(prompt)
        cond_t5 = self.text_processing_engine_t5(prompt)
        cond = dict(crossattn=cond_t5, vector=pooled_l)

        if self.use_distilled_cfg_scale:
            distilled_cfg_scale = getattr(prompt, 'distilled_cfg_scale', 3.5) or 3.5
            cond['guidance'] = torch.FloatTensor([distilled_cfg_scale] * len(prompt))
            print(f'Distilled CFG Scale: {distilled_cfg_scale}')
        else:
            print('Distilled CFG Scale will be ignored for Schnell')

        return cond

    @torch.inference_mode()
    def get_prompt_lengths_on_ui(self, prompt):
        token_count = len(self.text_processing_engine_t5.tokenize([prompt])[0])
        return token_count, max(255, token_count)

    @torch.inference_mode()
    def encode_first_stage(self, x):
        sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
        sample = self.forge_objects.vae.first_stage_model.process_in(sample)
        return sample.to(x)

    @torch.inference_mode()
    def decode_first_stage(self, x):
        sample = self.forge_objects.vae.first_stage_model.process_out(x)
        sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
        return sample.to(x)

I tinkered with this for a bit, but noticed that forge only passes one distillation config per, so I couldn't do what I wanted with it.

Experiment 2: Deeper I go

Then I went deeper.

Down the rabbit hole I went, until I found the guts of Flux itself.

# Single File Implementation of Flux with aggressive optimizations, Copyright Forge 2024
# If used outside Forge, only non-commercial use is allowed.
# See also https://github.com/black-forest-labs/flux


import math
import torch

from torch import nn
from einops import rearrange, repeat
from backend.attention import attention_function
from backend.utils import fp16_fix, tensor2parameter


def attention(q, k, v, pe):
    q, k = apply_rope(q, k, pe)
    x = attention_function(q, k, v, q.shape[1], skip_reshape=True)
    return x


def rope(pos, dim, theta):
    if pos.device.type == "mps" or pos.device.type == "xpu":
        scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
    else:
        scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
    omega = 1.0 / (theta ** scale)

    # out = torch.einsum("...n,d->...nd", pos, omega)
    out = pos.unsqueeze(-1) * omega.unsqueeze(0)

    cos_out = torch.cos(out)
    sin_out = torch.sin(out)
    out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
    del cos_out, sin_out

    # out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
    b, n, d, _ = out.shape
    out = out.view(b, n, d, 2, 2)

    return out.float()


def apply_rope(xq, xk, freqs_cis):
    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
    xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
    xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
    del xq_, xk_
    return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)


def timestep_embedding(t, dim, max_period=10000, time_factor=1000.0):
    t = time_factor * t
    half = dim // 2

    # TODO: Once A trainer for flux get popular, make timestep_embedding consistent to that trainer

    # Do not block CUDA steam, but having about 1e-4 differences with Flux official codes:
    freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)

    # Block CUDA steam, but consistent with official codes:
    # freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)

    args = t[:, None].float() * freqs[None]
    del freqs
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    del args
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    if torch.is_floating_point(t):
        embedding = embedding.to(t)
    return embedding


class EmbedND(nn.Module):
    def __init__(self, dim, theta, axes_dim):
        super().__init__()
        self.dim = dim
        self.theta = theta
        self.axes_dim = axes_dim

    def forward(self, ids):
        n_axes = ids.shape[-1]
        emb = torch.cat(
            [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
            dim=-3,
        )
        del ids, n_axes
        return emb.unsqueeze(1)


class MLPEmbedder(nn.Module):
    def __init__(self, in_dim, hidden_dim):
        super().__init__()
        self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
        self.silu = nn.SiLU()
        self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)

    def forward(self, x):
        x = self.silu(self.in_layer(x))
        return self.out_layer(x)


if hasattr(torch, 'rms_norm'):
    functional_rms_norm = torch.rms_norm
else:
    def functional_rms_norm(x, normalized_shape, weight, eps):
        if x.dtype in [torch.bfloat16, torch.float32]:
            n = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + eps) * weight
        else:
            n = torch.rsqrt(torch.mean(x.float() ** 2, dim=-1, keepdim=True) + eps).to(x.dtype) * weight
        return x * n


class RMSNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.weight = None  # to trigger module_profile
        self.scale = nn.Parameter(torch.ones(dim))
        self.eps = 1e-6
        self.normalized_shape = [dim]

    def forward(self, x):
        if self.scale.dtype != x.dtype:
            self.scale = tensor2parameter(self.scale.to(dtype=x.dtype))
        return functional_rms_norm(x, self.normalized_shape, self.scale, self.eps)


class QKNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.query_norm = RMSNorm(dim)
        self.key_norm = RMSNorm(dim)

    def forward(self, q, k, v):
        del v
        q = self.query_norm(q)
        k = self.key_norm(k)
        return q.to(k), k.to(q)


class SelfAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.norm = QKNorm(head_dim)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x, pe):
        qkv = self.qkv(x)

        # q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
        B, L, _ = qkv.shape
        qkv = qkv.view(B, L, 3, self.num_heads, -1)
        q, k, v = qkv.permute(2, 0, 3, 1, 4)
        del qkv

        q, k = self.norm(q, k, v)

        x = attention(q, k, v, pe=pe)
        del q, k, v

        x = self.proj(x)
        return x


class Modulation(nn.Module):
    def __init__(self, dim, double):
        super().__init__()
        self.is_double = double
        self.multiplier = 6 if double else 3
        self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)

    def forward(self, vec):
        out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
        return out


class DoubleStreamBlock(nn.Module):
    def __init__(self, hidden_size, num_heads, mlp_ratio, qkv_bias=False):
        super().__init__()
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        self.num_heads = num_heads
        self.hidden_size = hidden_size
        self.img_mod = Modulation(hidden_size, double=True)
        self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
        self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.img_mlp = nn.Sequential(
            nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
            nn.GELU(approximate="tanh"),
            nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
        )
        self.txt_mod = Modulation(hidden_size, double=True)
        self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
        self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.txt_mlp = nn.Sequential(
            nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
            nn.GELU(approximate="tanh"),
            nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
        )

    def forward(self, img, txt, vec, pe):
        img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate = self.img_mod(vec)

        img_modulated = self.img_norm1(img)
        img_modulated = (1 + img_mod1_scale) * img_modulated + img_mod1_shift
        del img_mod1_shift, img_mod1_scale
        img_qkv = self.img_attn.qkv(img_modulated)
        del img_modulated

        # img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
        B, L, _ = img_qkv.shape
        H = self.num_heads
        D = img_qkv.shape[-1] // (3 * H)
        img_q, img_k, img_v = img_qkv.view(B, L, 3, H, D).permute(2, 0, 3, 1, 4)
        del img_qkv

        img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)

        txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate = self.txt_mod(vec)
        del vec

        txt_modulated = self.txt_norm1(txt)
        txt_modulated = (1 + txt_mod1_scale) * txt_modulated + txt_mod1_shift
        del txt_mod1_shift, txt_mod1_scale
        txt_qkv = self.txt_attn.qkv(txt_modulated)
        del txt_modulated

        # txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
        B, L, _ = txt_qkv.shape
        txt_q, txt_k, txt_v = txt_qkv.view(B, L, 3, H, D).permute(2, 0, 3, 1, 4)
        del txt_qkv

        txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)

        q = torch.cat((txt_q, img_q), dim=2)
        del txt_q, img_q
        k = torch.cat((txt_k, img_k), dim=2)
        del txt_k, img_k
        v = torch.cat((txt_v, img_v), dim=2)
        del txt_v, img_v

        attn = attention(q, k, v, pe=pe)
        del pe, q, k, v
        txt_attn, img_attn = attn[:, :txt.shape[1]], attn[:, txt.shape[1]:]
        del attn

        img = img + img_mod1_gate * self.img_attn.proj(img_attn)
        del img_attn, img_mod1_gate
        img = img + img_mod2_gate * self.img_mlp((1 + img_mod2_scale) * self.img_norm2(img) + img_mod2_shift)
        del img_mod2_gate, img_mod2_scale, img_mod2_shift

        txt = txt + txt_mod1_gate * self.txt_attn.proj(txt_attn)
        del txt_attn, txt_mod1_gate
        txt = txt + txt_mod2_gate * self.txt_mlp((1 + txt_mod2_scale) * self.txt_norm2(txt) + txt_mod2_shift)
        del txt_mod2_gate, txt_mod2_scale, txt_mod2_shift

        txt = fp16_fix(txt)

        return img, txt


class SingleStreamBlock(nn.Module):
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, qk_scale=None):
        super().__init__()
        self.hidden_dim = hidden_size
        self.num_heads = num_heads
        head_dim = hidden_size // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
        self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
        self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
        self.norm = QKNorm(head_dim)
        self.hidden_size = hidden_size
        self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.mlp_act = nn.GELU(approximate="tanh")
        self.modulation = Modulation(hidden_size, double=False)

    def forward(self, x, vec, pe):
        mod_shift, mod_scale, mod_gate = self.modulation(vec)
        del vec
        x_mod = (1 + mod_scale) * self.pre_norm(x) + mod_shift
        del mod_shift, mod_scale
        qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
        del x_mod

        # q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
        qkv = qkv.view(qkv.size(0), qkv.size(1), 3, self.num_heads, self.hidden_size // self.num_heads)
        q, k, v = qkv.permute(2, 0, 3, 1, 4)
        del qkv

        q, k = self.norm(q, k, v)
        attn = attention(q, k, v, pe=pe)
        del q, k, v, pe
        output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), dim=2))
        del attn, mlp

        x = x + mod_gate * output
        del mod_gate, output

        x = fp16_fix(x)

        return x


class LastLayer(nn.Module):
    def __init__(self, hidden_size, patch_size, out_channels):
        super().__init__()
        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
        self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))

    def forward(self, x, vec):
        shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
        del vec
        x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
        del scale, shift
        x = self.linear(x)
        return x


class IntegratedFluxTransformer2DModel(nn.Module):
    def __init__(self, in_channels: int, vec_in_dim: int, context_in_dim: int, hidden_size: int, mlp_ratio: float, num_heads: int, depth: int, depth_single_blocks: int, axes_dim: list[int], theta: int, qkv_bias: bool, guidance_embed: bool):
        super().__init__()

        self.guidance_embed = guidance_embed
        self.in_channels = in_channels * 4
        self.out_channels = self.in_channels

        if hidden_size % num_heads != 0:
            raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}")

        pe_dim = hidden_size // num_heads
        if sum(axes_dim) != pe_dim:
            raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")

        self.hidden_size = hidden_size
        self.num_heads = num_heads

        self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
        self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
        self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
        self.vector_in = MLPEmbedder(vec_in_dim, self.hidden_size)
        self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if guidance_embed else nn.Identity()
        self.txt_in = nn.Linear(context_in_dim, self.hidden_size)

        self.double_blocks = nn.ModuleList(
            [
                DoubleStreamBlock(
                    self.hidden_size,
                    self.num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                )
                for _ in range(depth)
            ]
        )

        self.single_blocks = nn.ModuleList(
            [
                SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio)
                for _ in range(depth_single_blocks)
            ]
        )

        self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)

    def inner_forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None):
        if img.ndim != 3 or txt.ndim != 3:
            raise ValueError("Input img and txt tensors must have 3 dimensions.")
        img = self.img_in(img)
        vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
        if self.guidance_embed:
            if guidance is None:
                raise ValueError("Didn't get guidance strength for guidance distilled model.")
            # Generate random noise within the range [-2, 2]
            random_noise = torch.empty(1).uniform_(-0.75, 0.75).to(img.device)

            # Adjust guidance strength with random noise
            adjusted_guidance = guidance + random_noise

            # Embed the adjusted guidance into the timestep
            vec = vec + self.guidance_in(timestep_embedding(adjusted_guidance, 256).to(img.dtype))

        vec = vec + self.vector_in(y)
        txt = self.txt_in(txt)
        del y, guidance
        ids = torch.cat((txt_ids, img_ids), dim=1)
        del txt_ids, img_ids
        pe = self.pe_embedder(ids)
        del ids
        for block in self.double_blocks:
            img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
        img = torch.cat((txt, img), 1)
        for block in self.single_blocks:
            img = block(img, vec=vec, pe=pe)
        del pe
        img = img[:, txt.shape[1]:, ...]
        del txt
        img = self.final_layer(img, vec)
        del vec
        return img

    def forward(self, x, timestep, context, y, guidance=None, **kwargs):
        bs, c, h, w = x.shape
        input_device = x.device
        input_dtype = x.dtype
        patch_size = 2
        pad_h = (patch_size - x.shape[-2] % patch_size) % patch_size
        pad_w = (patch_size - x.shape[-1] % patch_size) % patch_size
        x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode="circular")
        img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
        del x, pad_h, pad_w
        h_len = ((h + (patch_size // 2)) // patch_size)
        w_len = ((w + (patch_size // 2)) // patch_size)
        img_ids = torch.zeros((h_len, w_len, 3), device=input_device, dtype=input_dtype)
        img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=input_device, dtype=input_dtype)[:, None]
        img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=input_device, dtype=input_dtype)[None, :]
        img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
        txt_ids = torch.zeros((bs, context.shape[1], 3), device=input_device, dtype=input_dtype)
        del input_device, input_dtype
        out = self.inner_forward(img, img_ids, context, txt_ids, timestep, y, guidance)
        del img, img_ids, txt_ids, timestep, context
        out = rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:, :, :h, :w]
        del h_len, w_len, bs
        return out

I noticed a few core things here;

  1. There is a comment in the timestep embedding code.

def timestep_embedding(t, dim, max_period=10000, time_factor=1000.0):
    t = time_factor * t
    half = dim // 2

    # TODO: Once A trainer for flux get popular, make timestep_embedding consistent to that trainer
  1. There is no config directly trickled to this point.

Here you'll see my badly optimized contribution meant for this experiment.

            if guidance is None:
                raise ValueError("Didn't get guidance strength for guidance distilled model.")
            # Generate random noise within the range [-2, 2]
            random_noise = torch.empty(1).uniform_(-0.75, 0.75).to(img.device)

            # Adjust guidance strength with random noise
            adjusted_guidance = guidance + random_noise

            # Embed the adjusted guidance into the timestep
            vec = vec + self.guidance_in(timestep_embedding(adjusted_guidance, 256).to(img.dtype))

Alright, so my hypothesis was simple. If we can control the distillation every step, we have a toggle to determine a few factors. Primarily speed and quality. A nearly completed image, does not necessarily need full distillation if it's a flat 2d image or an anime character, just like a 3d image does not gain anywhere near as much earlier in the distillation configuration.

So I tried a few experiments; bell curve, matching the cosine math, linear reduction, linear increase, inverted cosine, random noise, and a few others.

The outcomes were... very interesting each time.

Higher speed when the cfg reduces later, distorted hands with random, additional details when the cfg is reduced in reduction and then increased at the end. The code posted here is random noise.

The most interesting results were from less distilled cfg to more, and then some reduction near the end but not too much. It also reduces the likelihood of image frying and burning of the edges, or destroying coloration.

Training resnets for rapid difference classification during inference time.

My idea was to introduce and train block-by-block resnet 50s or smaller to capture the remnant roundings while comparing Schnel values to Dev, but that didn't pan out. The numbers are drastically different, that and running both of them took an H100 and I didn't want to run that for very long. The classifiers would have been a value between 0 and 10, 0 being deviance below a certain value, 10 being deviance above a certain value.

The captured data did not pan out however, it's too drastically different, so I moved back to create a bit of a proof of concept tool, rather than going full new model training mode.

On top of that, the implications by reading the Forge code imply that they REALLY don't like passing configs downward into the actual model logic code like how I did... but I did it anywayyyyyyy... So this is most definitely a topical experiment for now; but this code will most definitely work if you place it into forge manually until I prepare the tool.

A new tool to play with later today.

I've replicated a quasi-similar scheduler mathematics to the distillation cfg that is implied by Flux1D and Flux1dPro2 and will be pushing my distillation cfg timestep based tool to a github repo meant for comfyui later today, and formatting careful changes to forge. These two models are nearly identical configuration-wise, and the outcomes from both are so drastically different that it's made me question everything related to distillation.

https://huggingface.co/multimodalart/FLUX.1-dev2pro-full/blob/main/scheduler/scheduler_config.json

This distillation configuration concept started as a bit of a toy, something that could have been a bit of a side-product or a byproduct of maybe remnants of something internal in the Flux architecture itself.

THE BIGGEST TAKEAWAY HERE.

This is CLEARLY meant to be controlled in a DYNAMIC FASHION step-by-step or in a more specific timestep allocated value, and NOT a STATIC FASHION as implied by these inference engines by them giving us a single slider that does a single internal multiplier.

There are many MANY steps in some image requests, that quite frankly finish early. If you reduce the distilled CFG early, you end up having a better outcome overall in less time.

I'll be making a a more thorough write-up with comparison photos, and pushing a working distillation cfg shifter into forge and comfyui for testing and playing with.

9

Comments