提交 44d43bcf authored 作者: Frederic Bastien's avatar Frederic Bastien 提交者: Frederic

register opt local_sum_broadcastable only at specialize step

Otherwise this mess with a stability opt.
上级 f0715fba
...@@ -3144,7 +3144,12 @@ def local_cut_useless_reduce(node): ...@@ -3144,7 +3144,12 @@ def local_cut_useless_reduce(node):
return [summed] return [summed]
@register_canonicalize #Enabling this optimization at canonicalization step break this test:
#theano/tensor/tests/test_opt.py:T_local_sum.test_local_sum_broadcast_some_0
# see gh-790 issue.
#
#@register_canonicalize
@register_specialize
@gof.local_optimizer([]) @gof.local_optimizer([])
def local_sum_broadcastable(node): def local_sum_broadcastable(node):
"""Remove reduction over broadcastable dimensions""" """Remove reduction over broadcastable dimensions"""
...@@ -3177,6 +3182,7 @@ def local_sum_broadcastable(node): ...@@ -3177,6 +3182,7 @@ def local_sum_broadcastable(node):
# -- in this case we can remove the reduction completely # -- in this case we can remove the reduction completely
return [new_reduced.astype(odtype)] return [new_reduced.astype(odtype)]
@register_specialize @register_specialize
@gof.local_optimizer([]) @gof.local_optimizer([])
def local_sum_alloc(node): def local_sum_alloc(node):
......
...@@ -3188,7 +3188,8 @@ class test_local_remove_switch_const_cond(unittest.TestCase): ...@@ -3188,7 +3188,8 @@ class test_local_remove_switch_const_cond(unittest.TestCase):
class T_local_sum(unittest.TestCase): class T_local_sum(unittest.TestCase):
def setUp(self): def setUp(self):
self.mode = theano.compile.get_default_mode().including('canonicalize') self.mode = theano.compile.get_default_mode().including('canonicalize',
'specialize')
def test_local_sum_all_to_none(self): def test_local_sum_all_to_none(self):
a = T.tensor3() a = T.tensor3()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论