提交 39b45c6a authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Optimize special cases of ARange.infer_shapes, when start=0 or step=1.

上级 ad7937d2
...@@ -3051,10 +3051,29 @@ class ARange(Op): ...@@ -3051,10 +3051,29 @@ class ARange(Op):
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def infer_shape(self, node, i_shapes): def infer_shape(self, node, i_shapes):
start = as_tensor_variable(node.inputs[0]) start, stop, step = node.inputs
stop = as_tensor_variable(node.inputs[1]) def is_constant_value(var, value):
step = as_tensor_variable(node.inputs[2]) if numpy.all(var == value):
return [(cast(ceil(cast((stop-start),'float64')/step),'int64'),)] return True
if isinstance(var, gof.Constant):
return numpy.all(var.data == value)
if var.owner:
if var.owner.op == T.alloc:
return is_constant_value(var.owner.inputs[0], value)
if isinstance(var.owner.op, DimShuffle):
return is_constant_value(var.owner.inputs[0], value)
if var.owner.op == T.fill:
shape, in_var = var.owner.inputs
return is_constant_value(in_var, value)
return False
if is_constant_value(step, 1):
if is_constant_value(start, 0):
return [(stop,)]
else:
return [((stop-start),)]
else:
return [(cast(ceil(cast((stop-start),'float64')/step),'int64'),)]
def perform(self, node, (start, stop, step), (out,)): def perform(self, node, (start, stop, step), (out,)):
start = start.item() start = start.item()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论