提交 04603e7a authored 作者: Frederic Bastien's avatar Frederic Bastien

added infer_shape to ARange op.

上级 a5641dce
......@@ -3050,6 +3050,12 @@ class ARange(Op):
outputs = [tensor(self.dtype, (False,))]
return Apply(self, inputs, outputs)
def infer_shape(self, node, i_shapes):
start = as_tensor_variable(node.inputs[0])
stop = as_tensor_variable(node.inputs[1])
step = as_tensor_variable(node.inputs[2])
return [(cast(ceil(cast((stop-start),'float64')/step),'int64'),)]
def perform(self, node, (start, stop, step), (out,)):
start = start.item()
stop = stop.item()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论