提交 33a72bb9 authored 作者: Gijs van Tulder's avatar Gijs van Tulder

unravel_index and ravel_multi_index always return int64.

上级 0cd76a61
...@@ -1207,7 +1207,7 @@ class UnravelIndex(gof.Op): ...@@ -1207,7 +1207,7 @@ class UnravelIndex(gof.Op):
return gof.Apply( return gof.Apply(
self, [indices, dims], self, [indices, dims],
[basic.TensorType(dtype=indices.dtype, broadcastable=(False,) * indices.ndim)() [basic.TensorType(dtype='int64', broadcastable=(False,) * indices.ndim)()
for i in xrange(self.ndim)]) for i in xrange(self.ndim)])
def infer_shape(self, node, input_shapes): def infer_shape(self, node, input_shapes):
...@@ -1303,8 +1303,7 @@ class RavelMultiIndex(gof.Op): ...@@ -1303,8 +1303,7 @@ class RavelMultiIndex(gof.Op):
return gof.Apply( return gof.Apply(
self, multi_index + [dims], self, multi_index + [dims],
[basic.TensorType(dtype=multi_index[0].dtype, [basic.TensorType(dtype='int64', broadcastable=(False,) * multi_index[0].ndim)()])
broadcastable=(False,) * multi_index[0].ndim)()])
def infer_shape(self, node, input_shapes): def infer_shape(self, node, input_shapes):
return [input_shapes[0]] return [input_shapes[0]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论