提交 ee07a356 authored 作者: Ziye Fan's avatar Ziye Fan

register local_mul_switch_sink in specialize optdb

上级 83bc6a46
......@@ -3293,6 +3293,7 @@ def local_remove_switch_const_cond(node):
return False
@register_specialize
@register_canonicalize
@gof.local_optimizer([T.mul])
def local_mul_switch_sink(node):
......@@ -3322,6 +3323,7 @@ def local_mul_switch_sink(node):
return False
for idx, i in enumerate(node.inputs):
if i.owner and i.owner.op == T.switch:
# import ipdb;ipdb.set_trace()
switch = i.owner
try:
if (get_scalar_constant_value(
......
......@@ -4102,7 +4102,7 @@ class T_local_switch_sink(unittest.TestCase):
self.resm[idx][i])).sum() == self.resm[idx][i].size
else:
res = f(condition[1], x[1], -1)
# theano.printing.debugprint(f.maker.fgraph.outputs[0])
theano.printing.debugprint(f.maker.fgraph.outputs[0])
# import ipdb;ipdb.set_trace()
assert (res == numpy.asarray(self.
resm[idx])).sum() == self.resm[idx].size
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论