提交 788b52f6 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Cast return value of infer_shapes to int64 in ARange.infer_shape

上级 3f6cb8f9
...@@ -501,7 +501,7 @@ class TensorType(Type): ...@@ -501,7 +501,7 @@ class TensorType(Type):
This read-only property is the preferred way to get the number of dimensions This read-only property is the preferred way to get the number of dimensions
of a `TensorType`. of a `TensorType`.
""" """
def make_variable(self, name = None): def make_variable(self, name = None):
...@@ -530,7 +530,7 @@ class TensorType(Type): ...@@ -530,7 +530,7 @@ class TensorType(Type):
bcast = named_broadcastable[b] bcast = named_broadcastable[b]
else: else:
if any(b): if any(b):
bcast = str(b) bcast = str(b)
else: else:
bcast = '%iD' % len(b) bcast = '%iD' % len(b)
return "TensorType(%s, %s)" % (str(self.dtype), bcast) return "TensorType(%s, %s)" % (str(self.dtype), bcast)
...@@ -3069,9 +3069,9 @@ class ARange(Op): ...@@ -3069,9 +3069,9 @@ class ARange(Op):
if is_constant_value(step, 1): if is_constant_value(step, 1):
if is_constant_value(start, 0): if is_constant_value(start, 0):
return [(stop,)] return [(cast(stop, 'int64'),)]
else: else:
return [((stop-start),)] return [(cast(stop-start, 'int64'),)]
else: else:
return [(cast(ceil(cast((stop-start),'float64')/step),'int64'),)] return [(cast(ceil(cast((stop-start),'float64')/step),'int64'),)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论