We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 23160ba commit 687a89cCopy full SHA for 687a89c
2 files changed
nemo_reinforcer/algorithms/utils.py
@@ -118,13 +118,9 @@ def wrapper(*args, **kwargs):
118
119
# need to surpress the masked tensor warnings from pytorch
120
@surpress_user_warnings
121
-def masked_mean(values, mask, dim=None, check_all_zero_mask=True):
+def masked_mean(values, mask, dim=None):
122
"""Masks values with mask, and computes the mean of the values using the masked values."""
123
- if dim is None:
124
- if check_all_zero_mask and mask.sum() == 0:
125
- return values.sum() * 0
126
- return values[mask.bool()].mean()
127
- return as_masked_tensor(values, mask.bool()).mean(dim=dim).to_tensor(torch.nan)
+ return (values * mask).sum(dim=dim) / (mask.sum(dim=dim) + 1e-8)
128
129
130
def set_seed(seed: int):
tests/unit/algorithms/test_loss_functions.py
@@ -393,10 +393,18 @@ def test_masked_mean_all_zeros():
393
values = torch.tensor([1.0, 2.0, 3.0, 4.0])
394
mask = torch.zeros_like(values)
395
396
- # With check_zero_mask=True (default)
+ # All zeros mask should return 0
397
result = masked_mean(values, mask)
398
- assert torch.assert_allclose(result, torch.tensor(0.0))
+ print(result)
399
+ torch.testing.assert_allclose(result, torch.tensor(0.0))
400
401
# With check_zero_mask=False
- result = masked_mean(values, mask, check_all_zero_mask=False)
402
- assert torch.isnan(result) # Should be nan when mask is all zeros
+ mask[0] = 1
403
+ result = masked_mean(values, mask)
404
+ torch.testing.assert_allclose(result, torch.tensor(1.0))
405
+
406
+ # Case 2: dim is not None
407
+ values = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
408
+ mask = torch.zeros_like(values)
409
+ result = masked_mean(values, mask, dim=1)
410
+ torch.testing.assert_allclose(result, torch.tensor([0.0, 0.0]))
0 commit comments