diff --git a/README-quick-start.md b/README-quick-start.md index ad13acc72..84e968744 100644 --- a/README-quick-start.md +++ b/README-quick-start.md @@ -91,7 +91,7 @@ fn main() { use ndarray::prelude::*; use ndarray::{Array, Ix3}; fn main() { - let a = Array::::linspace(0., 5., 11); + let a = Array::::linspace(0.0..=5.0, 11); println!("{:?}", a); } ``` diff --git a/benches/bench1.rs b/benches/bench1.rs index ea527cd35..3b5405329 100644 --- a/benches/bench1.rs +++ b/benches/bench1.rs @@ -984,7 +984,7 @@ const MEAN_SUM_N: usize = 127; fn range_mat(m: Ix, n: Ix) -> Array2 { assert!(m * n != 0); - Array::linspace(0., (m * n - 1) as f32, m * n) + Array::linspace(0.0..=(m * n - 1) as f32, m * n) .into_shape_with_order((m, n)) .unwrap() } diff --git a/benches/construct.rs b/benches/construct.rs index 71a4fb905..958eaa3b6 100644 --- a/benches/construct.rs +++ b/benches/construct.rs @@ -21,7 +21,7 @@ fn zeros_f64(bench: &mut Bencher) #[bench] fn map_regular(bench: &mut test::Bencher) { - let a = Array::linspace(0., 127., 128) + let a = Array::linspace(0.0..=127.0, 128) .into_shape_with_order((8, 16)) .unwrap(); bench.iter(|| a.map(|&x| 2. * x)); @@ -31,7 +31,7 @@ fn map_regular(bench: &mut test::Bencher) #[bench] fn map_stride(bench: &mut test::Bencher) { - let a = Array::linspace(0., 127., 256) + let a = Array::linspace(0.0..=127.0, 256) .into_shape_with_order((8, 32)) .unwrap(); let av = a.slice(s![.., ..;2]); diff --git a/benches/higher-order.rs b/benches/higher-order.rs index 5eb009566..6356687fb 100644 --- a/benches/higher-order.rs +++ b/benches/higher-order.rs @@ -14,7 +14,7 @@ const Y: usize = 16; #[bench] fn map_regular(bench: &mut Bencher) { - let a = Array::linspace(0., 127., N) + let a = Array::linspace(0.0..=127.0, N) .into_shape_with_order((X, Y)) .unwrap(); bench.iter(|| a.map(|&x| 2. * x)); @@ -29,7 +29,7 @@ pub fn double_array(mut a: ArrayViewMut2<'_, f64>) #[bench] fn map_stride_double_f64(bench: &mut Bencher) { - let mut a = Array::linspace(0., 127., N * 2) + let mut a = Array::linspace(0.0..=127.0, N * 2) .into_shape_with_order([X, Y * 2]) .unwrap(); let mut av = a.slice_mut(s![.., ..;2]); @@ -42,7 +42,7 @@ fn map_stride_double_f64(bench: &mut Bencher) #[bench] fn map_stride_f64(bench: &mut Bencher) { - let a = Array::linspace(0., 127., N * 2) + let a = Array::linspace(0.0..=127.0, N * 2) .into_shape_with_order([X, Y * 2]) .unwrap(); let av = a.slice(s![.., ..;2]); @@ -53,7 +53,7 @@ fn map_stride_f64(bench: &mut Bencher) #[bench] fn map_stride_u32(bench: &mut Bencher) { - let a = Array::linspace(0., 127., N * 2) + let a = Array::linspace(0.0..=127.0, N * 2) .into_shape_with_order([X, Y * 2]) .unwrap(); let b = a.mapv(|x| x as u32); @@ -65,7 +65,7 @@ fn map_stride_u32(bench: &mut Bencher) #[bench] fn fold_axis(bench: &mut Bencher) { - let a = Array::linspace(0., 127., N * 2) + let a = Array::linspace(0.0..=127.0, N * 2) .into_shape_with_order([X, Y * 2]) .unwrap(); bench.iter(|| a.fold_axis(Axis(0), 0., |&acc, &elt| acc + elt)); diff --git a/benches/iter.rs b/benches/iter.rs index bc483c8c2..0e18f1230 100644 --- a/benches/iter.rs +++ b/benches/iter.rs @@ -47,7 +47,7 @@ fn iter_sum_2d_transpose(bench: &mut Bencher) #[bench] fn iter_filter_sum_2d_u32(bench: &mut Bencher) { - let a = Array::linspace(0., 1., 256) + let a = Array::linspace(0.0..=1.0, 256) .into_shape_with_order((16, 16)) .unwrap(); let b = a.mapv(|x| (x * 100.) as u32); @@ -58,7 +58,7 @@ fn iter_filter_sum_2d_u32(bench: &mut Bencher) #[bench] fn iter_filter_sum_2d_f32(bench: &mut Bencher) { - let a = Array::linspace(0., 1., 256) + let a = Array::linspace(0.0..=1.0, 256) .into_shape_with_order((16, 16)) .unwrap(); let b = a * 100.; @@ -69,7 +69,7 @@ fn iter_filter_sum_2d_f32(bench: &mut Bencher) #[bench] fn iter_filter_sum_2d_stride_u32(bench: &mut Bencher) { - let a = Array::linspace(0., 1., 256) + let a = Array::linspace(0.0..=1.0, 256) .into_shape_with_order((16, 16)) .unwrap(); let b = a.mapv(|x| (x * 100.) as u32); @@ -81,7 +81,7 @@ fn iter_filter_sum_2d_stride_u32(bench: &mut Bencher) #[bench] fn iter_filter_sum_2d_stride_f32(bench: &mut Bencher) { - let a = Array::linspace(0., 1., 256) + let a = Array::linspace(0.0..=1.0, 256) .into_shape_with_order((16, 16)) .unwrap(); let b = a * 100.; @@ -93,7 +93,7 @@ fn iter_filter_sum_2d_stride_f32(bench: &mut Bencher) #[bench] fn iter_rev_step_by_contiguous(bench: &mut Bencher) { - let a = Array::linspace(0., 1., 512); + let a = Array::linspace(0.0..=1.0, 512); bench.iter(|| { a.iter().rev().step_by(2).for_each(|x| { black_box(x); @@ -105,7 +105,7 @@ fn iter_rev_step_by_contiguous(bench: &mut Bencher) #[bench] fn iter_rev_step_by_discontiguous(bench: &mut Bencher) { - let mut a = Array::linspace(0., 1., 1024); + let mut a = Array::linspace(0.0..=1.0, 1024); a.slice_axis_inplace(Axis(0), Slice::new(0, None, 2)); bench.iter(|| { a.iter().rev().step_by(2).for_each(|x| { diff --git a/benches/numeric.rs b/benches/numeric.rs index ceb57fbd7..5dcde52d4 100644 --- a/benches/numeric.rs +++ b/benches/numeric.rs @@ -13,7 +13,7 @@ const Y: usize = 16; #[bench] fn clip(bench: &mut Bencher) { - let mut a = Array::linspace(0., 127., N * 2) + let mut a = Array::linspace(0.0..=127.0, N * 2) .into_shape_with_order([X, Y * 2]) .unwrap(); let min = 2.; diff --git a/examples/sort-axis.rs b/examples/sort-axis.rs index 4da3a64d5..112abfc77 100644 --- a/examples/sort-axis.rs +++ b/examples/sort-axis.rs @@ -169,7 +169,7 @@ where D: Dimension #[cfg(feature = "std")] fn main() { - let a = Array::linspace(0., 63., 64) + let a = Array::linspace(0.0..=63.0, 64) .into_shape_with_order((8, 8)) .unwrap(); let strings = a.map(|x| x.to_string()); diff --git a/src/doc/ndarray_for_numpy_users/mod.rs b/src/doc/ndarray_for_numpy_users/mod.rs index bb6b7ae83..a9400211d 100644 --- a/src/doc/ndarray_for_numpy_users/mod.rs +++ b/src/doc/ndarray_for_numpy_users/mod.rs @@ -195,8 +195,8 @@ //! ------|-----------|------ //! `np.array([[1.,2.,3.], [4.,5.,6.]])` | [`array![[1.,2.,3.], [4.,5.,6.]]`][array!] or [`arr2(&[[1.,2.,3.], [4.,5.,6.]])`][arr2()] | 2×3 floating-point array literal //! `np.arange(0., 10., 0.5)` or `np.r_[:10.:0.5]` | [`Array::range(0., 10., 0.5)`][::range()] | create a 1-D array with values `0.`, `0.5`, …, `9.5` -//! `np.linspace(0., 10., 11)` or `np.r_[:10.:11j]` | [`Array::linspace(0., 10., 11)`][::linspace()] | create a 1-D array with 11 elements with values `0.`, …, `10.` -//! `np.logspace(2.0, 3.0, num=4, base=10.0)` | [`Array::logspace(10.0, 2.0, 3.0, 4)`][::logspace()] | create a 1-D array with 4 elements with values `100.`, `215.4`, `464.1`, `1000.` +//! `np.linspace(0., 10., 11)` or `np.r_[:10.:11j]` | [`Array::linspace(0.0..=10.0, 11)`][::linspace()] | create a 1-D array with 11 elements with values `0.`, …, `10.` +//! `np.logspace(2.0, 3.0, num=4, base=10.0)` | [`Array::logspace(10.0, 2.0..=3.0, 4)`][::logspace()] | create a 1-D array with 4 elements with values `100.`, `215.4`, `464.1`, `1000.` //! `np.geomspace(1., 1000., num=4)` | [`Array::geomspace(1e0, 1e3, 4)`][::geomspace()] | create a 1-D array with 4 elements with values `1.`, `10.`, `100.`, `1000.` //! `np.ones((3, 4, 5))` | [`Array::ones((3, 4, 5))`][::ones()] | create a 3×4×5 array filled with ones (inferring the element type) //! `np.zeros((3, 4, 5))` | [`Array::zeros((3, 4, 5))`][::zeros()] | create a 3×4×5 array filled with zeros (inferring the element type) diff --git a/src/finite_bounds.rs b/src/finite_bounds.rs new file mode 100644 index 000000000..565fe2bcb --- /dev/null +++ b/src/finite_bounds.rs @@ -0,0 +1,42 @@ +use num_traits::Float; + +pub enum Bound +{ + Included(F), + Excluded(F), +} + +/// A version of std::ops::RangeBounds that only implements a..b and a..=b ranges. +pub trait FiniteBounds +{ + fn start_bound(&self) -> F; + fn end_bound(&self) -> Bound; +} + +impl FiniteBounds for std::ops::Range +where F: Float +{ + fn start_bound(&self) -> F + { + self.start + } + + fn end_bound(&self) -> Bound + { + Bound::Excluded(self.end) + } +} + +impl FiniteBounds for std::ops::RangeInclusive +where F: Float +{ + fn start_bound(&self) -> F + { + *self.start() + } + + fn end_bound(&self) -> Bound + { + Bound::Included(*self.end()) + } +} diff --git a/src/impl_constructors.rs b/src/impl_constructors.rs index ba01e2ca3..7f71cca5b 100644 --- a/src/impl_constructors.rs +++ b/src/impl_constructors.rs @@ -58,10 +58,7 @@ where S: DataOwned pub fn from_vec(v: Vec) -> Self { if mem::size_of::() == 0 { - assert!( - v.len() <= isize::MAX as usize, - "Length must fit in `isize`.", - ); + assert!(v.len() <= isize::MAX as usize, "Length must fit in `isize`.",); } unsafe { Self::from_shape_vec_unchecked(v.len() as Ix, v) } } @@ -95,14 +92,16 @@ where S: DataOwned /// ```rust /// use ndarray::{Array, arr1}; /// - /// let array = Array::linspace(0., 1., 5); + /// let array = Array::linspace(0.0..=1.0, 5); /// assert!(array == arr1(&[0.0, 0.25, 0.5, 0.75, 1.0])) /// ``` #[cfg(feature = "std")] - pub fn linspace(start: A, end: A, n: usize) -> Self - where A: Float + pub fn linspace(range: R, n: usize) -> Self + where + R: crate::finite_bounds::FiniteBounds, + A: Float, { - Self::from(to_vec(linspace::linspace(start, end, n))) + Self::from(to_vec(linspace::linspace(range, n))) } /// Create a one-dimensional array with elements from `start` to `end` @@ -137,18 +136,20 @@ where S: DataOwned /// use approx::assert_abs_diff_eq; /// use ndarray::{Array, arr1}; /// - /// let array = Array::logspace(10.0, 0.0, 3.0, 4); + /// let array = Array::logspace(10.0, 0.0..=3.0, 4); /// assert_abs_diff_eq!(array, arr1(&[1e0, 1e1, 1e2, 1e3]), epsilon = 1e-12); /// - /// let array = Array::logspace(-10.0, 3.0, 0.0, 4); + /// let array = Array::logspace(-10.0, 3.0..=0.0, 4); /// assert_abs_diff_eq!(array, arr1(&[-1e3, -1e2, -1e1, -1e0]), epsilon = 1e-12); /// # } /// ``` #[cfg(feature = "std")] - pub fn logspace(base: A, start: A, end: A, n: usize) -> Self - where A: Float + pub fn logspace(base: A, range: R, n: usize) -> Self + where + R: crate::finite_bounds::FiniteBounds, + A: Float, { - Self::from(to_vec(logspace::logspace(base, start, end, n))) + Self::from(to_vec(logspace::logspace(base, range, n))) } /// Create a one-dimensional array with `n` geometrically spaced elements diff --git a/src/lib.rs b/src/lib.rs index 41e5ca350..970c3f126 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -199,6 +199,8 @@ mod indexes; mod iterators; mod layout; mod linalg_traits; +#[cfg(feature = "std")] +mod finite_bounds; mod linspace; #[cfg(feature = "std")] pub use crate::linspace::{linspace, range, Linspace}; diff --git a/src/linspace.rs b/src/linspace.rs index 411c480db..ff52bf0c1 100644 --- a/src/linspace.rs +++ b/src/linspace.rs @@ -6,6 +6,9 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. #![cfg(feature = "std")] + +use crate::finite_bounds::{Bound, FiniteBounds}; + use num_traits::Float; /// An iterator of a sequence of evenly spaced floats. @@ -71,17 +74,24 @@ impl ExactSizeIterator for Linspace where Linspace: Iterator {} /// The iterator element type is `F`, where `F` must implement [`Float`], e.g. /// [`f32`] or [`f64`]. /// -/// **Panics** if converting `n - 1` to type `F` fails. +/// **Panics** if converting `n` to type `F` fails. #[inline] -pub fn linspace(a: F, b: F, n: usize) -> Linspace -where F: Float +pub fn linspace(range: R, n: usize) -> Linspace +where + R: FiniteBounds, + F: Float, { - let step = if n > 1 { - let num_steps = F::from(n - 1).expect("Converting number of steps to `A` must not fail."); + let (a, b, num_steps) = match (range.start_bound(), range.end_bound()) { + (a, Bound::Included(b)) => (a, b, F::from(n - 1).expect("Converting number of steps to `A` must not fail.")), + (a, Bound::Excluded(b)) => (a, b, F::from(n).expect("Converting number of steps to `A` must not fail.")), + }; + + let step = if num_steps > F::zero() { (b - a) / num_steps } else { F::zero() }; + Linspace { start: a, step, diff --git a/src/logspace.rs b/src/logspace.rs index 463012018..dd1b7ae19 100644 --- a/src/logspace.rs +++ b/src/logspace.rs @@ -6,6 +6,9 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. #![cfg(feature = "std")] + +use crate::finite_bounds::{Bound, FiniteBounds}; + use num_traits::Float; /// An iterator of a sequence of logarithmically spaced number. @@ -79,15 +82,22 @@ impl ExactSizeIterator for Logspace where Logspace: Iterator {} /// /// **Panics** if converting `n - 1` to type `F` fails. #[inline] -pub fn logspace(base: F, a: F, b: F, n: usize) -> Logspace -where F: Float +pub fn logspace(base: F, range: R, n: usize) -> Logspace +where + R: FiniteBounds, + F: Float, { - let step = if n > 1 { - let num_steps = F::from(n - 1).expect("Converting number of steps to `A` must not fail."); + let (a, b, num_steps) = match (range.start_bound(), range.end_bound()) { + (a, Bound::Included(b)) => (a, b, F::from(n - 1).expect("Converting number of steps to `A` must not fail.")), + (a, Bound::Excluded(b)) => (a, b, F::from(n).expect("Converting number of steps to `A` must not fail.")), + }; + + let step = if num_steps > F::zero() { (b - a) / num_steps } else { F::zero() }; + Logspace { sign: base.signum(), base: base.abs(), @@ -110,23 +120,23 @@ mod tests use crate::{arr1, Array1}; use approx::assert_abs_diff_eq; - let array: Array1<_> = logspace(10.0, 0.0, 3.0, 4).collect(); + let array: Array1<_> = logspace(10.0, 0.0..=3.0, 4).collect(); assert_abs_diff_eq!(array, arr1(&[1e0, 1e1, 1e2, 1e3]), epsilon = 1e-12); - let array: Array1<_> = logspace(10.0, 3.0, 0.0, 4).collect(); + let array: Array1<_> = logspace(10.0, 3.0..=0.0, 4).collect(); assert_abs_diff_eq!(array, arr1(&[1e3, 1e2, 1e1, 1e0]), epsilon = 1e-12); - let array: Array1<_> = logspace(-10.0, 3.0, 0.0, 4).collect(); + let array: Array1<_> = logspace(-10.0, 3.0..=0.0, 4).collect(); assert_abs_diff_eq!(array, arr1(&[-1e3, -1e2, -1e1, -1e0]), epsilon = 1e-12); - let array: Array1<_> = logspace(-10.0, 0.0, 3.0, 4).collect(); + let array: Array1<_> = logspace(-10.0, 0.0..=3.0, 4).collect(); assert_abs_diff_eq!(array, arr1(&[-1e0, -1e1, -1e2, -1e3]), epsilon = 1e-12); } #[test] fn iter_forward() { - let mut iter = logspace(10.0f64, 0.0, 3.0, 4); + let mut iter = logspace(10.0f64, 0.0..=3.0, 4); assert!(iter.size_hint() == (4, Some(4))); @@ -142,7 +152,7 @@ mod tests #[test] fn iter_backward() { - let mut iter = logspace(10.0f64, 0.0, 3.0, 4); + let mut iter = logspace(10.0f64, 0.0..=3.0, 4); assert!(iter.size_hint() == (4, Some(4))); diff --git a/src/parallel/mod.rs b/src/parallel/mod.rs index 2eef69307..3ac0d4b04 100644 --- a/src/parallel/mod.rs +++ b/src/parallel/mod.rs @@ -65,7 +65,7 @@ //! use ndarray::Axis; //! use ndarray::parallel::prelude::*; //! -//! let a = Array::linspace(0., 63., 64).into_shape_with_order((4, 16)).unwrap(); +//! let a = Array::linspace(0.0..=63.0, 64).into_shape_with_order((4, 16)).unwrap(); //! let mut sums = Vec::new(); //! a.axis_iter(Axis(0)) //! .into_par_iter() @@ -84,7 +84,7 @@ //! use ndarray::Axis; //! use ndarray::parallel::prelude::*; //! -//! let a = Array::linspace(0., 63., 64).into_shape_with_order((4, 16)).unwrap(); +//! let a = Array::linspace(0.0..=63.0, 64).into_shape_with_order((4, 16)).unwrap(); //! let mut shapes = Vec::new(); //! a.axis_chunks_iter(Axis(0), 3) //! .into_par_iter() diff --git a/tests/par_azip.rs b/tests/par_azip.rs index 41011d495..7dd233e5e 100644 --- a/tests/par_azip.rs +++ b/tests/par_azip.rs @@ -41,7 +41,7 @@ fn test_par_azip3() *a += b / 10.; *c = a.sin(); }); - let res = Array::linspace(0., 3.1, 32).mapv_into(f32::sin); + let res = Array::linspace(0.0..=3.1, 32).mapv_into(f32::sin); assert_abs_diff_eq!(res, ArrayView::from(&c), epsilon = 1e-4); } diff --git a/tests/par_rayon.rs b/tests/par_rayon.rs index 13669763f..1b6b2b794 100644 --- a/tests/par_rayon.rs +++ b/tests/par_rayon.rs @@ -26,7 +26,7 @@ fn test_axis_iter() fn test_axis_iter_mut() { use approx::assert_abs_diff_eq; - let mut a = Array::linspace(0., 1.0f64, M * N) + let mut a = Array::linspace(0.0..=1.0f64, M * N) .into_shape_with_order((M, N)) .unwrap(); let b = a.mapv(|x| x.exp()); @@ -82,7 +82,7 @@ fn test_axis_chunks_iter() fn test_axis_chunks_iter_mut() { use approx::assert_abs_diff_eq; - let mut a = Array::linspace(0., 1.0f64, M * N) + let mut a = Array::linspace(0.0..=1.0f64, M * N) .into_shape_with_order((M, N)) .unwrap(); let b = a.mapv(|x| x.exp());