Commit 5495fb4
committed
Fix Flax LMS scheduler sigma mismatch and derivatives buffer OOB
Remove the erroneous +1e-5 sigma offset in step() so it matches
scale_model_input() and the PyTorch LMSDiscreteScheduler reference.
Cap the multistep fori_loop bound by the pre-allocated derivatives
buffer size to prevent out-of-bounds indexing when order exceeds
max_order from set_timesteps.1 parent 4eeb570 commit 5495fb4
1 file changed
Lines changed: 3 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
271 | 271 | | |
272 | 272 | | |
273 | 273 | | |
274 | | - | |
| 274 | + | |
275 | 275 | | |
276 | 276 | | |
277 | 277 | | |
| |||
291 | 291 | | |
292 | 292 | | |
293 | 293 | | |
| 294 | + | |
294 | 295 | | |
295 | 296 | | |
296 | | - | |
| 297 | + | |
297 | 298 | | |
298 | 299 | | |
299 | 300 | | |
| |||
0 commit comments