提交 d9fe1b74 authored 作者: Olivier Mastropietro's avatar Olivier Mastropietro

Forgot the if condition for the NoneType check

上级 a384c80d
...@@ -157,17 +157,19 @@ def local_dimshuffle_alloc(node): ...@@ -157,17 +157,19 @@ def local_dimshuffle_alloc(node):
dimshuffle{x, 0, 1}(alloc([3 4], 3, 2) => alloc([3 4], 1, 3, 2) dimshuffle{x, 0, 1}(alloc([3 4], 3, 2) => alloc([3 4], 1, 3, 2)
""" """
if isinstance(node.op, T.DimShuffle) and isinstance(node.inputs[0].owner.op, T.Alloc): if isinstance(node.op, T.DimShuffle) and node.inputs[0].owner:
# check if it only adds dimension to the left input_ = node.inputs[0]
new_order = node.op.new_order if isinstance(input_.owner.op, T.Alloc):
expected_new_order = ('x',) * (len(new_order) - node.inputs[0].ndim) + \ # check if it only adds dimension to the left
tuple(range(node.inputs[0].ndim)) new_order = node.op.new_order
if new_order != expected_new_order: expected_new_order = ('x',) * (len(new_order) - input_.ndim) + \
return False tuple(range(input_.ndim))
if new_order != expected_new_order:
# count numbers of 'x' return False
nb_new_dims = len(new_order) - node.inputs[0].ndim
new_shape_input = (1,) * nb_new_dims + tuple(node.inputs[0].owner.inputs[1:]) # count numbers of 'x'
nb_new_dims = len(new_order) - input_.ndim
return [T.alloc(node.inputs[0].owner.inputs[0], *new_shape_input)] new_shape_input = (1,) * nb_new_dims + tuple(input_.owner.inputs[1:])
return [T.alloc(input_.owner.inputs[0], *new_shape_input)]
return False return False
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论