Skip to content

Commit 687a89c

Browse files
committed
update logic for avoiding div-0; add unit test
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
1 parent 23160ba commit 687a89c

2 files changed

Lines changed: 14 additions & 10 deletions

File tree

nemo_reinforcer/algorithms/utils.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,9 @@ def wrapper(*args, **kwargs):
118118

119119
# need to surpress the masked tensor warnings from pytorch
120120
@surpress_user_warnings
121-
def masked_mean(values, mask, dim=None, check_all_zero_mask=True):
121+
def masked_mean(values, mask, dim=None):
122122
"""Masks values with mask, and computes the mean of the values using the masked values."""
123-
if dim is None:
124-
if check_all_zero_mask and mask.sum() == 0:
125-
return values.sum() * 0
126-
return values[mask.bool()].mean()
127-
return as_masked_tensor(values, mask.bool()).mean(dim=dim).to_tensor(torch.nan)
123+
return (values * mask).sum(dim=dim) / (mask.sum(dim=dim) + 1e-8)
128124

129125

130126
def set_seed(seed: int):

tests/unit/algorithms/test_loss_functions.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -393,10 +393,18 @@ def test_masked_mean_all_zeros():
393393
values = torch.tensor([1.0, 2.0, 3.0, 4.0])
394394
mask = torch.zeros_like(values)
395395

396-
# With check_zero_mask=True (default)
396+
# All zeros mask should return 0
397397
result = masked_mean(values, mask)
398-
assert torch.assert_allclose(result, torch.tensor(0.0))
398+
print(result)
399+
torch.testing.assert_allclose(result, torch.tensor(0.0))
399400

400401
# With check_zero_mask=False
401-
result = masked_mean(values, mask, check_all_zero_mask=False)
402-
assert torch.isnan(result) # Should be nan when mask is all zeros
402+
mask[0] = 1
403+
result = masked_mean(values, mask)
404+
torch.testing.assert_allclose(result, torch.tensor(1.0))
405+
406+
# Case 2: dim is not None
407+
values = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
408+
mask = torch.zeros_like(values)
409+
result = masked_mean(values, mask, dim=1)
410+
torch.testing.assert_allclose(result, torch.tensor([0.0, 0.0]))

0 commit comments

Comments
 (0)