提交 334d3fdf authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove BadOptimization error when types don't match in FunctionGraph

上级 574174c4
"""A container for specifying and manipulating a graph with distinct inputs and outputs.""" """A container for specifying and manipulating a graph with distinct inputs and outputs."""
import time import time
from collections import OrderedDict from collections import OrderedDict
from io import StringIO
import aesara import aesara
from aesara.configdefaults import config from aesara.configdefaults import config
...@@ -475,40 +474,7 @@ class FunctionGraph(utils.MetaObject): ...@@ -475,40 +474,7 @@ class FunctionGraph(utils.MetaObject):
if verbose: if verbose:
print(reason, var, new_var) print(reason, var, new_var)
if var.type != new_var.type: new_var = var.type.filter_variable(new_var, allow_convert=True)
new_var_2 = var.type.convert_variable(new_var)
# We still make sure that the type converts correctly
if new_var_2 is None or new_var_2.type != var.type:
done = dict()
used_ids = dict()
old = aesara.compile.debugmode.debugprint(
var,
prefix=" ",
depth=6,
file=StringIO(),
done=done,
print_type=True,
used_ids=used_ids,
).getvalue()
new = aesara.compile.debugmode.debugprint(
new_var,
prefix=" ",
depth=6,
file=StringIO(),
done=done,
print_type=True,
used_ids=used_ids,
).getvalue()
raise toolbox.BadOptimization(
var,
new_var,
None,
None,
str(reason) + ". The type of the replacement must be the same.",
old,
new,
)
new_var = new_var_2
if var not in self.variables: if var not in self.variables:
# this variable isn't in the graph... don't raise an # this variable isn't in the graph... don't raise an
......
...@@ -11,6 +11,7 @@ import numpy as np ...@@ -11,6 +11,7 @@ import numpy as np
import aesara import aesara
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import ( from aesara.graph.basic import (
Variable,
equal_computations, equal_computations,
graph_inputs, graph_inputs,
io_toposort, io_toposort,
...@@ -105,13 +106,41 @@ class BadOptimization(Exception): ...@@ -105,13 +106,41 @@ class BadOptimization(Exception):
new_graph=None, new_graph=None,
): ):
super().__init__() super().__init__()
self.old_r = old_r self.old_r = old_r
self.new_r = new_r self.new_r = new_r
self.old_r_val = old_r_val self.old_r_val = old_r_val
self.new_r_val = new_r_val self.new_r_val = new_r_val
self.reason = reason self.reason = reason
self.old_graph = old_graph
self.new_graph = new_graph done = dict()
used_ids = dict()
if isinstance(old_r, Variable):
self.old_graph = aesara.compile.debugmode.debugprint(
old_r,
prefix=" ",
depth=6,
file=StringIO(),
done=done,
print_type=True,
used_ids=used_ids,
).getvalue()
else:
self.old_graph = None
if isinstance(new_r, Variable):
self.new_graph = aesara.compile.debugmode.debugprint(
new_r,
prefix=" ",
depth=6,
file=StringIO(),
done=done,
print_type=True,
used_ids=used_ids,
).getvalue()
else:
self.new_graph = None
# To allow extending the error message of an existing error. # To allow extending the error message of an existing error.
self.full_err = None self.full_err = None
......
...@@ -5,7 +5,14 @@ import pytest ...@@ -5,7 +5,14 @@ import pytest
import aesara import aesara
import aesara.tensor as aet import aesara.tensor as aet
from aesara.compile import debugmode from aesara.compile.debugmode import (
BadDestroyMap,
BadThunkOutput,
BadViewMap,
DebugMode,
InvalidValueError,
StochasticOrder,
)
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable from aesara.graph.basic import Apply, Variable
from aesara.graph.op import COp, Op from aesara.graph.op import COp, Op
...@@ -19,7 +26,7 @@ from tests import unittest_tools as utt ...@@ -19,7 +26,7 @@ from tests import unittest_tools as utt
def test_debugmode_basic(): def test_debugmode_basic():
x = dvector() x = dvector()
f = aesara.function([x], ((2.0 * x) + 7) / 2.0, mode=debugmode.DebugMode()) f = aesara.function([x], ((2.0 * x) + 7) / 2.0, mode=DebugMode())
f([1, 2]) f([1, 2])
...@@ -210,18 +217,18 @@ def test_badthunkoutput(): ...@@ -210,18 +217,18 @@ def test_badthunkoutput():
f_good = aesara.function( f_good = aesara.function(
[a, b], [a, b],
off_by_half(a, b), off_by_half(a, b),
mode=debugmode.DebugMode(check_c_code=config.cxx), mode=DebugMode(check_c_code=config.cxx),
) )
f_inconsistent = aesara.function( f_inconsistent = aesara.function(
[a, b], [a, b],
inconsistent(a, b), inconsistent(a, b),
mode=debugmode.DebugMode(check_c_code=config.cxx), mode=DebugMode(check_c_code=config.cxx),
) )
# this should evaluate with no error # this should evaluate with no error
f_good([1.0, 2.0, 3.0], [2, 3, 4]) f_good([1.0, 2.0, 3.0], [2, 3, 4])
with pytest.raises(debugmode.BadThunkOutput) as einfo: with pytest.raises(BadThunkOutput) as einfo:
f_inconsistent([1.0, 2.0, 3.0], [2, 3, 4]) f_inconsistent([1.0, 2.0, 3.0], [2, 3, 4])
assert einfo.value.r.owner.op is inconsistent assert einfo.value.r.owner.op is inconsistent
...@@ -241,9 +248,9 @@ def test_badoptimization(): ...@@ -241,9 +248,9 @@ def test_badoptimization():
a = dvector() a = dvector()
b = dvector() b = dvector()
f = aesara.function([a, b], a + b, mode=debugmode.DebugMode(optimizer=opt)) f = aesara.function([a, b], a + b, mode=DebugMode(optimizer=opt))
with pytest.raises(debugmode.BadOptimization) as einfo: with pytest.raises(BadOptimization) as einfo:
f( f(
[1.0, 2.0, 3.0], [1.0, 2.0, 3.0],
[2, 3, 4], [2, 3, 4],
...@@ -282,7 +289,7 @@ def test_badoptimization_opt_err(): ...@@ -282,7 +289,7 @@ def test_badoptimization_opt_err():
a = dvector() a = dvector()
b = dvector() b = dvector()
f = aesara.function([a, b], a + b, mode=debugmode.DebugMode(optimizer=opt)) f = aesara.function([a, b], a + b, mode=DebugMode(optimizer=opt))
with pytest.raises(ValueError, match=r"insert_bigger_b_add"): with pytest.raises(ValueError, match=r"insert_bigger_b_add"):
f( f(
[1.0, 2.0, 3.0], [1.0, 2.0, 3.0],
...@@ -290,12 +297,12 @@ def test_badoptimization_opt_err(): ...@@ -290,12 +297,12 @@ def test_badoptimization_opt_err():
) )
# Test that opt that do an illegal change still get the error from graph. # Test that opt that do an illegal change still get the error from graph.
with pytest.raises(BadOptimization, match=r"insert_bad_dtype") as einfo: with pytest.raises(TypeError) as einfo:
with config.change_flags(on_opt_error="raise"): with config.change_flags(on_opt_error="raise"):
f2 = aesara.function( f2 = aesara.function(
[a, b], [a, b],
a + b, a + b,
mode=debugmode.DebugMode(optimizer=opt2, stability_patience=1), mode=DebugMode(optimizer=opt2, stability_patience=1),
) )
f2( f2(
[1.0, 2.0, 3.0], [1.0, 2.0, 3.0],
...@@ -303,7 +310,7 @@ def test_badoptimization_opt_err(): ...@@ -303,7 +310,7 @@ def test_badoptimization_opt_err():
) )
# Test that we can reraise the error with an extended message # Test that we can reraise the error with an extended message
with pytest.raises(BadOptimization): with pytest.raises(TypeError):
e = einfo.value e = einfo.value
new_e = e.__class__("TTT" + str(e)) new_e = e.__class__("TTT" + str(e))
exc_type, exc_value, exc_trace = sys.exc_info() exc_type, exc_value, exc_trace = sys.exc_info()
...@@ -332,11 +339,11 @@ def test_stochasticoptimization(): ...@@ -332,11 +339,11 @@ def test_stochasticoptimization():
a = dvector() a = dvector()
b = dvector() b = dvector()
with pytest.raises(debugmode.StochasticOrder): with pytest.raises(StochasticOrder):
aesara.function( aesara.function(
[a, b], [a, b],
add(a, b), add(a, b),
mode=debugmode.DebugMode( mode=DebugMode(
optimizer=opt, optimizer=opt,
check_c_code=True, check_c_code=True,
stability_patience=max(2, config.DebugMode__patience), stability_patience=max(2, config.DebugMode__patience),
...@@ -349,7 +356,7 @@ def test_stochasticoptimization(): ...@@ -349,7 +356,7 @@ def test_stochasticoptimization():
) )
def test_just_c_code(): def test_just_c_code():
x = dvector() x = dvector()
f = aesara.function([x], wb2(x), mode=debugmode.DebugMode(check_py_code=False)) f = aesara.function([x], wb2(x), mode=DebugMode(check_py_code=False))
assert np.all(f([1, 2]) == [2, 4]) assert np.all(f([1, 2]) == [2, 4])
...@@ -369,7 +376,7 @@ def test_baddestroymap(): ...@@ -369,7 +376,7 @@ def test_baddestroymap():
y = dvector() y = dvector()
f = aesara.function([x, y], BadAdd()(x, y), mode="DEBUG_MODE") f = aesara.function([x, y], BadAdd()(x, y), mode="DEBUG_MODE")
with pytest.raises(debugmode.BadDestroyMap): with pytest.raises(BadDestroyMap):
f([1, 2], [3, 4]) f([1, 2], [3, 4])
...@@ -378,8 +385,8 @@ def test_baddestroymap(): ...@@ -378,8 +385,8 @@ def test_baddestroymap():
) )
def test_baddestroymap_c(): def test_baddestroymap_c():
x = dvector() x = dvector()
f = aesara.function([x], wb2i(x), mode=debugmode.DebugMode(check_py_code=False)) f = aesara.function([x], wb2i(x), mode=DebugMode(check_py_code=False))
with pytest.raises(debugmode.BadDestroyMap): with pytest.raises(BadDestroyMap):
assert np.all(f([1, 2]) == [2, 4]) assert np.all(f([1, 2]) == [2, 4])
...@@ -408,14 +415,14 @@ class TestViewMap: ...@@ -408,14 +415,14 @@ class TestViewMap:
x = dvector() x = dvector()
y = dvector() y = dvector()
f = aesara.function([x, y], self.BadAddRef()(x, y), mode="DEBUG_MODE") f = aesara.function([x, y], self.BadAddRef()(x, y), mode="DEBUG_MODE")
with pytest.raises(debugmode.BadViewMap): with pytest.raises(BadViewMap):
f([1, 2], [3, 4]) f([1, 2], [3, 4])
def test_badviewmap_slice(self): def test_badviewmap_slice(self):
x = dvector() x = dvector()
y = dvector() y = dvector()
f = aesara.function([x, y], self.BadAddSlice()(x, y), mode="DEBUG_MODE") f = aesara.function([x, y], self.BadAddSlice()(x, y), mode="DEBUG_MODE")
with pytest.raises(debugmode.BadViewMap): with pytest.raises(BadViewMap):
f([1, 2], [3, 4]) f([1, 2], [3, 4])
def test_goodviewmap(self): def test_goodviewmap(self):
...@@ -432,8 +439,8 @@ class TestViewMap: ...@@ -432,8 +439,8 @@ class TestViewMap:
) )
def test_badviewmap_c(self): def test_badviewmap_c(self):
x = dvector() x = dvector()
f = aesara.function([x], wb1i(x), mode=debugmode.DebugMode(check_py_code=False)) f = aesara.function([x], wb1i(x), mode=DebugMode(check_py_code=False))
with pytest.raises(debugmode.BadViewMap): with pytest.raises(BadViewMap):
f([1, 2]) f([1, 2])
def test_aliased_outputs_ok(self): def test_aliased_outputs_ok(self):
...@@ -537,7 +544,7 @@ class TestViewMap: ...@@ -537,7 +544,7 @@ class TestViewMap:
out = bad_xy0 * 2 + bad_xy1 * 2 out = bad_xy0 * 2 + bad_xy1 * 2
f = aesara.function([x, y], out, mode="DEBUG_MODE") f = aesara.function([x, y], out, mode="DEBUG_MODE")
with pytest.raises(debugmode.BadViewMap): with pytest.raises(BadViewMap):
f([1, 2, 3, 4], [5, 6, 7, 8]) f([1, 2, 3, 4], [5, 6, 7, 8])
# the situation can be rescued by picking one of the inputs and # the situation can be rescued by picking one of the inputs and
...@@ -569,16 +576,16 @@ class TestCheckIsfinite: ...@@ -569,16 +576,16 @@ class TestCheckIsfinite:
# ValueError # ValueError
# if not, DebugMode will check internally, and raise InvalidValueError # if not, DebugMode will check internally, and raise InvalidValueError
# passing an invalid value as an input should trigger ValueError # passing an invalid value as an input should trigger ValueError
with pytest.raises(debugmode.InvalidValueError): with pytest.raises(InvalidValueError):
f(np.log([3, -4, 5]).astype(config.floatX)) f(np.log([3, -4, 5]).astype(config.floatX))
with pytest.raises(debugmode.InvalidValueError): with pytest.raises(InvalidValueError):
f((np.asarray([0, 1.0, 0]) / 0).astype(config.floatX)) f((np.asarray([0, 1.0, 0]) / 0).astype(config.floatX))
with pytest.raises(debugmode.InvalidValueError): with pytest.raises(InvalidValueError):
f((np.asarray([1.0, 1.0, 1.0]) / 0).astype(config.floatX)) f((np.asarray([1.0, 1.0, 1.0]) / 0).astype(config.floatX))
# generating an invalid value internally should trigger # generating an invalid value internally should trigger
# InvalidValueError # InvalidValueError
with pytest.raises(debugmode.InvalidValueError): with pytest.raises(InvalidValueError):
g(np.asarray([3, -4, 5], dtype=config.floatX)) g(np.asarray([3, -4, 5], dtype=config.floatX))
# this should disable the exception # this should disable the exception
...@@ -589,9 +596,7 @@ class TestCheckIsfinite: ...@@ -589,9 +596,7 @@ class TestCheckIsfinite:
def test_check_isfinite_disabled(self): def test_check_isfinite_disabled(self):
x = dvector() x = dvector()
f = aesara.function( f = aesara.function([x], (x + 2) * 5, mode=DebugMode(check_isfinite=False))
[x], (x + 2) * 5, mode=debugmode.DebugMode(check_isfinite=False)
)
# nan should go through # nan should go through
f(np.log([3, -4, 5])) f(np.log([3, -4, 5]))
...@@ -743,7 +748,7 @@ class TestPreallocatedOutput: ...@@ -743,7 +748,7 @@ class TestPreallocatedOutput:
b_val = self.rng.randn(7, 7).astype("float32") b_val = self.rng.randn(7, 7).astype("float32")
# Should work # Should work
mode = debugmode.DebugMode(check_preallocated_output=["c_contiguous"]) mode = DebugMode(check_preallocated_output=["c_contiguous"])
f = aesara.function([a, b], out, mode=mode) f = aesara.function([a, b], out, mode=mode)
f(a_val, b_val) f(a_val, b_val)
...@@ -752,12 +757,12 @@ class TestPreallocatedOutput: ...@@ -752,12 +757,12 @@ class TestPreallocatedOutput:
# Should raise an Exception, since the output buffer is # Should raise an Exception, since the output buffer is
# used incorrectly. # used incorrectly.
mode = debugmode.DebugMode(check_preallocated_output=["f_contiguous"]) mode = DebugMode(check_preallocated_output=["f_contiguous"])
f = aesara.function([a, b], out, mode=mode) f = aesara.function([a, b], out, mode=mode)
if config.cxx: if config.cxx:
with pytest.raises(debugmode.BadThunkOutput): with pytest.raises(BadThunkOutput):
f(a_val, b_val) f(a_val, b_val)
else: else:
# The python code of this op is good. # The python code of this op is good.
...@@ -774,7 +779,7 @@ class TestPreallocatedOutput: ...@@ -774,7 +779,7 @@ class TestPreallocatedOutput:
b_val = self.rng.randn(7, 7).astype("float32") b_val = self.rng.randn(7, 7).astype("float32")
# Should work # Should work
mode = debugmode.DebugMode(check_preallocated_output=["c_contiguous"]) mode = DebugMode(check_preallocated_output=["c_contiguous"])
f = aesara.function([a, b], out, mode=mode) f = aesara.function([a, b], out, mode=mode)
f(a_val, b_val) f(a_val, b_val)
...@@ -783,12 +788,12 @@ class TestPreallocatedOutput: ...@@ -783,12 +788,12 @@ class TestPreallocatedOutput:
# Should raise an Exception, since the output buffer is # Should raise an Exception, since the output buffer is
# used incorrectly. # used incorrectly.
mode = debugmode.DebugMode(check_preallocated_output=["f_contiguous"]) mode = DebugMode(check_preallocated_output=["f_contiguous"])
f = aesara.function([a, b], out, mode=mode) f = aesara.function([a, b], out, mode=mode)
if config.cxx: if config.cxx:
with pytest.raises(debugmode.BadThunkOutput): with pytest.raises(BadThunkOutput):
f(a_val, b_val) f(a_val, b_val)
else: else:
# The python code of this op is good. # The python code of this op is good.
......
...@@ -5,7 +5,6 @@ import pytest ...@@ -5,7 +5,6 @@ import pytest
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph, MissingInputError from aesara.graph.fg import FunctionGraph, MissingInputError
from aesara.graph.toolbox import BadOptimization
from tests.graph.utils import MyVariable, MyVariable2, op1, op2, op3 from tests.graph.utils import MyVariable, MyVariable2, op1, op2, op3
...@@ -216,7 +215,7 @@ class TestFunctionGraph: ...@@ -216,7 +215,7 @@ class TestFunctionGraph:
var5 = op3(var4, var2, var2) var5 = op3(var4, var2, var2)
fg = FunctionGraph([var1, var2], [var3, var5], clone=False) fg = FunctionGraph([var1, var2], [var3, var5], clone=False)
with pytest.raises(BadOptimization): with pytest.raises(TypeError):
var0 = MyVariable2("var0") var0 = MyVariable2("var0")
# The types don't match and one cannot be converted to the other # The types don't match and one cannot be converted to the other
fg.replace(var3, var0) fg.replace(var3, var0)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论