This paper describes the use of the fractional Fourier transform (FrFT) as a loss function for training diffusion models.
What you need:
Integrate the formula into your training process (the example below uses
sd_scripts)Fix any errors if they occur.
Steps
1. Add the basic structure to library/train_util.py:
from torch_frft.frft_module import frft
from torch_frft.dfrft_module import dfrft, dfrftmtx
def mse_complex(x, y):
diff = x - y
return torch.mean(diff.real ** 2 + diff.imag ** 2)2. Define fft in the loss_type argument:
parser.add_argument(
"--loss_type",
type=str,
default="l2",
choices=["l1", "l2", "huber", "smooth_l1", "fft"],
help="The type of loss function to use (L1, L2, Huber, or smooth L1); default is L2 / 使用する損失関数の種類(L1、L2、Huber、またはsmooth L1)、デフォルトはL2",
)3. Add FFT loss to the conditional loss section:
elif loss_type == "fft":
loss = mse_complex(dfrft(model_pred.float(), 1.0), dfrft(target.float(), 1.0))In some cases, you may need to use:
elif loss_type == "fft":
with torch.cuda.amp.autocast(enabled=False):
loss = mse_complex(dfrft(model_pred.float(), 1.0), dfrft(target.float(), 1.0))4. Set --loss_type="fft" in your training configuration.
Notes
FFT is quite flexible and allows for complex loss constructions.
For example:
def fft_loss(model_pred, target, timesteps, noise_scheduler,
alpha_min=0.0, alpha_max=2.0, eps=1e-6, c=0.1, gamma=0.75):
# === Per-sample normalization ===
timesteps = timesteps.float()
t_norm = 1.0 - (timesteps / (noise_scheduler.config.num_train_timesteps - 1))
alpha = alpha_min + t_norm * (alpha_max - alpha_min)
# === Fractional transform per sample ===
pred_frft_list = []
tgt_frft_list = []
for i in range(model_pred.shape[0]):
a = alpha[i].item()
pred_frft_list.append(dfrft(model_pred[i].float(), a))
tgt_frft_list.append(dfrft(target[i].float(), a))
pred_frft = torch.stack(pred_frft_list)
tgt_frft = torch.stack(tgt_frft_list)
# === Difference in FrFT space ===
diff = pred_frft - tgt_frft
# === Hypercube projection of errors ===
diff_hyper = torch.sgn(diff) * (torch.abs(diff) + eps) ** gamma
# === Log-Charbonnier in the distorted space ===
charbonnier = torch.log1p(torch.sqrt(diff_hyper.real**2 + diff_hyper.imag**2 + eps**2))
# === SNR weighting from the scheduler ===
alphas_cumprod = torch.index_select(
noise_scheduler.alphas_cumprod, 0,
timesteps.to(noise_scheduler.alphas_cumprod.device).long()
)
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
snr_weight = (1 - c) / (1 + sigmas)**2 + c
snr_weight = snr_weight.to(timesteps.device).view(-1, *([1] * (charbonnier.ndim - 1)))
return (snr_weight * charbonnier).mean()Be creative.

![FFT loss implementation [SDXL] [sd_scripts]](https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/5ec85c83-dc49-495c-9bb6-4e30cdbc0b01/width=1320/874798505337247.jpeg)