santa hat
deerdeer nosedeer glow
Sign In

What LoRA alpha actually does (in theory)

Sep 07, 2023

TL;DR: scaling alpha by k is equivalent to simultaneously scaling the learning rate by k and the initialization weights by sqrt(k).

The alpha hyperparameter in LoRAs is not well understood. Some say it "dampens learning", some say it acts like the learning rate, but no clear evidence was given to support the claims. This is my take on alpha: with the rank fixed, multiplying alpha by k while dividing the learning rate by k and the LoRA initialization weights by sqrt(k) is a no-op. In other words, scaling alpha by k is equivalent to simultaneously scaling the learning rate by k and the initialization weights by sqrt(k). This means that instead of tuning alpha, we should be tuning the learning rate and initialization.

LoRAs are residual layers that add A B * scaling to the original output, where A and B are low rank matrices and scaling is a constant given by alpha / rank. The weight initialization is independent of alpha. Below is then the rough derivation:

Gradient descent when optimizing A and B to the loss function f is
A := A - lr * D(f(AB), A)
B := B - lr * D(f(AB), B)

when we scale alpha by k, we perform a change of variables AB = k UV and optimize U and V
U := U - c * lr * D(f(k UV), U)
V := V - c * lr * D(f(k UV), V)
where c is a lr scaling factor.

We wish to make these two methods output the same loss value at each gradient descent step. We can then write:

D(f(AB), A_ij) / A_ij = c D(f(k UV), U_ij) / U_ij
D(f(AB), B_ij) / B_ij = c D(f(k UV), V_ij) / V_ij

This has a solution:
U = A / sqrt(k)
V = B / sqrt(k)
c = 1 / k

What does this tell us about how to adjust alpha and the learning rate when scaling the rank? Unfortunately, not much, since LoRAs with different ranks have different learning dynamics and can't be directly compared as such. Hopefully, this will be explored empirically in a follow-up post.


To sanity check the result, I used this Python script:

import torch
import loralib as lora
from pprint import pprint
import random

N = 4
M = 16
s = random.randint(1, 99999999)
for alpha in (1/4, 1, 4, 16, 64, 256):
  X, Y = torch.randn((M, N)), torch.randn((M, N))
  layer = lora.Linear(N, N, 2, lora_alpha=alpha, bias=False)
  with torch.no_grad():
    layer.lora_A /= alpha ** .5
  print(f"State before opt ({alpha=}):")
  opt = torch.optim.SGD(layer.parameters(), lr=1e-1 / alpha)
  for _ in range(5):
    loss = torch.nn.functional.l1_loss(layer(X), Y)
  print(f"State after opt ({alpha=}):")
  pprint({k: v for k, v in layer.named_parameters() if v.requires_grad})
  print('-' * 20)