提交 ad6caaf7 authored 作者: Frederic's avatar Frederic

make infer_shape don't return uint64 dtype as they are not supported.

上级 fa33b17f
...@@ -360,6 +360,10 @@ class RepeatOp(theano.Op): ...@@ -360,6 +360,10 @@ class RepeatOp(theano.Op):
repeats = node.inputs[1] repeats = node.inputs[1]
out_shape = list(i0_shapes) out_shape = list(i0_shapes)
#uint64 shape are not supported.
dtype = None
if repeats.dtype in ['uint8', 'uint16', 'uint32']:
dtype = 'int64'
if self.axis is None: if self.axis is None:
if repeats.ndim == 0: if repeats.ndim == 0:
if len(i0_shapes) == 0: if len(i0_shapes) == 0:
...@@ -370,12 +374,12 @@ class RepeatOp(theano.Op): ...@@ -370,12 +374,12 @@ class RepeatOp(theano.Op):
res = res * d res = res * d
out_shape = (res * repeats, ) out_shape = (res * repeats, )
else: else:
out_shape = [theano.tensor.sum(repeats)] out_shape = [theano.tensor.sum(repeats, dtype=dtype)]
else: else:
if repeats.ndim == 0: if repeats.ndim == 0:
out_shape[self.axis] = out_shape[self.axis] * repeats out_shape[self.axis] = out_shape[self.axis] * repeats
else: else:
out_shape[self.axis] = theano.tensor.sum(repeats) out_shape[self.axis] = theano.tensor.sum(repeats, dtype=dtype)
return [out_shape] return [out_shape]
def __str__(self): def __str__(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论