提交 0044349f authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #4554 from nouiz/scan_opt

Make recent scan opt change deterministic.
...@@ -202,7 +202,7 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -202,7 +202,7 @@ def remove_constants_and_unused_inputs_scan(node):
# DEBUG CHECK # DEBUG CHECK
nwScan = scan_op.Scan(nw_inner, op_outs, nw_info) nwScan = scan_op.Scan(nw_inner, op_outs, nw_info)
nw_outs = nwScan(*nw_outer, **dict(return_list=True)) nw_outs = nwScan(*nw_outer, **dict(return_list=True))
return dict([("remove", [node])] + list(zip(node.outputs, nw_outs))) return OrderedDict([("remove", [node])] + list(zip(node.outputs, nw_outs)))
else: else:
return False return False
...@@ -2072,8 +2072,8 @@ def scan_merge_inouts(node): ...@@ -2072,8 +2072,8 @@ def scan_merge_inouts(node):
new_outer_out_mit_mot.append(outer_omm) new_outer_out_mit_mot.append(outer_omm)
na.outer_out_mit_mot = new_outer_out_mit_mot na.outer_out_mit_mot = new_outer_out_mit_mot
if remove: if remove:
return dict([("remove", remove)] + return OrderedDict([("remove", remove)] +
list(zip(node.outputs, na.outer_outputs))) list(zip(node.outputs, na.outer_outputs)))
return na.outer_outputs return na.outer_outputs
......
...@@ -152,6 +152,7 @@ from theano.tensor import basic as T ...@@ -152,6 +152,7 @@ from theano.tensor import basic as T
from theano.tensor.blas_headers import blas_header_text from theano.tensor.blas_headers import blas_header_text
from theano.tensor.blas_headers import blas_header_version from theano.tensor.blas_headers import blas_header_version
from theano.tensor.opt import in2out, local_dimshuffle_lift from theano.tensor.opt import in2out, local_dimshuffle_lift
from theano.tensor.type import values_eq_approx_remove_inf_nan
_logger = logging.getLogger('theano.tensor.blas') _logger = logging.getLogger('theano.tensor.blas')
...@@ -1465,6 +1466,7 @@ class GemmOptimizer(Optimizer): ...@@ -1465,6 +1466,7 @@ class GemmOptimizer(Optimizer):
if new_outputs: if new_outputs:
new_outputs, old_dot22 = new_outputs new_outputs, old_dot22 = new_outputs
assert len(new_outputs) == len(node.outputs) assert len(new_outputs) == len(node.outputs)
new_outputs[0].tag.values_eq_approx = values_eq_approx_remove_inf_nan
try: try:
fgraph.replace_all_validate_remove( fgraph.replace_all_validate_remove(
list(zip(node.outputs, new_outputs)), list(zip(node.outputs, new_outputs)),
......
...@@ -132,9 +132,11 @@ class TestCGemv(TestCase, TestOptimizationMixin): ...@@ -132,9 +132,11 @@ class TestCGemv(TestCase, TestOptimizationMixin):
self.a = tensor.tensor(dtype=dtype, broadcastable=()) self.a = tensor.tensor(dtype=dtype, broadcastable=())
def test_nan_beta_0(self): def test_nan_beta_0(self):
mode = self.mode.including()
mode.check_isfinite = False
f = theano.function([self.A, self.x, self.y, self.a], f = theano.function([self.A, self.x, self.y, self.a],
self.a*self.y + theano.dot(self.A, self.x), self.a*self.y + theano.dot(self.A, self.x),
mode=self.mode) mode=mode)
Aval = numpy.ones((3, 1), dtype=self.dtype) Aval = numpy.ones((3, 1), dtype=self.dtype)
xval = numpy.ones((1,), dtype=self.dtype) xval = numpy.ones((1,), dtype=self.dtype)
yval = float('NaN') * numpy.ones((3,), dtype=self.dtype) yval = float('NaN') * numpy.ones((3,), dtype=self.dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论