提交 61011501 authored 作者: Tanjay94's avatar Tanjay94 提交者: Frederic

Added choose fonction.

上级 647921d7
...@@ -5092,3 +5092,43 @@ def swapaxes(y, axis1, axis2): ...@@ -5092,3 +5092,43 @@ def swapaxes(y, axis1, axis2):
li = range(0, ndim) li = range(0, ndim)
li[axis1], li[axis2] = li[axis2], li[axis1] li[axis1], li[axis2] = li[axis2], li[axis1]
return y.dimshuffle(li) return y.dimshuffle(li)
class choose(Op):
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, x, y):
x = as_tensor_variable(x)
y = as_tensor_variable(y)
return Apply(self, [x, y], [x.type()])
def perform(self, node, inputs, (z,)):
z[0] = numpy.choose(inputs[0], inputs[1])
def infer_shape(self, nodes, shapes):
return [shapes[0]]
def __str__(self):
return self.__class__.__name__
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论