Sign In

FFT loss implementation [SDXL] [sd_scripts]

0

FFT loss implementation [SDXL] [sd_scripts]

This paper describes the use of the fractional Fourier transform (FrFT) as a loss function for training diffusion models.

What you need:

  1. Install https://github.com/tunakasif/torch-frft

  2. Integrate the formula into your training process (the example below uses sd_scripts)

  3. Fix any errors if they occur.

Steps

1. Add the basic structure to library/train_util.py:

from torch_frft.frft_module import frft
from torch_frft.dfrft_module import dfrft, dfrftmtx

def mse_complex(x, y):
    diff = x - y
    return torch.mean(diff.real ** 2 + diff.imag ** 2)

2. Define fft in the loss_type argument:

parser.add_argument(
    "--loss_type",
    type=str,
    default="l2",
    choices=["l1", "l2", "huber", "smooth_l1", "fft"],
    help="The type of loss function to use (L1, L2, Huber, or smooth L1); default is L2 / 使用する損失関数の種類(L1、L2、Huber、またはsmooth L1)、デフォルトはL2",
)

3. Add FFT loss to the conditional loss section:

elif loss_type == "fft":
    loss = mse_complex(dfrft(model_pred.float(), 1.0), dfrft(target.float(), 1.0))

In some cases, you may need to use:

elif loss_type == "fft":
    with torch.cuda.amp.autocast(enabled=False):
        loss = mse_complex(dfrft(model_pred.float(), 1.0), dfrft(target.float(), 1.0))

4. Set --loss_type="fft" in your training configuration.

Notes

FFT is quite flexible and allows for complex loss constructions.
For example:

def fft_loss(model_pred, target, timesteps, noise_scheduler,
             alpha_min=0.0, alpha_max=2.0, eps=1e-6, c=0.1, gamma=0.75):
    # === Per-sample normalization ===
    timesteps = timesteps.float()
    t_norm = 1.0 - (timesteps / (noise_scheduler.config.num_train_timesteps - 1))
    alpha = alpha_min + t_norm * (alpha_max - alpha_min)

    # === Fractional transform per sample ===
    pred_frft_list = []
    tgt_frft_list = []
    for i in range(model_pred.shape[0]):
        a = alpha[i].item()
        pred_frft_list.append(dfrft(model_pred[i].float(), a))
        tgt_frft_list.append(dfrft(target[i].float(), a))

    pred_frft = torch.stack(pred_frft_list)
    tgt_frft  = torch.stack(tgt_frft_list)

    # === Difference in FrFT space ===
    diff = pred_frft - tgt_frft

    # === Hypercube projection of errors ===
    diff_hyper = torch.sgn(diff) * (torch.abs(diff) + eps) ** gamma

    # === Log-Charbonnier in the distorted space ===
    charbonnier = torch.log1p(torch.sqrt(diff_hyper.real**2 + diff_hyper.imag**2 + eps**2))

    # === SNR weighting from the scheduler ===
    alphas_cumprod = torch.index_select(
        noise_scheduler.alphas_cumprod, 0,
        timesteps.to(noise_scheduler.alphas_cumprod.device).long()
    )
    sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
    snr_weight = (1 - c) / (1 + sigmas)**2 + c
    snr_weight = snr_weight.to(timesteps.device).view(-1, *([1] * (charbonnier.ndim - 1)))

    return (snr_weight * charbonnier).mean()

Be creative.

0