提交 8cbd9840 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add basic optimization canonicalizations to "fast_compile" mode by default

上级 842d3bcd
...@@ -507,7 +507,9 @@ def register_canonicalize(lopt, *tags, **kwargs): ...@@ -507,7 +507,9 @@ def register_canonicalize(lopt, *tags, **kwargs):
return register return register
else: else:
name = kwargs.pop("name", None) or lopt.__name__ name = kwargs.pop("name", None) or lopt.__name__
compile.optdb["canonicalize"].register(name, lopt, "fast_run", *tags, **kwargs) compile.optdb["canonicalize"].register(
name, lopt, "fast_run", "fast_compile", *tags, **kwargs
)
return lopt return lopt
......
...@@ -48,7 +48,7 @@ class TestPyDotFormatter: ...@@ -48,7 +48,7 @@ class TestPyDotFormatter:
assert len(sub_graphs) == 2 assert len(sub_graphs) == 2
ofg1, ofg2 = sub_graphs ofg1, ofg2 = sub_graphs
if config.mode == "FAST_COMPILE": if config.mode == "FAST_COMPILE":
assert len(ofg1.get_nodes()) == 9 assert len(ofg1.get_nodes()) == 8
else: else:
assert len(ofg1.get_nodes()) == 5 assert len(ofg1.get_nodes()) == 5
assert len(ofg1.get_nodes()) == len(ofg2.get_nodes()) assert len(ofg1.get_nodes()) == len(ofg2.get_nodes())
......
...@@ -1563,18 +1563,8 @@ class TestDots(utt.InferShapeTester): ...@@ -1563,18 +1563,8 @@ class TestDots(utt.InferShapeTester):
assert np.all(f_a(vx, vy) == f_b(vx, vy)) assert np.all(f_a(vx, vy) == f_b(vx, vy))
topo = f_a.maker.fgraph.toposort() topo = f_a.maker.fgraph.toposort()
if aesara.config.mode != "FAST_COMPILE": assert not any(
nb = 0 isinstance(node.op, (Dot, Usmm, UsmmCscDense)) for node in topo
else:
nb = 1
assert (
sum(
[
isinstance(node.op, (Dot, Usmm, UsmmCscDense))
for node in topo
]
)
== nb
) )
def test_int32_dtype(self): def test_int32_dtype(self):
...@@ -1822,13 +1812,8 @@ class TestUsmm: ...@@ -1822,13 +1812,8 @@ class TestUsmm:
) )
assert all(f_shape(a_data, x_data, y_data) == f_b_out.shape) assert all(f_shape(a_data, x_data, y_data) == f_b_out.shape)
topo = f_shape.maker.fgraph.toposort() topo = f_shape.maker.fgraph.toposort()
if aesara.config.mode != "FAST_COMPILE": assert not any(
nb = 0 isinstance(node.op, (Dot, Usmm, UsmmCscDense)) for node in topo
else:
nb = 1
assert (
sum([isinstance(node.op, (Dot, Usmm, UsmmCscDense)) for node in topo])
== nb
) )
......
...@@ -1028,11 +1028,6 @@ class TestJoinAndSplit: ...@@ -1028,11 +1028,6 @@ class TestJoinAndSplit:
self.split_op_class = Split self.split_op_class = Split
self.make_vector_op = MakeVector() self.make_vector_op = MakeVector()
self.floatX = config.floatX self.floatX = config.floatX
self.hide_error = config.mode not in [
"DebugMode",
"DEBUG_MODE",
"FAST_COMPILE",
]
self.shared = shared self.shared = shared
def eval_outputs_and_check_join(self, outputs): def eval_outputs_and_check_join(self, outputs):
...@@ -1712,17 +1707,6 @@ class TestJoinAndSplit: ...@@ -1712,17 +1707,6 @@ class TestJoinAndSplit:
for node in topo: for node in topo:
assert not isinstance(node.op, type(self.join_op)) assert not isinstance(node.op, type(self.join_op))
with config.change_flags(compute_test_value="off"):
# Test hide error
x1.set_value(get_mat(3, 4))
x2.set_value(get_mat(3, 4))
x3.set_value(get_mat(2, 5))
if not self.hide_error:
with pytest.raises(ValueError):
f()
else:
f()
def test_rebroadcast(self): def test_rebroadcast(self):
# Regression test for a crash that used to happen when rebroadcasting. # Regression test for a crash that used to happen when rebroadcasting.
x = TensorType(self.floatX, [False, False, True])() x = TensorType(self.floatX, [False, False, True])()
......
...@@ -439,14 +439,6 @@ def makeSharedTester( ...@@ -439,14 +439,6 @@ def makeSharedTester(
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
specify_shape_fct() specify_shape_fct()
# No assertion will be raised as the Op is removed from the graph
# when their is optimization
if aesara.config.mode not in ["FAST_COMPILE", "DebugMode", "DEBUG_MODE"]:
shape_constant_fct()
else:
with pytest.raises(AssertionError):
shape_constant_fct()
def test_specify_shape_partial(self): def test_specify_shape_partial(self):
dtype = self.dtype dtype = self.dtype
if dtype is None: if dtype is None:
...@@ -502,13 +494,6 @@ def makeSharedTester( ...@@ -502,13 +494,6 @@ def makeSharedTester(
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
specify_shape_fct() specify_shape_fct()
# No assertion will be raised as the Op is removed from the graph
if aesara.config.mode not in ["FAST_COMPILE", "DebugMode", "DEBUG_MODE"]:
shape_constant_fct()
else:
with pytest.raises(AssertionError):
shape_constant_fct()
def test_specify_shape_inplace(self): def test_specify_shape_inplace(self):
# test that specify_shape don't break inserting inplace op # test that specify_shape don't break inserting inplace op
......
...@@ -1077,7 +1077,8 @@ class TestSubtensor(utt.OptimizationTestMixin): ...@@ -1077,7 +1077,8 @@ class TestSubtensor(utt.OptimizationTestMixin):
else: else:
ops = subtensor_ops ops = subtensor_ops
if idx is idxs[0]: if idx is idxs[0]:
f = self.function([], [gn.shape, n[idx_].shape], op=ops, N=0, N_fast=2) # TODO FIXME: This is a very poorly specified test.
f = self.function([], [gn.shape, n[idx_].shape], op=ops, N=0, N_fast=0)
f() f()
def test_wrong_exception_regression(self): def test_wrong_exception_regression(self):
...@@ -1129,7 +1130,7 @@ class TestSubtensor(utt.OptimizationTestMixin): ...@@ -1129,7 +1130,7 @@ class TestSubtensor(utt.OptimizationTestMixin):
data = np.asarray(data, dtype=self.dtype) data = np.asarray(data, dtype=self.dtype)
n = self.shared(data) n = self.shared(data)
t = n[idx] t = n[idx]
f = self.function([], t.shape, op=subtensor_ops, N=0, N_fast=1) f = self.function([], t.shape, op=subtensor_ops, N=0, N_fast=0)
val = f() val = f()
assert np.allclose(val, data[idx].shape) assert np.allclose(val, data[idx].shape)
......
...@@ -1174,10 +1174,6 @@ class TestLocalSubtensorMerge: ...@@ -1174,10 +1174,6 @@ class TestLocalSubtensorMerge:
data = self.rng.uniform(size=(8, 8, 8)).astype(config.floatX) data = self.rng.uniform(size=(8, 8, 8)).astype(config.floatX)
x = tensor3("x") x = tensor3("x")
nops = 1
if config.mode == "FAST_COMPILE":
nops = 2
# test 1) # test 1)
y = x[3:6, 2:6, 1:7][1] y = x[3:6, 2:6, 1:7][1]
fun = function([x], y) fun = function([x], y)
...@@ -1185,7 +1181,7 @@ class TestLocalSubtensorMerge: ...@@ -1185,7 +1181,7 @@ class TestLocalSubtensorMerge:
assert np.all(val == data[3:6, 2:6, 1:7][1]) assert np.all(val == data[3:6, 2:6, 1:7][1])
assert ( assert (
len([n for n in fun.maker.fgraph.toposort() if isinstance(n.op, Subtensor)]) len([n for n in fun.maker.fgraph.toposort() if isinstance(n.op, Subtensor)])
== nops == 1
) )
# test 2) # test 2)
...@@ -1195,7 +1191,7 @@ class TestLocalSubtensorMerge: ...@@ -1195,7 +1191,7 @@ class TestLocalSubtensorMerge:
assert np.all(val == data[2, 3][1]) assert np.all(val == data[2, 3][1])
assert ( assert (
len([n for n in fun.maker.fgraph.toposort() if isinstance(n.op, Subtensor)]) len([n for n in fun.maker.fgraph.toposort() if isinstance(n.op, Subtensor)])
== nops == 1
) )
# test 3) # test 3)
...@@ -1205,7 +1201,7 @@ class TestLocalSubtensorMerge: ...@@ -1205,7 +1201,7 @@ class TestLocalSubtensorMerge:
assert np.all(val == data[3:6, 2, 1:7][1]) assert np.all(val == data[3:6, 2, 1:7][1])
assert ( assert (
len([n for n in fun.maker.fgraph.toposort() if isinstance(n.op, Subtensor)]) len([n for n in fun.maker.fgraph.toposort() if isinstance(n.op, Subtensor)])
== nops == 1
) )
def test_scalar6(self): def test_scalar6(self):
...@@ -1590,10 +1586,7 @@ class TestAllocZero: ...@@ -1590,10 +1586,7 @@ class TestAllocZero:
assert len(inc_nodes) == 1 assert len(inc_nodes) == 1
node_is_set_instead_of_inc = inc_nodes[0].op.set_instead_of_inc node_is_set_instead_of_inc = inc_nodes[0].op.set_instead_of_inc
mode = config.mode assert node_is_set_instead_of_inc
assert (mode != "FAST_COMPILE" and node_is_set_instead_of_inc) or (
mode == "FAST_COMPILE" and not node_is_set_instead_of_inc
)
test_X = np.random.random((4, 4)).astype(config.floatX) test_X = np.random.random((4, 4)).astype(config.floatX)
utt.assert_allclose(f(test_X), test_X) utt.assert_allclose(f(test_X), test_X)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论