提交 aa7a7fdc authored 作者: abalkin's avatar abalkin

Reuse advanced_subtensor1 in all cases.

上级 11603832
......@@ -6814,47 +6814,6 @@ class AdvancedIncSubtensor(Op):
*inputs[2:]).outputs
advanced_inc_subtensor = AdvancedIncSubtensor()
class Take(Op):
"""
Take elements from an array along an axis.
"""
def __init__(self, axis, mode):
self.axis = axis
self.mode = mode
def __eq__(self, other):
return (type(self) == type(other) and
self.axis == other.axis and
self.mode == other.mode)
def __hash__(self):
return hash((type(self), self.axis, self.mode))
def make_node(self, a, indices):
a = as_tensor_variable(a)
indices = as_tensor_variable(indices)
if self.axis is None:
broadcastable = [False]
else:
broadcastable = (a.broadcastable[:self.axis] +
indices.broadcastable +
a.broadcastable[self.axis+1:])
return gof.Apply(self, (a, indices),
[TensorType(a.dtype, broadcastable)()])
def perform(self, node, inputs, outputs):
a, indices = inputs
out, = outputs
out[0] = a.take(indices, axis=self.axis, mode=self.mode)
def infer_shape(self, node, input_shapes):
a_shape, indices_shape = input_shapes
if self.axis is None:
shape = indices_shape
else:
shape = a_shape[:self.axis] + indices_shape + a_shape[self.axis+1:]
return [shape]
def take(a, indices, axis=None, mode='raise'):
a = as_tensor_variable(a)
indices = as_tensor_variable(indices)
......@@ -6877,7 +6836,13 @@ def take(a, indices, axis=None, mode='raise'):
shuffle[axis] = 0
return advanced_subtensor1(
a.dimshuffle(shuffle), indices).dimshuffle(shuffle)
return Take(axis, mode)(a, indices)
if axis is None:
shape = indices.shape
ndim = indices.ndim
else:
shape = concatenate([a.shape[:axis], indices.shape, a.shape[axis+1:]])
ndim = a.ndim + indices.ndim - 1
return take(a, indices.flatten(), axis, mode).reshape(shape, ndim)
#########################
# Linalg : Dot
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论