Fix modeling_smoldlm.py: download-first quickstart + MPS float64 fix
Browse files- modeling_smoldlm.py +4 -2
modeling_smoldlm.py
CHANGED
|
@@ -319,8 +319,10 @@ def _add_gumbel_noise(logits, temperature):
|
|
| 319 |
"""Gumbel-max sampling for stochastic token selection."""
|
| 320 |
if temperature == 0:
|
| 321 |
return logits
|
| 322 |
-
|
| 323 |
-
|
|
|
|
|
|
|
| 324 |
gumbel_noise = (-torch.log(noise.clamp(min=1e-20))) ** temperature
|
| 325 |
return logits.exp() / gumbel_noise
|
| 326 |
|
|
|
|
| 319 |
"""Gumbel-max sampling for stochastic token selection."""
|
| 320 |
if temperature == 0:
|
| 321 |
return logits
|
| 322 |
+
# float64 for precision on CUDA/CPU; MPS only supports float32
|
| 323 |
+
dtype = torch.float32 if logits.device.type == "mps" else torch.float64
|
| 324 |
+
logits = logits.to(dtype)
|
| 325 |
+
noise = torch.rand_like(logits, dtype=dtype)
|
| 326 |
gumbel_noise = (-torch.log(noise.clamp(min=1e-20))) ** temperature
|
| 327 |
return logits.exp() / gumbel_noise
|
| 328 |
|