提交 33a72bb9 authored 作者: Gijs van Tulder's avatar Gijs van Tulder

unravel_index and ravel_multi_index always return int64.

上级 0cd76a61
......@@ -1207,7 +1207,7 @@ class UnravelIndex(gof.Op):
return gof.Apply(
self, [indices, dims],
[basic.TensorType(dtype=indices.dtype, broadcastable=(False,) * indices.ndim)()
[basic.TensorType(dtype='int64', broadcastable=(False,) * indices.ndim)()
for i in xrange(self.ndim)])
def infer_shape(self, node, input_shapes):
......@@ -1303,8 +1303,7 @@ class RavelMultiIndex(gof.Op):
return gof.Apply(
self, multi_index + [dims],
[basic.TensorType(dtype=multi_index[0].dtype,
broadcastable=(False,) * multi_index[0].ndim)()])
[basic.TensorType(dtype='int64', broadcastable=(False,) * multi_index[0].ndim)()])
def infer_shape(self, node, input_shapes):
return [input_shapes[0]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论