1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
// Copyright 2014-2016 bluss and ndarray developers.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
use crate::error::{from_kind, ErrorKind, ShapeError};
use crate::imp_prelude::*;
/// Stack arrays along the given axis.
///
/// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`.
/// (may be made more flexible in the future).<br>
/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
/// if the result is larger than is possible to represent.
///
/// ```
/// use ndarray::{arr2, Axis, stack};
///
/// let a = arr2(&[[2., 2.],
/// [3., 3.]]);
/// assert!(
/// stack(Axis(0), &[a.view(), a.view()])
/// == Ok(arr2(&[[2., 2.],
/// [3., 3.],
/// [2., 2.],
/// [3., 3.]]))
/// );
/// ```
pub fn stack<'a, A, D>(
axis: Axis,
arrays: &[ArrayView<'a, A, D>],
) -> Result<Array<A, D>, ShapeError>
where
A: Copy,
D: RemoveAxis,
{
if arrays.is_empty() {
return Err(from_kind(ErrorKind::Unsupported));
}
let mut res_dim = arrays[0].raw_dim();
if axis.index() >= res_dim.ndim() {
return Err(from_kind(ErrorKind::OutOfBounds));
}
let common_dim = res_dim.remove_axis(axis);
if arrays
.iter()
.any(|a| a.raw_dim().remove_axis(axis) != common_dim)
{
return Err(from_kind(ErrorKind::IncompatibleShape));
}
let stacked_dim = arrays.iter().fold(0, |acc, a| acc + a.len_of(axis));
res_dim.set_axis(axis, stacked_dim);
// we can safely use uninitialized values here because they are Copy
// and we will only ever write to them
let size = res_dim.size();
let mut v = Vec::with_capacity(size);
unsafe {
v.set_len(size);
}
let mut res = Array::from_shape_vec(res_dim, v)?;
{
let mut assign_view = res.view_mut();
for array in arrays {
let len = array.len_of(axis);
let (mut front, rest) = assign_view.split_at(axis, len);
front.assign(array);
assign_view = rest;
}
}
Ok(res)
}
/// Stack arrays along the given axis.
///
/// Uses the [`stack`][1] function, calling `ArrayView::from(&a)` on each
/// argument `a`.
///
/// [1]: fn.stack.html
///
/// ***Panics*** if the `stack` function would return an error.
///
/// ```
/// extern crate ndarray;
///
/// use ndarray::{arr2, stack, Axis};
///
/// # fn main() {
///
/// let a = arr2(&[[2., 2.],
/// [3., 3.]]);
/// assert!(
/// stack![Axis(0), a, a]
/// == arr2(&[[2., 2.],
/// [3., 3.],
/// [2., 2.],
/// [3., 3.]])
/// );
/// # }
/// ```
#[macro_export]
macro_rules! stack {
($axis:expr, $( $array:expr ),+ ) => {
$crate::stack($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
}
}