提交 be7831fc authored 作者: Mohammad Pezeshki's avatar Mohammad Pezeshki

extract_constant is now used

上级 1da11e43
......@@ -47,7 +47,8 @@ from theano.tensor.type import (values_eq_approx_remove_inf,
from theano.gof.opt import (Optimizer, pre_constant_merge,
pre_greedy_local_optimizer)
from theano.gof import toolbox
from theano.tensor.basic import Alloc, get_scalar_constant_value, ShapeError, NotScalarConstantError
from theano.tensor.basic import (Alloc, get_scalar_constant_value, ShapeError,
extract_constant, NotScalarConstantError)
from six import StringIO
_logger = logging.getLogger('theano.tensor.opt')
......@@ -1699,7 +1700,7 @@ def local_useless_alloc(node):
output_shape = node.inputs[1:]
num_dims_with_size_1_added_to_left = 0
for i in range(len(output_shape)):
if output_shape[i].get_scalar_constant_value() == 1:
if extract_constant(output_shape[i], only_process_constants=True) == 1:
num_dims_with_size_1_added_to_left += 1
else:
break
......
......@@ -3548,7 +3548,7 @@ class Test_local_useless_alloc(unittest.TestCase):
self._verify_stack_trace(f)
def test_useless_alloc_1_on_broadcastable(self):
def test_useless_alloc_with_shape_one(self):
alloc_lift = out2in(local_useless_alloc)
x = shared(self.rng.randn(2,))
y = shared(self.rng.randn())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论