提交 4f05b2f8 authored 作者: Mohammad Pezeshki's avatar Mohammad Pezeshki

useless alloc is now replaced with dimshuffle

上级 c7b07251
......@@ -47,7 +47,7 @@ from theano.tensor.type import (values_eq_approx_remove_inf,
from theano.gof.opt import (Optimizer, pre_constant_merge,
pre_greedy_local_optimizer)
from theano.gof import toolbox
from theano.tensor.basic import get_scalar_constant_value, ShapeError, NotScalarConstantError
from theano.tensor.basic import Alloc, get_scalar_constant_value, ShapeError, NotScalarConstantError
from six import StringIO
_logger = logging.getLogger('theano.tensor.opt')
......@@ -1683,10 +1683,32 @@ def local_useless_alloc(node):
of the input. This is not needed.
"""
if node.op == T.alloc:
if node.inputs[0].type == node.outputs[0].type:
# We don't need to copy over any stack traces here
return [node.inputs[0]]
op = node.op
if not isinstance(op, Alloc):
return False
input = node.inputs[0]
output = node.outputs[0]
# Check if dtype and broadcast remain the same.
if input.type == output.type:
# We don't need to copy over any stack traces here
return [input]
# Check if alloc adds a broadcastable dimension with shape 1.
output_shape = node.inputs[1:]
num_dims_with_size_1_added_to_left = 0
for i in range(len(output_shape)):
if output_shape[i].value == 1:
num_dims_with_size_1_added_to_left += 1
else:
break
if num_dims_with_size_1_added_to_left > 0:
new_output_shape = output_shape[num_dims_with_size_1_added_to_left:]
inner = op(*([input] + new_output_shape))
dimshuffle_new_order = (['x'] * num_dims_with_size_1_added_to_left +
range(len(new_output_shape)))
return [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)]
# Don't register by default.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论