提交 8cbd9840 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add basic optimization canonicalizations to "fast_compile" mode by default

上级 842d3bcd
......@@ -507,7 +507,9 @@ def register_canonicalize(lopt, *tags, **kwargs):
return register
else:
name = kwargs.pop("name", None) or lopt.__name__
compile.optdb["canonicalize"].register(name, lopt, "fast_run", *tags, **kwargs)
compile.optdb["canonicalize"].register(
name, lopt, "fast_run", "fast_compile", *tags, **kwargs
)
return lopt
......
......@@ -48,7 +48,7 @@ class TestPyDotFormatter:
assert len(sub_graphs) == 2
ofg1, ofg2 = sub_graphs
if config.mode == "FAST_COMPILE":
assert len(ofg1.get_nodes()) == 9
assert len(ofg1.get_nodes()) == 8
else:
assert len(ofg1.get_nodes()) == 5
assert len(ofg1.get_nodes()) == len(ofg2.get_nodes())
......
......@@ -1563,18 +1563,8 @@ class TestDots(utt.InferShapeTester):
assert np.all(f_a(vx, vy) == f_b(vx, vy))
topo = f_a.maker.fgraph.toposort()
if aesara.config.mode != "FAST_COMPILE":
nb = 0
else:
nb = 1
assert (
sum(
[
isinstance(node.op, (Dot, Usmm, UsmmCscDense))
for node in topo
]
)
== nb
assert not any(
isinstance(node.op, (Dot, Usmm, UsmmCscDense)) for node in topo
)
def test_int32_dtype(self):
......@@ -1822,13 +1812,8 @@ class TestUsmm:
)
assert all(f_shape(a_data, x_data, y_data) == f_b_out.shape)
topo = f_shape.maker.fgraph.toposort()
if aesara.config.mode != "FAST_COMPILE":
nb = 0
else:
nb = 1
assert (
sum([isinstance(node.op, (Dot, Usmm, UsmmCscDense)) for node in topo])
== nb
assert not any(
isinstance(node.op, (Dot, Usmm, UsmmCscDense)) for node in topo
)
......
......@@ -1028,11 +1028,6 @@ class TestJoinAndSplit:
self.split_op_class = Split
self.make_vector_op = MakeVector()
self.floatX = config.floatX
self.hide_error = config.mode not in [
"DebugMode",
"DEBUG_MODE",
"FAST_COMPILE",
]
self.shared = shared
def eval_outputs_and_check_join(self, outputs):
......@@ -1712,17 +1707,6 @@ class TestJoinAndSplit:
for node in topo:
assert not isinstance(node.op, type(self.join_op))
with config.change_flags(compute_test_value="off"):
# Test hide error
x1.set_value(get_mat(3, 4))
x2.set_value(get_mat(3, 4))
x3.set_value(get_mat(2, 5))
if not self.hide_error:
with pytest.raises(ValueError):
f()
else:
f()
def test_rebroadcast(self):
# Regression test for a crash that used to happen when rebroadcasting.
x = TensorType(self.floatX, [False, False, True])()
......
......@@ -439,14 +439,6 @@ def makeSharedTester(
with pytest.raises(AssertionError):
specify_shape_fct()
# No assertion will be raised as the Op is removed from the graph
# when their is optimization
if aesara.config.mode not in ["FAST_COMPILE", "DebugMode", "DEBUG_MODE"]:
shape_constant_fct()
else:
with pytest.raises(AssertionError):
shape_constant_fct()
def test_specify_shape_partial(self):
dtype = self.dtype
if dtype is None:
......@@ -502,13 +494,6 @@ def makeSharedTester(
with pytest.raises(AssertionError):
specify_shape_fct()
# No assertion will be raised as the Op is removed from the graph
if aesara.config.mode not in ["FAST_COMPILE", "DebugMode", "DEBUG_MODE"]:
shape_constant_fct()
else:
with pytest.raises(AssertionError):
shape_constant_fct()
def test_specify_shape_inplace(self):
# test that specify_shape don't break inserting inplace op
......
......@@ -1077,7 +1077,8 @@ class TestSubtensor(utt.OptimizationTestMixin):
else:
ops = subtensor_ops
if idx is idxs[0]:
f = self.function([], [gn.shape, n[idx_].shape], op=ops, N=0, N_fast=2)
# TODO FIXME: This is a very poorly specified test.
f = self.function([], [gn.shape, n[idx_].shape], op=ops, N=0, N_fast=0)
f()
def test_wrong_exception_regression(self):
......@@ -1129,7 +1130,7 @@ class TestSubtensor(utt.OptimizationTestMixin):
data = np.asarray(data, dtype=self.dtype)
n = self.shared(data)
t = n[idx]
f = self.function([], t.shape, op=subtensor_ops, N=0, N_fast=1)
f = self.function([], t.shape, op=subtensor_ops, N=0, N_fast=0)
val = f()
assert np.allclose(val, data[idx].shape)
......
......@@ -1174,10 +1174,6 @@ class TestLocalSubtensorMerge:
data = self.rng.uniform(size=(8, 8, 8)).astype(config.floatX)
x = tensor3("x")
nops = 1
if config.mode == "FAST_COMPILE":
nops = 2
# test 1)
y = x[3:6, 2:6, 1:7][1]
fun = function([x], y)
......@@ -1185,7 +1181,7 @@ class TestLocalSubtensorMerge:
assert np.all(val == data[3:6, 2:6, 1:7][1])
assert (
len([n for n in fun.maker.fgraph.toposort() if isinstance(n.op, Subtensor)])
== nops
== 1
)
# test 2)
......@@ -1195,7 +1191,7 @@ class TestLocalSubtensorMerge:
assert np.all(val == data[2, 3][1])
assert (
len([n for n in fun.maker.fgraph.toposort() if isinstance(n.op, Subtensor)])
== nops
== 1
)
# test 3)
......@@ -1205,7 +1201,7 @@ class TestLocalSubtensorMerge:
assert np.all(val == data[3:6, 2, 1:7][1])
assert (
len([n for n in fun.maker.fgraph.toposort() if isinstance(n.op, Subtensor)])
== nops
== 1
)
def test_scalar6(self):
......@@ -1590,10 +1586,7 @@ class TestAllocZero:
assert len(inc_nodes) == 1
node_is_set_instead_of_inc = inc_nodes[0].op.set_instead_of_inc
mode = config.mode
assert (mode != "FAST_COMPILE" and node_is_set_instead_of_inc) or (
mode == "FAST_COMPILE" and not node_is_set_instead_of_inc
)
assert node_is_set_instead_of_inc
test_X = np.random.random((4, 4)).astype(config.floatX)
utt.assert_allclose(f(test_X), test_X)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论