提交 59a99da3 authored 作者: Tanjay94's avatar Tanjay94 提交者: Frederic

Added mode to choose function.

上级 61011501
...@@ -5093,14 +5093,22 @@ def swapaxes(y, axis1, axis2): ...@@ -5093,14 +5093,22 @@ def swapaxes(y, axis1, axis2):
li[axis1], li[axis2] = li[axis2], li[axis1] li[axis1], li[axis2] = li[axis2], li[axis1]
return y.dimshuffle(li) return y.dimshuffle(li)
def choose(x, y, mode='raise'):
return Choose(mode)(x, y)
class choose(Op): class Choose(Op):
def __eq__(self, other): def __init__(self, mode):
return type(self) == type(other) self.mode = mode
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash((type(self), self.props()))
def __eq__(self, other):
return (type(self) == type(other) and self.props() == other.props())
def props(self):
return self.mode
def make_node(self, x, y): def make_node(self, x, y):
x = as_tensor_variable(x) x = as_tensor_variable(x)
...@@ -5109,7 +5117,7 @@ class choose(Op): ...@@ -5109,7 +5117,7 @@ class choose(Op):
return Apply(self, [x, y], [x.type()]) return Apply(self, [x, y], [x.type()])
def perform(self, node, inputs, (z,)): def perform(self, node, inputs, (z,)):
z[0] = numpy.choose(inputs[0], inputs[1]) z[0] = numpy.choose(inputs[0], inputs[1], self.mode)
def infer_shape(self, nodes, shapes): def infer_shape(self, nodes, shapes):
return [shapes[0]] return [shapes[0]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论