提交 be7831fc authored 作者: Mohammad Pezeshki's avatar Mohammad Pezeshki

extract_constant is now used

上级 1da11e43
...@@ -47,7 +47,8 @@ from theano.tensor.type import (values_eq_approx_remove_inf, ...@@ -47,7 +47,8 @@ from theano.tensor.type import (values_eq_approx_remove_inf,
from theano.gof.opt import (Optimizer, pre_constant_merge, from theano.gof.opt import (Optimizer, pre_constant_merge,
pre_greedy_local_optimizer) pre_greedy_local_optimizer)
from theano.gof import toolbox from theano.gof import toolbox
from theano.tensor.basic import Alloc, get_scalar_constant_value, ShapeError, NotScalarConstantError from theano.tensor.basic import (Alloc, get_scalar_constant_value, ShapeError,
extract_constant, NotScalarConstantError)
from six import StringIO from six import StringIO
_logger = logging.getLogger('theano.tensor.opt') _logger = logging.getLogger('theano.tensor.opt')
...@@ -1699,7 +1700,7 @@ def local_useless_alloc(node): ...@@ -1699,7 +1700,7 @@ def local_useless_alloc(node):
output_shape = node.inputs[1:] output_shape = node.inputs[1:]
num_dims_with_size_1_added_to_left = 0 num_dims_with_size_1_added_to_left = 0
for i in range(len(output_shape)): for i in range(len(output_shape)):
if output_shape[i].get_scalar_constant_value() == 1: if extract_constant(output_shape[i], only_process_constants=True) == 1:
num_dims_with_size_1_added_to_left += 1 num_dims_with_size_1_added_to_left += 1
else: else:
break break
......
...@@ -3548,7 +3548,7 @@ class Test_local_useless_alloc(unittest.TestCase): ...@@ -3548,7 +3548,7 @@ class Test_local_useless_alloc(unittest.TestCase):
self._verify_stack_trace(f) self._verify_stack_trace(f)
def test_useless_alloc_1_on_broadcastable(self): def test_useless_alloc_with_shape_one(self):
alloc_lift = out2in(local_useless_alloc) alloc_lift = out2in(local_useless_alloc)
x = shared(self.rng.randn(2,)) x = shared(self.rng.randn(2,))
y = shared(self.rng.randn()) y = shared(self.rng.randn())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论