Commit 28b0ab0
committed
Fix context parallel pad length collision across shared split dims
Store pre-pad sequence lengths keyed by input name instead of shard
dimension so tensors like hidden_states and encoder_hidden_states_mask
that share split_dim=1 no longer overwrite each other. The gather hook
looks up the pad length for hidden_states when trimming gathered output.1 parent 1be07ba commit 28b0ab0
2 files changed
Lines changed: 9 additions & 7 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
37 | 37 | | |
38 | 38 | | |
39 | 39 | | |
40 | | - | |
| 40 | + | |
41 | 41 | | |
42 | 42 | | |
43 | 43 | | |
| |||
244 | 244 | | |
245 | 245 | | |
246 | 246 | | |
247 | | - | |
| 247 | + | |
| 248 | + | |
248 | 249 | | |
249 | 250 | | |
250 | 251 | | |
| |||
282 | 283 | | |
283 | 284 | | |
284 | 285 | | |
285 | | - | |
286 | | - | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
287 | 289 | | |
288 | 290 | | |
289 | 291 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
420 | 420 | | |
421 | 421 | | |
422 | 422 | | |
423 | | - | |
| 423 | + | |
424 | 424 | | |
425 | 425 | | |
426 | 426 | | |
| |||
443 | 443 | | |
444 | 444 | | |
445 | 445 | | |
446 | | - | |
| 446 | + | |
447 | 447 | | |
448 | 448 | | |
449 | 449 | | |
| |||
454 | 454 | | |
455 | 455 | | |
456 | 456 | | |
457 | | - | |
| 457 | + | |
458 | 458 | | |
0 commit comments