77// except according to those terms.
88
99use std:: ops:: { Add , Div } ;
10- use libnum:: { self , One , Zero , Float } ;
10+ use libnum:: { self , One , Zero , Float , FromPrimitive } ;
1111use itertools:: free:: enumerate;
1212
1313use imp_prelude:: * ;
@@ -137,8 +137,11 @@ impl<A, S, D> ArrayBase<S, D>
137137 /// n i=1
138138 /// ```
139139 ///
140- /// **Panics** if `ddof` is less than zero or greater than the length of
141- /// the axis or if `axis` is out of bounds.
140+ /// and `n` is the length of the axis.
141+ ///
142+ /// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
143+ /// is out of bounds, or if `A::from_usize()` fails for any any of the
144+ /// numbers in the range `0..=n`.
142145 ///
143146 /// # Example
144147 ///
@@ -153,26 +156,27 @@ impl<A, S, D> ArrayBase<S, D>
153156 /// ```
154157 pub fn var_axis ( & self , axis : Axis , ddof : A ) -> Array < A , D :: Smaller >
155158 where
156- A : Float ,
159+ A : Float + FromPrimitive ,
157160 D : RemoveAxis ,
158161 {
159- let mut count = A :: zero ( ) ;
162+ let zero = A :: from_usize ( 0 ) . expect ( "Converting 0 to `A` must not fail." ) ;
163+ let n = A :: from_usize ( self . len_of ( axis) ) . expect ( "Converting length to `A` must not fail." ) ;
164+ assert ! (
165+ !( ddof < zero || ddof > n) ,
166+ "`ddof` must not be less than zero or greater than the length of \
167+ the axis",
168+ ) ;
169+ let dof = n - ddof;
160170 let mut mean = Array :: < A , _ > :: zeros ( self . dim . remove_axis ( axis) ) ;
161171 let mut sum_sq = Array :: < A , _ > :: zeros ( self . dim . remove_axis ( axis) ) ;
162- for subview in self . axis_iter ( axis) {
163- count = count + A :: one ( ) ;
172+ for ( i , subview) in self . axis_iter ( axis) . enumerate ( ) {
173+ let count = A :: from_usize ( i + 1 ) . expect ( "Converting index to `A` must not fail." ) ;
164174 azip ! ( mut mean, mut sum_sq, x ( subview) in {
165175 let delta = x - * mean;
166176 * mean = * mean + delta / count;
167177 * sum_sq = ( x - * mean) . mul_add( delta, * sum_sq) ;
168178 } ) ;
169179 }
170- assert ! (
171- !( ddof < A :: zero( ) || ddof > count) ,
172- "`ddof` must not be less than zero or greater than the length of \
173- the axis",
174- ) ;
175- let dof = count - ddof;
176180 sum_sq. mapv_into ( |s| s / dof)
177181 }
178182
@@ -201,8 +205,11 @@ impl<A, S, D> ArrayBase<S, D>
201205 /// n i=1
202206 /// ```
203207 ///
204- /// **Panics** if `ddof` is less than zero or greater than the length of
205- /// the axis or if `axis` is out of bounds.
208+ /// and `n` is the length of the axis.
209+ ///
210+ /// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
211+ /// is out of bounds, or if `A::from_usize()` fails for any any of the
212+ /// numbers in the range `0..=n`.
206213 ///
207214 /// # Example
208215 ///
@@ -217,7 +224,7 @@ impl<A, S, D> ArrayBase<S, D>
217224 /// ```
218225 pub fn std_axis ( & self , axis : Axis , ddof : A ) -> Array < A , D :: Smaller >
219226 where
220- A : Float ,
227+ A : Float + FromPrimitive ,
221228 D : RemoveAxis ,
222229 {
223230 self . var_axis ( axis, ddof) . mapv_into ( |x| x. sqrt ( ) )
0 commit comments