Targeted Unsupervised Machine Unlearning in Latent Diffusion (UNet or Transformer Denoiser)
0) Problem statement and assumptions
We want Targeted Unsupervised Machine Unlearning for a latent text-to-image diffusion model.
Assumptions:
- A folder called Forget/ contains images of a concept to erase (for example, cats).
- No labels, captions, or boxes are available.
- The model is a latent diffusion model, meaning it uses a VAE encoder/decoder to move between image space and latent space.
- We freeze the VAE and all prompt encoders (text encoders). We unlearn by modifying the denoiser only.
We cover both denoiser types:
- UNet denoiser (SDXL-style)
- Transformer denoiser (diffusion transformer variants)
Goal:
- Remove the target concept while preserving everything else (styles, other objects, composition priors).
1) The Teacher-Student Scrubber (summary)
We unlearn by training a Student denoiser to behave like a Teacher denoiser that is “blind” to the concept region.
- Teacher: a frozen copy of the original denoiser, used for (1) stable feature extraction and (2) stable regression targets. In another word, the original, pretrained denoiser (frozen weights, no learning). It represents “what the model used to know”. It gives you a stable reference.
- Student: trainable denoiser updated to erase the concept. A copy you allow to change (trainable weights). This is the model you are editing (unlearning).
Key idea:
- Detect “concept-like” regions in the latent using a frozen teacher feature space.
- Create a manifold-consistent “blinded latent” by corrupting the clean latent in the detected region, then re-noising it.
- Distill the Student so that inside the mask it matches the Teacher’s prediction on the blinded latent (inpainting), and outside the mask it matches the Teacher on the original latent (preservation).
- Add a small retain set anchor to prevent soft global degradation.
Phase 1: Learn the concept fingerprint (prototypes) from Forget/
Forget images (no labels)
|
v
VAE encoder (frozen)
|
v
clean latents z0
|
v
add noise at semantic t -> z_t
|
v
Teacher denoiser (frozen)
(tap mid-block features F)
|
v
flatten features into vectors
|
v
K-means clustering
|
v
prototypes P = {p1..pK}
(concept fingerprint)
Phase 2: Unlearning training loop (Teacher guides Student)
(per training step, on a batch of forget images)
forget images
|
v
VAE encoder (frozen) ---------------------> clean latents z0
|
v
sample timestep t and noise eps
|
v
make noisy latent: z_t = add_noise(z0, eps, t)
|
+-------------------------------+
| |
| |
v v
Teacher (frozen) Student (trainable)
on z_t on z_t
| |
| v
| prediction S
|
+--> extract features F
|
+--> compare F to prototypes P
| (cosine similarity)
|
+--> build mask M (0..1)
"where the concept is"
Now create a blinded version for the Teacher target:
z0_blind = (1 - M) * z0 + M * random_texture
z_t_blind = add_noise(z0_blind, eps, t)
Teacher targets:
T_orig = Teacher(z_t) (preserve outside concept)
T_blind = Teacher(z_t_blind) (inpaint inside concept)
Losses:
Inside mask: Student(z_t) should match T_blind
Outside mask: Student(z_t) should match T_orig
Optional retain: Student should match Teacher on retain images
2) Notation
Datasets and signals:
Forget dataset:
Retain dataset
Image:
- : image height (pixels)
- : image width (pixels)
Frozen VAE:
- Frozen VAE encoder:
- Frozen VAE decoder:
- : frozen VAE parameters
- VAE scale factor: (read from the VAE config)
Clean latent:
where:
- : clean latent (latent representation of the image before forward diffusion)
- : scalar latent scaling factor from the VAE config
- : VAE-encoded latent of image (x)
- : input image
- : VAE parameters (frozen)
Forward noising:
where:
- : noisy latent at timestep
- : diffusion timestep index (integer)
- : scheduler coefficient multiplying the clean signal at timestep
- : scheduler coefficient multiplying the noise at timestep
- : clean latent (defined above)
- : Gaussian noise sample
- : multivariate normal with mean (0) and covariance
- : identity matrix (dimension matches the latent dimensionality)
Denoisers:
- Student denoiser :
- Teacher denoiser (frozen):
Conditioning:
- Conditioning bundle: (depends on model, SDXL uses micro-conditioning, transformer families vary)
Teacher feature extractor:
- Feature extractor from Teacher internals:
UNet vs Transformer teacher feature shapes:
- UNet case:
- Transformer case:
Where:
- : number of feature channels (UNet feature map)
- : feature map height and width
- : number of tokens (image tokens after selecting the right slice)
- : token feature dimension
3) Phase 1: Unsupervised fingerprinting (prototype discovery)
We learn a “concept fingerprint” from Forget/ via clustering, using Teacher features so the representation does not drift during training.
3.1 Feature extraction (semantic timestep band)
For each :
- Encode into
- Sample timestep from a semantic band (for example 300 to 700 out of 1000 steps)
- Build using the forward noising equation
- Extract teacher features
Convert features into vectors:
- UNet: each spatial location becomes a vector
- Transformer: each image token is already a vector
Subsample vectors per image to avoid memory blow-up.
3.2 Clustering (multi-prototype fingerprint)
Let be the set of extracted vectors for image . We cluster all vectors to obtain prototypes:
where:
- : prototype (cluster center) number
- : number of prototypes (clusters)
- : K-Means clustering that outputs centers
- : set of vectors extracted from image
- : forget dataset (defined earlier)
Normalize each prototype:
where:
- : prototype vector
- : assignment (replace with its normalized version)
Interpretation:
- Each captures a sub-feature of the concept (shape cues, textures, parts).
4) Phase 2: Detection + (z_0)-blinding + teacher distillation
4.1 Stable detection (frozen eye)
Given , compute teacher features . Then compute max cosine similarity to the prototypes per location.
First normalize each local feature vector:
- UNet:
- Transformer:
Then compute similarity scores.
UNet (location is in the teacher feature map):
where:
- : similarity score at feature-map location
- : maximum over prototype index
- : cosine similarity between vectors and
- : L2-normalized teacher feature vector at location
Transformer (location is token index ):
where:
- : similarity score for token
- : maximum over prototype index
- : cosine similarity
- : L2-normalized teacher feature vector for token
- : prototype
Soft mask:
where:
- : soft mask values in , per location (grid for UNet, tokens for transformer)
- : sigmoid function,
- : slope controlling sharpness of the mask
- : similarity field (either or
- : similarity threshold that defines “concept-like”
- : Euler’s number (base of natural logarithm)
Typical values:
Reshape and upsample:
- UNet: reshape into a grid
- Transformer: reshape (M) into a token grid where
Finally upsample to latent resolution and broadcast across latent channels.
4.2 Manifold-consistent blinding (blind , then re-noise)
Blind the clean latent:
where:
- : blinded clean latent
- : upsampled mask at latent resolution
- : complement of the mask (keeps background)
- : elementwise (Hadamard) product
- : clean latent
- : replacement Gaussian noise for the masked region
- : standard Gaussian distribution
- : identity covariance
Re-noise with the same timestep (t) and the same forward noise :
where:
- : noisy latent built from the blinded clean latent
- : scheduler signal coefficient at timestep (t)\)
- : blinded clean latent
- : scheduler noise coefficient at timestep (t)\)
- : the same Gaussian noise used when building (z_t)\)
5) Three-part loss
Teacher predictions:
where:
- : teacher noise prediction on the original
- : teacher denoiser with frozen parameters
- : noisy latent at timestep
- : timestep
- : conditioning bundle
where:
- : teacher noise prediction on the blinded noisy latent
- : teacher denoiser
- : blinded noisy latent
- : timestep
- : conditioning bundle
Student prediction:
where:
- : student noise prediction
- : student denoiser with trainable parameters
- : noisy latent
- : timestep
- : conditioning bundle
Losses:
- Forget loss (inside mask):
where:
- : forgetting loss
- : squared L2 (squared Frobenius norm over all tensor elements)
- : student prediction
- : teacher prediction on blinded latent
- : elementwise product
- Background preservation (outside mask):
where:
- : background preservation loss
- : student prediction
- : teacher prediction on original latent
- : mask complement
- : elementwise product
- : squared L2 norm
- Retain anchor (on retain data):
where:
- : retain anchoring loss
- : student denoiser
- : teacher denoiser
- : noisy latent built from a retain image (superscript (r) means retain)
- : timestep sampled for retain batch
- : conditioning for retain batch
- : squared L2 norm
Total loss:
where:
- : total training loss
- : weight for forget loss
- : weight for background loss
- : weight for retain loss
- : component losses defined above
6) Why this works (what matters, briefly)
- Freezing the Teacher makes the feature space stationary, so the detector does not drift during training.
- Using Teacher predictions on a blinded latent produces low-variance, context-aware targets (inpainting) instead of unstable random-noise targets.
- The split loss (mask vs background) prevents the Student from changing unrelated regions.
- Retain anchoring reduces soft global degradation due to parameter entanglement.
What is different between UNet and Transformer denoisers, and why
Difference 1: what “locations” mean
- UNet gives a spatial feature map , so your mask is naturally 2D.
- Transformer gives token features , so your mask is per-token, then reshaped to a 2D token grid.
Why: UNets operate on feature maps, transformers operate on sequences of patch tokens.
Difference 2: where to hook
- UNet: hook
mid_blockoutput. - Transformer: hook a middle transformer block output (hidden states for image tokens).
Why: mid-depth representations tend to be the most semantic and stable for clustering and detection.
Difference 3: conditioning plumbing
- SDXL UNet needs SDXL micro-conditioning (time ids, pooled text embeds).
- Transformer pipelines vary. Some use cross-attention, some use joint attention, some use different
added_cond_kwargs.
Why: transformer denoisers are not standardized across all diffusion families. The algorithm is unchanged, but the call signature and conditioning keys are model-specific.
Code: one scrubber, two adapters (UNet vs Transformer)
This code makes the distinction explicit. You plug in:
- a UNetAdapter for SDXL UNet, or
- a TransformerAdapter for diffusion transformers
It is designed as a template because transformer pipelines differ. The adapter is where you adjust forward signatures and token-grid shape.
Important implementation note (hook correctness): the feature hook must be registered on the Teacher instance that you actually call inside
features(). The version below does that by building the adapter from the teacher model via anadapter_factory.
import copy
import numpy as np
from sklearn.cluster import MiniBatchKMeans
import torch
import torch.nn.functional as F
# ============================================================
# Adapters (this is where UNet vs Transformer differ)
# ============================================================
class BaseDenoiserAdapter:
"""
Defines a common interface for UNet or Transformer denoisers.
Must provide:
- forward(model, z_t, t, cond) -> prediction tensor (same shape as z_t)
- features(model, z_t, t, cond) -> (vectors, grid_hw)
vectors: [B, L, D] where L = H*W (UNet) or N tokens (Transformer)
grid_hw: (H_loc, W_loc) so we can reshape mask to 2D before upsampling
"""
def forward(self, model, z_t, t, cond):
raise NotImplementedError
def features(self, model, z_t, t, cond):
raise NotImplementedError
class UNetMidBlockAdapter(BaseDenoiserAdapter):
"""
Adapter for SDXL-style UNet denoisers.
Features come from mid_block output: [B, C, Hf, Wf]
"""
def __init__(self, unet):
self._feat = None
def hook_fn(module, inputs, output):
if isinstance(output, (tuple, list)):
output = output[0]
self._feat = output
unet.mid_block.register_forward_hook(hook_fn)
def forward(self, model, z_t, t, cond):
# Diffusers UNet returns object with .sample
return model(z_t, t, **cond).sample
def features(self, model, z_t, t, cond):
_ = model(z_t, t, **cond) # fills hook
feats = self._feat # [B, C, Hf, Wf]
if feats is None:
raise RuntimeError("UNet mid_block hook failed.")
B, C, Hf, Wf = feats.shape
vec = feats.permute(0, 2, 3, 1).reshape(B, Hf * Wf, C) # [B, L, D]
return vec, (Hf, Wf)
class TransformerBlockAdapter(BaseDenoiserAdapter):
"""
Adapter for diffusion-transformer denoisers.
You must:
- pick a block to hook (often middle depth)
- define how many image tokens exist, and their grid shape (Ht, Wt)
- ensure the hook captures image-token hidden states, not text tokens
This is a template because different models structure blocks differently.
"""
def __init__(self, transformer, block, token_grid_hw, image_token_slice=None):
"""
transformer: the denoiser module
block: a module inside transformer to hook (for example transformer.blocks[mid])
token_grid_hw: (Ht, Wt) so N == Ht*Wt for image tokens
image_token_slice: optional slice to select only image tokens if output includes text tokens
"""
self.token_grid_hw = token_grid_hw
self.image_token_slice = image_token_slice
self._feat = None
def hook_fn(module, inputs, output):
# output often is hidden states [B, N, D], but may be tuple
if isinstance(output, (tuple, list)):
output = output[0]
self._feat = output
block.register_forward_hook(hook_fn)
def forward(self, model, z_t, t, cond):
# You may need to adjust this to your model signature.
out = model(z_t, t, **cond)
# Many models expose .sample like UNet, some return tensor directly.
return out.sample if hasattr(out, "sample") else out
def features(self, model, z_t, t, cond):
_ = model(z_t, t, **cond) # fills hook
feats = self._feat # expected [B, N, D]
if feats is None:
raise RuntimeError("Transformer hook failed, no features captured.")
if self.image_token_slice is not None:
feats = feats[:, self.image_token_slice, :]
B, N, D = feats.shape
Ht, Wt = self.token_grid_hw
if N != Ht * Wt:
raise ValueError(f"Token grid mismatch: N={N}, but Ht*Wt={Ht*Wt}")
return feats, (Ht, Wt)
# ============================================================
# Shared Teacher-Student Scrubber (model-agnostic)
# ============================================================
class TeacherStudentScrubber:
"""
Works for either UNet or Transformer denoisers via adapters.
You must supply:
- pipe.vae (frozen)
- pipe.scheduler
- denoiser module (pipe.unet for SDXL, or pipe.transformer/pipe.denoiser for transformer models)
- a cond_provider(batch_size) -> conditioning dict
- an adapter_factory(teacher_model) -> adapter instance
(important so feature hooks are registered on the teacher copy)
"""
def __init__(
self,
pipe,
denoiser,
adapter_factory,
cond_provider,
device="cuda",
num_clusters=5,
t_low=300,
t_high=700,
max_vecs_per_image=2000,
):
self.pipe = pipe
self.device = device
self.cond_provider = cond_provider
self.scheduler = pipe.scheduler
self.num_train_timesteps = getattr(self.scheduler.config, "num_train_timesteps", 1000)
# VAE frozen in float32
self.vae = pipe.vae.to(device, dtype=torch.float32).eval()
self.vae.requires_grad_(False)
self.vae_scale = float(getattr(self.vae.config, "scaling_factor", 0.18215))
# Student and Teacher denoisers
self.student = denoiser.to(device)
self.student.train()
self.teacher = copy.deepcopy(denoiser).to(device)
self.teacher.eval()
self.teacher.requires_grad_(False)
# Adapter must be built from TEACHER (so hooks attach to teacher modules)
self.adapter = adapter_factory(self.teacher)
# Prototypes
self.num_clusters = num_clusters
self.t_low = t_low
self.t_high = t_high
self.max_vecs_per_image = max_vecs_per_image
self.prototypes = None
@torch.no_grad()
def encode_images_to_latents(self, images: torch.Tensor) -> torch.Tensor:
images = images.to(self.device, dtype=torch.float32)
z0 = self.vae.encode(images).latent_dist.sample() * self.vae_scale
return z0
@torch.no_grad()
def fit_prototypes(self, forget_loader):
vectors = []
for batch in forget_loader:
z0 = self.encode_images_to_latents(batch) # float32
B = z0.shape[0]
# timestep in semantic band
t_low = max(0, min(self.t_low, self.num_train_timesteps - 1))
t_high = max(1, min(self.t_high, self.num_train_timesteps))
t = torch.randint(t_low, t_high, (B,), device=self.device).long()
eps = torch.randn_like(z0)
z_t = self.scheduler.add_noise(z0, eps, t)
cond = self.cond_provider(B)
# features from TEACHER, adapter returns [B, L, D]
feat_vec, _grid_hw = self.adapter.features(self.teacher, z_t.to(self.teacher.dtype), t, cond)
# flatten batch to [B*L, D]
flat = feat_vec.reshape(-1, feat_vec.shape[-1])
# subsample
max_total = self.max_vecs_per_image * B
if flat.shape[0] > max_total:
idx = torch.randperm(flat.shape[0], device=flat.device)[:max_total]
flat = flat[idx]
vectors.append(flat.float().cpu().numpy())
data = np.concatenate(vectors, axis=0)
kmeans = MiniBatchKMeans(
n_clusters=self.num_clusters,
batch_size=8192,
random_state=42,
n_init="auto",
)
kmeans.fit(data)
protos = torch.tensor(kmeans.cluster_centers_, device=self.device, dtype=self.teacher.dtype)
self.prototypes = F.normalize(protos, dim=1)
return self.prototypes
@torch.no_grad()
def compute_mask(self, z_t, t, cond, tau=0.60, alpha=15.0):
if self.prototypes is None:
raise RuntimeError("Call fit_prototypes() first.")
feat_vec, (Hloc, Wloc) = self.adapter.features(self.teacher, z_t.to(self.teacher.dtype), t, cond)
# feat_vec: [B, L, D] where L = Hloc*Wloc
B, L, D = feat_vec.shape
feat_norm = F.normalize(feat_vec, dim=-1) # [B, L, D]
sim = torch.matmul(feat_norm, self.prototypes.t()) # [B, L, K]
max_sim, _ = torch.max(sim, dim=-1) # [B, L]
mask_loc = torch.sigmoid((max_sim - tau) * alpha) # [B, L]
mask_loc = mask_loc.view(B, 1, Hloc, Wloc) # [B,1,Hloc,Wloc]
# upsample to latent resolution
mask = F.interpolate(mask_loc, size=z_t.shape[-2:], mode="nearest")
return mask
def training_step(
self,
batch_forget: torch.Tensor,
batch_retain: torch.Tensor = None,
tau=0.60,
alpha=15.0,
lambda_f=1.0,
lambda_b=1.0,
lambda_r=0.5,
):
# Forget batch latents
with torch.no_grad():
z0 = self.encode_images_to_latents(batch_forget).to(self.student.dtype)
B = z0.shape[0]
cond = self.cond_provider(B)
# noise and timestep
t = torch.randint(0, self.num_train_timesteps, (B,), device=self.device).long()
eps = torch.randn_like(z0)
z_t = self.scheduler.add_noise(z0, eps, t)
# Teacher detection and targets
with torch.no_grad():
M = self.compute_mask(z_t, t, cond, tau=tau, alpha=alpha)
# z0-blinding then re-noise with same eps and t
eta = torch.randn_like(z0)
z0_blind = (1.0 - M) * z0 + M * eta
z_t_blind = self.scheduler.add_noise(z0_blind, eps, t)
T_blind = self.adapter.forward(self.teacher, z_t_blind.to(self.teacher.dtype), t, cond)
T_orig = self.adapter.forward(self.teacher, z_t.to(self.teacher.dtype), t, cond)
# Student prediction
S = self.adapter.forward(self.student, z_t, t, cond)
loss_forget = F.mse_loss(S * M, T_blind * M)
loss_bg = F.mse_loss(S * (1.0 - M), T_orig * (1.0 - M))
total = lambda_f * loss_forget + lambda_b * loss_bg
# Retain anchor
if batch_retain is not None:
with torch.no_grad():
z0r = self.encode_images_to_latents(batch_retain).to(self.student.dtype)
Br = z0r.shape[0]
condr = self.cond_provider(Br)
tr = torch.randint(0, self.num_train_timesteps, (Br,), device=self.device).long()
epsr = torch.randn_like(z0r)
ztr = self.scheduler.add_noise(z0r, epsr, tr)
with torch.no_grad():
Tr = self.adapter.forward(self.teacher, ztr.to(self.teacher.dtype), tr, condr)
Sr = self.adapter.forward(self.student, ztr, tr, condr)
loss_retain = F.mse_loss(Sr, Tr)
total = total + lambda_r * loss_retain
logs = {
"loss_forget": float(loss_forget.detach().cpu()),
"loss_bg": float(loss_bg.detach().cpu()),
"loss_total": float(total.detach().cpu()),
}
return total, logs
How to instantiate for SDXL UNet (Diffusers)
# SDXL conditioning provider (null prompt, fixed canvas)
@torch.no_grad()
def make_sdxl_null_cond_provider(pipe, device="cuda"):
prompt_embeds, _, pooled, _ = pipe.encode_prompt(
prompt="", prompt_2="",
device=device,
num_images_per_prompt=1,
do_classifier_free_guidance=False,
)
original_size = (1024, 1024)
target_size = (1024, 1024)
crops_coords_top_left = (0, 0)
add_time_ids = pipe._get_add_time_ids(
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
).to(device)
base = {
"encoder_hidden_states": prompt_embeds,
"added_cond_kwargs": {"text_embeds": pooled, "time_ids": add_time_ids},
}
def cond_provider(bsz: int):
return {
"encoder_hidden_states": base["encoder_hidden_states"].repeat(bsz, 1, 1),
"added_cond_kwargs": {
"text_embeds": base["added_cond_kwargs"]["text_embeds"].repeat(bsz, 1),
"time_ids": base["added_cond_kwargs"]["time_ids"].repeat(bsz, 1),
},
}
return cond_provider
# Usage:
# pipe = StableDiffusionXLPipeline.from_pretrained(...)
cond_provider = make_sdxl_null_cond_provider(pipe, device="cuda")
scrubber = TeacherStudentScrubber(
pipe=pipe,
denoiser=pipe.unet,
adapter_factory=lambda teacher_unet: UNetMidBlockAdapter(teacher_unet),
cond_provider=cond_provider,
device="cuda",
num_clusters=5,
)
How to instantiate for a Transformer denoiser
This depends on your pipeline. The structure below is the clean pattern:
# Example placeholders:
# denoiser = pipe.transformer # or pipe.denoiser
# token_grid_hw MUST match your model's image-token grid.
# Choose a mid block to hook after you know your model layout.
token_grid_hw = (32, 32) # example only, must match your model's image-token grid
def make_transformer_adapter(teacher_denoiser):
mid_block = teacher_denoiser.blocks[len(teacher_denoiser.blocks)//2]
return TransformerBlockAdapter(
transformer=teacher_denoiser,
block=mid_block,
token_grid_hw=token_grid_hw,
image_token_slice=None, # set if output includes non-image tokens
)
def cond_provider(bsz: int):
# Must match your transformer pipeline signature.
return {
"encoder_hidden_states": your_prompt_embeds.repeat(bsz, 1, 1),
# "added_cond_kwargs": {...} # if required by your model
}
scrubber = TeacherStudentScrubber(
pipe=pipe,
denoiser=denoiser,
adapter_factory=make_transformer_adapter,
cond_provider=cond_provider,
device="cuda",
num_clusters=5,
)