提交 4db19187 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add more type safety.

上级 168c502e
...@@ -2721,13 +2721,14 @@ class InversePermutation(Op): ...@@ -2721,13 +2721,14 @@ class InversePermutation(Op):
""" """
def make_node(self, x): def make_node(self, x):
x = as_tensor_variable(x)
return Apply(self, [x], [x.type()]) return Apply(self, [x], [x.type()])
def perform(self, node, (x,), (outs,)): def perform(self, node, (x,), (outs,)):
if outs[0] is None or outs[0].shape != x.shape: if outs[0] is None or outs[0].shape != x.shape:
outs[0] = numpy.empty_like(x) outs[0] = numpy.empty_like(x)
for i in numpy.ndindex(x.shape[:-1]): for i in numpy.ndindex(x.shape[:-1]):
outs[0][i][x[i]] = numpy.arange(x.shape[0]) outs[0][i][x[i]] = numpy.arange(x.shape[-1], dtype=x.dtype)
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
return [None] return [None]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论