提交 8ea11b10 authored 作者: abalkin's avatar abalkin

Started implementation of Take Op.

上级 b412823a
......@@ -6811,6 +6811,45 @@ class AdvancedIncSubtensor(Op):
*inputs[2:]).outputs
advanced_inc_subtensor = AdvancedIncSubtensor()
class Take(Op):
"""
Take elements from an array along an axis.
"""
def __init__(self, axis, mode):
self.axis = axis
self.mode = mode
def __eq__(self, other):
return (type(self) == type(other) and
self.axis == other.axis and
self.mode == other.mode)
def __hash__(self):
return hash((type(self), self.axis, self.mode))
def make_node(self, a, indices):
a = as_tensor_variable(a)
indices = as_tensor_variable(indices)
return gof.Apply(self, (a, indices), [a.type()])
def perform(self, node, inputs, outputs):
a, indices = inputs
out, = outputs
out[0] = a.take(indices, axis=self.axis, mode=self.mode)
def infer_shape(self, node, input_shapes):
a_shape, indices_shape = input_shapes
if self.axis is None:
shape = indices_shape
else:
shape = a_shape[:self.axis] + indices_shape + a_shape[self.axis+1:]
return [shape]
def take(a, indices, axis=None, mode='raise'):
return Take(axis, mode)(a, indices)
#########################
# Linalg : Dot
#########################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论