提交 ba5f2a3a authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #2773 from nouiz/fix_nan

[REGRESSION] Fix nan
...@@ -2978,8 +2978,8 @@ def local_mul_switch_sink(node): ...@@ -2978,8 +2978,8 @@ def local_mul_switch_sink(node):
if i.owner and i.owner.op == T.switch: if i.owner and i.owner.op == T.switch:
switch = i.owner switch = i.owner
try: try:
if (isinstance(switch.inputs[0], Constant) and if (get_scalar_constant_value(
get_scalar_constant_value(switch.inputs[1]) == 0.): switch.inputs[1], only_process_constants=True) == 0.):
listmul = node.inputs[:idx] + node.inputs[idx + 1:] listmul = node.inputs[:idx] + node.inputs[idx + 1:]
fct = [T.switch(switch.inputs[0], 0, fct = [T.switch(switch.inputs[0], 0,
T.mul(*(listmul + [switch.inputs[2]])))] T.mul(*(listmul + [switch.inputs[2]])))]
...@@ -2988,8 +2988,8 @@ def local_mul_switch_sink(node): ...@@ -2988,8 +2988,8 @@ def local_mul_switch_sink(node):
except NotScalarConstantError: except NotScalarConstantError:
pass pass
try: try:
if (isinstance(switch.inputs[2], Constant) and if (get_scalar_constant_value(
get_scalar_constant_value(switch.inputs[2]) == 0.): switch.inputs[2], only_process_constants=True) == 0.):
listmul = node.inputs[:idx] + node.inputs[idx + 1:] listmul = node.inputs[:idx] + node.inputs[idx + 1:]
fct = [T.switch(switch.inputs[0], fct = [T.switch(switch.inputs[0],
T.mul(*(listmul + [switch.inputs[1]])), 0)] T.mul(*(listmul + [switch.inputs[1]])), 0)]
...@@ -3676,6 +3676,9 @@ class Canonizer(gof.LocalOptimizer): ...@@ -3676,6 +3676,9 @@ class Canonizer(gof.LocalOptimizer):
new = _fill_chain(new, node.inputs)[0] new = _fill_chain(new, node.inputs)[0]
if new.type == out.type: if new.type == out.type:
# This happen with test
# theano/tensor/tests/test_opt.py:T_local_switch_sink
new.tag.values_eq_approx = values_eq_approx_remove_inf_nan
return [new] return [new]
else: else:
_logger.warning(' '.join(('CANONIZE FAILED: new, out = ', _logger.warning(' '.join(('CANONIZE FAILED: new, out = ',
......
...@@ -3962,23 +3962,33 @@ class T_local_switch_sink(unittest.TestCase): ...@@ -3962,23 +3962,33 @@ class T_local_switch_sink(unittest.TestCase):
def test_local_mul_switch_sink(self): def test_local_mul_switch_sink(self):
c = T.dscalar() c = T.dscalar()
idx = 0 idx = 0
for condition in [(T.dmatrix('cond'), self.condm), (T.dvector('cond'), self.condv), (T.dscalar('cond'), self.conds)]: for condition in [(T.dmatrix('cond'), self.condm),
for x in [(T.dmatrix('x'), self.xm), (T.dvector('x'), self.xv), (T.dscalar('x'), self.xs)]: (T.dvector('cond'), self.condv),
y = T.mul(T.switch(condition[0] > 0, 1. * x[0], (T.dscalar('cond'), self.conds)]:
0. * x[0]), T.switch(condition[0] > 0, 1.*x[0], T.log(c)*x[0])) for x in [(T.dmatrix('x'), self.xm), (T.dvector('x'), self.xv),
f = theano.function([condition[0], x[0], c] (T.dscalar('x'), self.xs)]:
, [y], mode=self.mode) y = T.mul(T.switch(condition[0] > 0, 1. * x[0], 0. * x[0]),
T.switch(condition[0] > 0,
1. * x[0], T.log(c) * x[0]))
f = theano.function([condition[0], x[0], c],
[y], mode=self.mode)
if type(condition[1]) is list: if type(condition[1]) is list:
for i in range(len(condition[1])): for i in range(len(condition[1])):
res = f(condition[1][i], x[1], -1) res = f(condition[1][i], x[1], -1)
assert (res == numpy. assert (res == numpy.asarray(
asarray(self.resm[idx][i])).sum() == self.resm[idx][i].size self.resm[idx][i])).sum() == self.resm[idx][i].size
else: else:
res = f(condition[1], x[1], -1) res = f(condition[1], x[1], -1)
assert (res == numpy.asarray(self. assert (res == numpy.asarray(self.
resm[idx])).sum() == self.resm[idx].size resm[idx])).sum() == self.resm[idx].size
idx += 1 idx += 1
# This case caused a missed optimization in the past.
x = T.dscalar('x')
y = T.switch(x < 7, x, T.sqrt(x - 7))
f = theano.function([x], T.grad(y, x), self.mode)
assert f(5) == 1, f(5)
@attr('slow') @attr('slow')
def test_local_div_switch_sink(self): def test_local_div_switch_sink(self):
c = T.dscalar() c = T.dscalar()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论