提交 c9f28f19 authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix stability opt not being applied

上级 c5201091
......@@ -2978,8 +2978,8 @@ def local_mul_switch_sink(node):
if i.owner and i.owner.op == T.switch:
switch = i.owner
try:
if (isinstance(switch.inputs[0], Constant) and
get_scalar_constant_value(switch.inputs[1]) == 0.):
if (get_scalar_constant_value(
switch.inputs[1], only_process_constants=True) == 0.):
listmul = node.inputs[:idx] + node.inputs[idx + 1:]
fct = [T.switch(switch.inputs[0], 0,
T.mul(*(listmul + [switch.inputs[2]])))]
......@@ -2988,8 +2988,8 @@ def local_mul_switch_sink(node):
except NotScalarConstantError:
pass
try:
if (isinstance(switch.inputs[2], Constant) and
get_scalar_constant_value(switch.inputs[2]) == 0.):
if (get_scalar_constant_value(
switch.inputs[2], only_process_constants=True) == 0.):
listmul = node.inputs[:idx] + node.inputs[idx + 1:]
fct = [T.switch(switch.inputs[0],
T.mul(*(listmul + [switch.inputs[1]])), 0)]
......
......@@ -3983,6 +3983,12 @@ class T_local_switch_sink(unittest.TestCase):
resm[idx])).sum() == self.resm[idx].size
idx += 1
# This case caused a missed optimization in the past.
x = T.dscalar('x')
y = T.switch(x < 7, x, T.sqrt(x - 7))
f = theano.function([x], T.grad(y, x))
assert f(5) == 1
@attr('slow')
def test_local_div_switch_sink(self):
c = T.dscalar()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论