Skip to content

Commit 2a71d35

Browse files
committed
Switch var_axis and std_axis to use FromPrimitive
The documentation for the `Zero` and `One` traits say only that they are the additive and multiplicative identities; it doesn't say anything about converting an integer to a float by adding `One::one()` to `Zero::zero()` repeatedly. Additionally, it's nice to panic early instead of waiting until after the sum has been calculated.
1 parent 5e31b7d commit 2a71d35

1 file changed

Lines changed: 23 additions & 16 deletions

File tree

src/numeric/impl_numeric.rs

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
// except according to those terms.
88

99
use std::ops::{Add, Div};
10-
use libnum::{self, One, Zero, Float};
10+
use libnum::{self, One, Zero, Float, FromPrimitive};
1111
use itertools::free::enumerate;
1212

1313
use imp_prelude::*;
@@ -137,8 +137,11 @@ impl<A, S, D> ArrayBase<S, D>
137137
/// n i=1
138138
/// ```
139139
///
140-
/// **Panics** if `ddof` is less than zero or greater than the length of
141-
/// the axis or if `axis` is out of bounds.
140+
/// and `n` is the length of the axis.
141+
///
142+
/// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
143+
/// is out of bounds, or if `A::from_usize()` fails for any any of the
144+
/// numbers in the range `0..=n`.
142145
///
143146
/// # Example
144147
///
@@ -153,26 +156,27 @@ impl<A, S, D> ArrayBase<S, D>
153156
/// ```
154157
pub fn var_axis(&self, axis: Axis, ddof: A) -> Array<A, D::Smaller>
155158
where
156-
A: Float,
159+
A: Float + FromPrimitive,
157160
D: RemoveAxis,
158161
{
159-
let mut count = A::zero();
162+
let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail.");
163+
let n = A::from_usize(self.len_of(axis)).expect("Converting length to `A` must not fail.");
164+
assert!(
165+
!(ddof < zero || ddof > n),
166+
"`ddof` must not be less than zero or greater than the length of \
167+
the axis",
168+
);
169+
let dof = n - ddof;
160170
let mut mean = Array::<A, _>::zeros(self.dim.remove_axis(axis));
161171
let mut sum_sq = Array::<A, _>::zeros(self.dim.remove_axis(axis));
162-
for subview in self.axis_iter(axis) {
163-
count = count + A::one();
172+
for (i, subview) in self.axis_iter(axis).enumerate() {
173+
let count = A::from_usize(i + 1).expect("Converting index to `A` must not fail.");
164174
azip!(mut mean, mut sum_sq, x (subview) in {
165175
let delta = x - *mean;
166176
*mean = *mean + delta / count;
167177
*sum_sq = (x - *mean).mul_add(delta, *sum_sq);
168178
});
169179
}
170-
assert!(
171-
!(ddof < A::zero() || ddof > count),
172-
"`ddof` must not be less than zero or greater than the length of \
173-
the axis",
174-
);
175-
let dof = count - ddof;
176180
sum_sq.mapv_into(|s| s / dof)
177181
}
178182

@@ -201,8 +205,11 @@ impl<A, S, D> ArrayBase<S, D>
201205
/// n i=1
202206
/// ```
203207
///
204-
/// **Panics** if `ddof` is less than zero or greater than the length of
205-
/// the axis or if `axis` is out of bounds.
208+
/// and `n` is the length of the axis.
209+
///
210+
/// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
211+
/// is out of bounds, or if `A::from_usize()` fails for any any of the
212+
/// numbers in the range `0..=n`.
206213
///
207214
/// # Example
208215
///
@@ -217,7 +224,7 @@ impl<A, S, D> ArrayBase<S, D>
217224
/// ```
218225
pub fn std_axis(&self, axis: Axis, ddof: A) -> Array<A, D::Smaller>
219226
where
220-
A: Float,
227+
A: Float + FromPrimitive,
221228
D: RemoveAxis,
222229
{
223230
self.var_axis(axis, ddof).mapv_into(|x| x.sqrt())

0 commit comments

Comments
 (0)