提交 334d3fdf authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove BadOptimization error when types don't match in FunctionGraph

上级 574174c4
"""A container for specifying and manipulating a graph with distinct inputs and outputs."""
import time
from collections import OrderedDict
from io import StringIO
import aesara
from aesara.configdefaults import config
......@@ -475,40 +474,7 @@ class FunctionGraph(utils.MetaObject):
if verbose:
print(reason, var, new_var)
if var.type != new_var.type:
new_var_2 = var.type.convert_variable(new_var)
# We still make sure that the type converts correctly
if new_var_2 is None or new_var_2.type != var.type:
done = dict()
used_ids = dict()
old = aesara.compile.debugmode.debugprint(
var,
prefix=" ",
depth=6,
file=StringIO(),
done=done,
print_type=True,
used_ids=used_ids,
).getvalue()
new = aesara.compile.debugmode.debugprint(
new_var,
prefix=" ",
depth=6,
file=StringIO(),
done=done,
print_type=True,
used_ids=used_ids,
).getvalue()
raise toolbox.BadOptimization(
var,
new_var,
None,
None,
str(reason) + ". The type of the replacement must be the same.",
old,
new,
)
new_var = new_var_2
new_var = var.type.filter_variable(new_var, allow_convert=True)
if var not in self.variables:
# this variable isn't in the graph... don't raise an
......
......@@ -11,6 +11,7 @@ import numpy as np
import aesara
from aesara.configdefaults import config
from aesara.graph.basic import (
Variable,
equal_computations,
graph_inputs,
io_toposort,
......@@ -105,13 +106,41 @@ class BadOptimization(Exception):
new_graph=None,
):
super().__init__()
self.old_r = old_r
self.new_r = new_r
self.old_r_val = old_r_val
self.new_r_val = new_r_val
self.reason = reason
self.old_graph = old_graph
self.new_graph = new_graph
done = dict()
used_ids = dict()
if isinstance(old_r, Variable):
self.old_graph = aesara.compile.debugmode.debugprint(
old_r,
prefix=" ",
depth=6,
file=StringIO(),
done=done,
print_type=True,
used_ids=used_ids,
).getvalue()
else:
self.old_graph = None
if isinstance(new_r, Variable):
self.new_graph = aesara.compile.debugmode.debugprint(
new_r,
prefix=" ",
depth=6,
file=StringIO(),
done=done,
print_type=True,
used_ids=used_ids,
).getvalue()
else:
self.new_graph = None
# To allow extending the error message of an existing error.
self.full_err = None
......
......@@ -5,7 +5,14 @@ import pytest
import aesara
import aesara.tensor as aet
from aesara.compile import debugmode
from aesara.compile.debugmode import (
BadDestroyMap,
BadThunkOutput,
BadViewMap,
DebugMode,
InvalidValueError,
StochasticOrder,
)
from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable
from aesara.graph.op import COp, Op
......@@ -19,7 +26,7 @@ from tests import unittest_tools as utt
def test_debugmode_basic():
x = dvector()
f = aesara.function([x], ((2.0 * x) + 7) / 2.0, mode=debugmode.DebugMode())
f = aesara.function([x], ((2.0 * x) + 7) / 2.0, mode=DebugMode())
f([1, 2])
......@@ -210,18 +217,18 @@ def test_badthunkoutput():
f_good = aesara.function(
[a, b],
off_by_half(a, b),
mode=debugmode.DebugMode(check_c_code=config.cxx),
mode=DebugMode(check_c_code=config.cxx),
)
f_inconsistent = aesara.function(
[a, b],
inconsistent(a, b),
mode=debugmode.DebugMode(check_c_code=config.cxx),
mode=DebugMode(check_c_code=config.cxx),
)
# this should evaluate with no error
f_good([1.0, 2.0, 3.0], [2, 3, 4])
with pytest.raises(debugmode.BadThunkOutput) as einfo:
with pytest.raises(BadThunkOutput) as einfo:
f_inconsistent([1.0, 2.0, 3.0], [2, 3, 4])
assert einfo.value.r.owner.op is inconsistent
......@@ -241,9 +248,9 @@ def test_badoptimization():
a = dvector()
b = dvector()
f = aesara.function([a, b], a + b, mode=debugmode.DebugMode(optimizer=opt))
f = aesara.function([a, b], a + b, mode=DebugMode(optimizer=opt))
with pytest.raises(debugmode.BadOptimization) as einfo:
with pytest.raises(BadOptimization) as einfo:
f(
[1.0, 2.0, 3.0],
[2, 3, 4],
......@@ -282,7 +289,7 @@ def test_badoptimization_opt_err():
a = dvector()
b = dvector()
f = aesara.function([a, b], a + b, mode=debugmode.DebugMode(optimizer=opt))
f = aesara.function([a, b], a + b, mode=DebugMode(optimizer=opt))
with pytest.raises(ValueError, match=r"insert_bigger_b_add"):
f(
[1.0, 2.0, 3.0],
......@@ -290,12 +297,12 @@ def test_badoptimization_opt_err():
)
# Test that opt that do an illegal change still get the error from graph.
with pytest.raises(BadOptimization, match=r"insert_bad_dtype") as einfo:
with pytest.raises(TypeError) as einfo:
with config.change_flags(on_opt_error="raise"):
f2 = aesara.function(
[a, b],
a + b,
mode=debugmode.DebugMode(optimizer=opt2, stability_patience=1),
mode=DebugMode(optimizer=opt2, stability_patience=1),
)
f2(
[1.0, 2.0, 3.0],
......@@ -303,7 +310,7 @@ def test_badoptimization_opt_err():
)
# Test that we can reraise the error with an extended message
with pytest.raises(BadOptimization):
with pytest.raises(TypeError):
e = einfo.value
new_e = e.__class__("TTT" + str(e))
exc_type, exc_value, exc_trace = sys.exc_info()
......@@ -332,11 +339,11 @@ def test_stochasticoptimization():
a = dvector()
b = dvector()
with pytest.raises(debugmode.StochasticOrder):
with pytest.raises(StochasticOrder):
aesara.function(
[a, b],
add(a, b),
mode=debugmode.DebugMode(
mode=DebugMode(
optimizer=opt,
check_c_code=True,
stability_patience=max(2, config.DebugMode__patience),
......@@ -349,7 +356,7 @@ def test_stochasticoptimization():
)
def test_just_c_code():
x = dvector()
f = aesara.function([x], wb2(x), mode=debugmode.DebugMode(check_py_code=False))
f = aesara.function([x], wb2(x), mode=DebugMode(check_py_code=False))
assert np.all(f([1, 2]) == [2, 4])
......@@ -369,7 +376,7 @@ def test_baddestroymap():
y = dvector()
f = aesara.function([x, y], BadAdd()(x, y), mode="DEBUG_MODE")
with pytest.raises(debugmode.BadDestroyMap):
with pytest.raises(BadDestroyMap):
f([1, 2], [3, 4])
......@@ -378,8 +385,8 @@ def test_baddestroymap():
)
def test_baddestroymap_c():
x = dvector()
f = aesara.function([x], wb2i(x), mode=debugmode.DebugMode(check_py_code=False))
with pytest.raises(debugmode.BadDestroyMap):
f = aesara.function([x], wb2i(x), mode=DebugMode(check_py_code=False))
with pytest.raises(BadDestroyMap):
assert np.all(f([1, 2]) == [2, 4])
......@@ -408,14 +415,14 @@ class TestViewMap:
x = dvector()
y = dvector()
f = aesara.function([x, y], self.BadAddRef()(x, y), mode="DEBUG_MODE")
with pytest.raises(debugmode.BadViewMap):
with pytest.raises(BadViewMap):
f([1, 2], [3, 4])
def test_badviewmap_slice(self):
x = dvector()
y = dvector()
f = aesara.function([x, y], self.BadAddSlice()(x, y), mode="DEBUG_MODE")
with pytest.raises(debugmode.BadViewMap):
with pytest.raises(BadViewMap):
f([1, 2], [3, 4])
def test_goodviewmap(self):
......@@ -432,8 +439,8 @@ class TestViewMap:
)
def test_badviewmap_c(self):
x = dvector()
f = aesara.function([x], wb1i(x), mode=debugmode.DebugMode(check_py_code=False))
with pytest.raises(debugmode.BadViewMap):
f = aesara.function([x], wb1i(x), mode=DebugMode(check_py_code=False))
with pytest.raises(BadViewMap):
f([1, 2])
def test_aliased_outputs_ok(self):
......@@ -537,7 +544,7 @@ class TestViewMap:
out = bad_xy0 * 2 + bad_xy1 * 2
f = aesara.function([x, y], out, mode="DEBUG_MODE")
with pytest.raises(debugmode.BadViewMap):
with pytest.raises(BadViewMap):
f([1, 2, 3, 4], [5, 6, 7, 8])
# the situation can be rescued by picking one of the inputs and
......@@ -569,16 +576,16 @@ class TestCheckIsfinite:
# ValueError
# if not, DebugMode will check internally, and raise InvalidValueError
# passing an invalid value as an input should trigger ValueError
with pytest.raises(debugmode.InvalidValueError):
with pytest.raises(InvalidValueError):
f(np.log([3, -4, 5]).astype(config.floatX))
with pytest.raises(debugmode.InvalidValueError):
with pytest.raises(InvalidValueError):
f((np.asarray([0, 1.0, 0]) / 0).astype(config.floatX))
with pytest.raises(debugmode.InvalidValueError):
with pytest.raises(InvalidValueError):
f((np.asarray([1.0, 1.0, 1.0]) / 0).astype(config.floatX))
# generating an invalid value internally should trigger
# InvalidValueError
with pytest.raises(debugmode.InvalidValueError):
with pytest.raises(InvalidValueError):
g(np.asarray([3, -4, 5], dtype=config.floatX))
# this should disable the exception
......@@ -589,9 +596,7 @@ class TestCheckIsfinite:
def test_check_isfinite_disabled(self):
x = dvector()
f = aesara.function(
[x], (x + 2) * 5, mode=debugmode.DebugMode(check_isfinite=False)
)
f = aesara.function([x], (x + 2) * 5, mode=DebugMode(check_isfinite=False))
# nan should go through
f(np.log([3, -4, 5]))
......@@ -743,7 +748,7 @@ class TestPreallocatedOutput:
b_val = self.rng.randn(7, 7).astype("float32")
# Should work
mode = debugmode.DebugMode(check_preallocated_output=["c_contiguous"])
mode = DebugMode(check_preallocated_output=["c_contiguous"])
f = aesara.function([a, b], out, mode=mode)
f(a_val, b_val)
......@@ -752,12 +757,12 @@ class TestPreallocatedOutput:
# Should raise an Exception, since the output buffer is
# used incorrectly.
mode = debugmode.DebugMode(check_preallocated_output=["f_contiguous"])
mode = DebugMode(check_preallocated_output=["f_contiguous"])
f = aesara.function([a, b], out, mode=mode)
if config.cxx:
with pytest.raises(debugmode.BadThunkOutput):
with pytest.raises(BadThunkOutput):
f(a_val, b_val)
else:
# The python code of this op is good.
......@@ -774,7 +779,7 @@ class TestPreallocatedOutput:
b_val = self.rng.randn(7, 7).astype("float32")
# Should work
mode = debugmode.DebugMode(check_preallocated_output=["c_contiguous"])
mode = DebugMode(check_preallocated_output=["c_contiguous"])
f = aesara.function([a, b], out, mode=mode)
f(a_val, b_val)
......@@ -783,12 +788,12 @@ class TestPreallocatedOutput:
# Should raise an Exception, since the output buffer is
# used incorrectly.
mode = debugmode.DebugMode(check_preallocated_output=["f_contiguous"])
mode = DebugMode(check_preallocated_output=["f_contiguous"])
f = aesara.function([a, b], out, mode=mode)
if config.cxx:
with pytest.raises(debugmode.BadThunkOutput):
with pytest.raises(BadThunkOutput):
f(a_val, b_val)
else:
# The python code of this op is good.
......
......@@ -5,7 +5,6 @@ import pytest
from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph, MissingInputError
from aesara.graph.toolbox import BadOptimization
from tests.graph.utils import MyVariable, MyVariable2, op1, op2, op3
......@@ -216,7 +215,7 @@ class TestFunctionGraph:
var5 = op3(var4, var2, var2)
fg = FunctionGraph([var1, var2], [var3, var5], clone=False)
with pytest.raises(BadOptimization):
with pytest.raises(TypeError):
var0 = MyVariable2("var0")
# The types don't match and one cannot be converted to the other
fg.replace(var3, var0)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论