提交 304aad43 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Update _check_rows_is_arange_len_labels after the introduction of ShapeOptimizer

上级 22b39728
...@@ -905,12 +905,19 @@ def _check_rows_is_arange_len_labels(rows, labels): ...@@ -905,12 +905,19 @@ def _check_rows_is_arange_len_labels(rows, labels):
return False return False
if getattr(step, 'data', None) != 1: # constant step will have data if getattr(step, 'data', None) != 1: # constant step will have data
return False return False
if stop.owner and isinstance(stop.owner.op, tensor.Subtensor): if not stop.owner:
return False
# Not sure if that case happens any more after the introduction
# of ShapeOptimizer
if isinstance(stop.owner.op, tensor.Subtensor):
shape_subtensor = stop.owner shape_subtensor = stop.owner
if shape_subtensor.op.idx_list == [0]: if shape_subtensor.op.idx_list == [0]:
shape_var, = shape_subtensor.inputs shape_var, = shape_subtensor.inputs
if shape_var.owner and shape_var.owner.op == tensor._shape: if shape_var.owner and shape_var.owner.op == tensor._shape:
return shape_var.owner.inputs[0] is labels return shape_var.owner.inputs[0] is labels
else:
shape_of = stop.owner.env.shape_feature.shape_of
return shape_of[labels][0] is stop
@gof.local_optimizer([tensor._shape]) @gof.local_optimizer([tensor._shape])
def local_shape_lift_advanced_indexing_arange(node): def local_shape_lift_advanced_indexing_arange(node):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论