HoangHa commited on
Commit
ea93d7f
·
verified ·
1 Parent(s): 06b57f2

Fix modeling_smoldlm.py: download-first quickstart + MPS float64 fix

Browse files
Files changed (1) hide show
  1. modeling_smoldlm.py +4 -2
modeling_smoldlm.py CHANGED
@@ -319,8 +319,10 @@ def _add_gumbel_noise(logits, temperature):
319
  """Gumbel-max sampling for stochastic token selection."""
320
  if temperature == 0:
321
  return logits
322
- logits = logits.to(torch.float64)
323
- noise = torch.rand_like(logits, dtype=torch.float64)
 
 
324
  gumbel_noise = (-torch.log(noise.clamp(min=1e-20))) ** temperature
325
  return logits.exp() / gumbel_noise
326
 
 
319
  """Gumbel-max sampling for stochastic token selection."""
320
  if temperature == 0:
321
  return logits
322
+ # float64 for precision on CUDA/CPU; MPS only supports float32
323
+ dtype = torch.float32 if logits.device.type == "mps" else torch.float64
324
+ logits = logits.to(dtype)
325
+ noise = torch.rand_like(logits, dtype=dtype)
326
  gumbel_noise = (-torch.log(noise.clamp(min=1e-20))) ** temperature
327
  return logits.exp() / gumbel_noise
328