
该numpy的问题#8708具有take_along_axis的样本实现,做什么,我需要;
我不确定大型阵列是否有效,但似乎可行。
def take_along_axis(arr, ind, axis): """ ... here means a "pack" of dimensions, possibly empty arr: array_like of shape (A..., M, B...) source array ind: array_like of shape (A..., K..., B...) indices to take along each 1d slice of `arr` axis: int index of the axis with dimension M out: array_like of shape (A..., K..., B...) out[a..., k..., b...] = arr[a..., inds[a..., k..., b...], b...] """ if axis < 0: if axis >= -arr.ndim:axis += arr.ndim else:raise IndexError('axis out of range') ind_shape = (1,) * ind.ndim ins_ndim = ind.ndim - (arr.ndim - 1) #inserted dimensions dest_dims = list(range(axis)) + [None] + list(range(axis+ins_ndim, ind.ndim)) # could also call np.ix_ here with some dummy arguments, then throw those results away inds = [] for dim, n in zip(dest_dims, arr.shape): if dim is None: inds.append(ind) else: ind_shape_dim = ind_shape[:dim] + (-1,) + ind_shape[dim+1:] inds.append(np.arange(n).reshape(ind_shape_dim)) return arr[tuple(inds)]产生
>>> A = np.array([[3,2,1],[4,0,6]])>>> B = np.array([[3,1,4],[1,5,9]])>>> i = A.argsort(axis=-1)>>> take_along_axis(A,i,axis=-1)array([[1, 2, 3], [0, 4, 6]])>>> take_along_axis(B,i,axis=-1)array([[4, 1, 3], [5, 1, 9]])
欢迎分享,转载请注明来源:内存溢出
微信扫一扫
支付宝扫一扫
评论列表(0条)