提交 5a39ef60 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Do not check rewrites based on string representation

上级 0a839391
...@@ -12,6 +12,7 @@ from pytensor.compile.function import function ...@@ -12,6 +12,7 @@ from pytensor.compile.function import function
from pytensor.compile.mode import get_default_mode, get_mode from pytensor.compile.mode import get_default_mode, get_mode
from pytensor.compile.ops import DeepCopyOp, deep_copy_op from pytensor.compile.ops import DeepCopyOp, deep_copy_op
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import equal_computations
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import check_stack_trace, out2in from pytensor.graph.rewriting.basic import check_stack_trace, out2in
from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.db import RewriteDatabaseQuery
...@@ -1410,31 +1411,28 @@ class TestLiftTransposeThroughDot: ...@@ -1410,31 +1411,28 @@ class TestLiftTransposeThroughDot:
def test_matrix_matrix(self): def test_matrix_matrix(self):
a, b = matrices("ab") a, b = matrices("ab")
g = self.simple_rewrite(FunctionGraph([a, b], [dot(a, b).T])) g = self.simple_rewrite(FunctionGraph([a, b], [dot(a, b).T], clone=False))
sg = "FunctionGraph(dot(InplaceDimShuffle{1,0}(b), InplaceDimShuffle{1,0}(a)))" assert equal_computations(g.outputs, [dot(b.T, a.T)])
assert str(g) == sg, (str(g), sg)
assert check_stack_trace(g, ops_to_check="all") assert check_stack_trace(g, ops_to_check="all")
def test_row_matrix(self): def test_row_matrix(self):
a = vector("a") a = vector("a")
b = matrix("b") b = matrix("b")
g = rewrite( g = rewrite(
FunctionGraph([a, b], [dot(a.dimshuffle("x", 0), b).T]), FunctionGraph([a, b], [dot(a.dimshuffle("x", 0), b).T], clone=False),
level="stabilize", level="stabilize",
) )
sg = "FunctionGraph(dot(InplaceDimShuffle{1,0}(b), InplaceDimShuffle{0,x}(a)))" assert equal_computations(g.outputs, [dot(b.T, a.dimshuffle(0, "x"))])
assert str(g) == sg, (str(g), sg)
assert check_stack_trace(g, ops_to_check="all") assert check_stack_trace(g, ops_to_check="all")
def test_matrix_col(self): def test_matrix_col(self):
a = vector("a") a = vector("a")
b = matrix("b") b = matrix("b")
g = rewrite( g = rewrite(
FunctionGraph([a, b], [dot(b, a.dimshuffle(0, "x")).T]), FunctionGraph([a, b], [dot(b, a.dimshuffle(0, "x")).T], clone=False),
level="stabilize", level="stabilize",
) )
sg = "FunctionGraph(dot(InplaceDimShuffle{x,0}(a), InplaceDimShuffle{1,0}(b)))" assert equal_computations(g.outputs, [dot(a.dimshuffle("x", 0), b.T)])
assert str(g) == sg, (str(g), sg)
assert check_stack_trace(g, ops_to_check="all") assert check_stack_trace(g, ops_to_check="all")
......
...@@ -12,7 +12,7 @@ from pytensor.compile.function import function ...@@ -12,7 +12,7 @@ from pytensor.compile.function import function
from pytensor.compile.mode import Mode, get_default_mode from pytensor.compile.mode import Mode, get_default_mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.gradient import grad from pytensor.gradient import grad
from pytensor.graph.basic import Constant from pytensor.graph.basic import Constant, equal_computations
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import check_stack_trace, out2in from pytensor.graph.rewriting.basic import check_stack_trace, out2in
from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.db import RewriteDatabaseQuery
...@@ -86,113 +86,66 @@ def inputs(xbc=(0, 0), ybc=(0, 0), zbc=(0, 0)): ...@@ -86,113 +86,66 @@ def inputs(xbc=(0, 0), ybc=(0, 0), zbc=(0, 0)):
class TestDimshuffleLift: class TestDimshuffleLift:
def test_double_transpose(self): def test_double_transpose(self):
x, y, z = inputs() x, *_ = inputs()
e = ds(ds(x, (1, 0)), (1, 0)) e = ds(ds(x, (1, 0)), (1, 0))
g = FunctionGraph([x], [e]) g = FunctionGraph([x], [e], clone=False)
# TODO FIXME: Construct these graphs and compare them. assert isinstance(g.outputs[0].owner.op, DimShuffle)
assert (
str(g) == "FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{1,0}(x)))"
)
dimshuffle_lift.rewrite(g) dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(x)" assert g.outputs[0] is x
# no need to check_stack_trace as graph is supposed to be empty # no need to check_stack_trace as graph is supposed to be empty
def test_merge2(self): def test_merge2(self):
x, y, z = inputs() x, *_ = inputs()
e = ds(ds(x, (1, "x", 0)), (2, 0, "x", 1)) e = ds(ds(x, (1, "x", 0)), (2, 0, "x", 1))
g = FunctionGraph([x], [e]) g = FunctionGraph([x], [e], clone=False)
# TODO FIXME: Construct these graphs and compare them. assert len(g.apply_nodes) == 2
assert (
str(g)
== "FunctionGraph(InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{1,x,0}(x)))"
), str(g)
dimshuffle_lift.rewrite(g) dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1,x,x}(x))", str(g) assert equal_computations(g.outputs, [x.dimshuffle(0, 1, "x", "x")])
# Check stacktrace was copied over correctly after rewrite was applied # Check stacktrace was copied over correctly after rewrite was applied
assert check_stack_trace(g, ops_to_check="all") assert check_stack_trace(g, ops_to_check="all")
def test_elim3(self): def test_elim3(self):
x, y, z = inputs() x, y, z = inputs()
e = ds(ds(ds(x, (0, "x", 1)), (2, 0, "x", 1)), (1, 0)) e = ds(ds(ds(x, (0, "x", 1)), (2, 0, "x", 1)), (1, 0))
g = FunctionGraph([x], [e]) g = FunctionGraph([x], [e], clone=False)
# TODO FIXME: Construct these graphs and compare them. assert isinstance(g.outputs[0].owner.op, DimShuffle)
assert str(g) == (
"FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}"
"(InplaceDimShuffle{0,x,1}(x))))"
), str(g)
dimshuffle_lift.rewrite(g) dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(x)", str(g) assert g.outputs[0] is x
# no need to check_stack_trace as graph is supposed to be empty # no need to check_stack_trace as graph is supposed to be empty
def test_lift(self): def test_lift(self):
x, y, z = inputs([False] * 1, [False] * 2, [False] * 3) x, y, z = inputs([False] * 1, [False] * 2, [False] * 3)
e = x + y + z e = x + y + z
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e], clone=False)
# TODO FIXME: Construct these graphs and compare them.
# It does not really matter if the DimShuffles are inplace
# or not.
init_str_g_inplace = (
"FunctionGraph(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0,1}"
"(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0}(x), y)), z))"
)
init_str_g_noinplace = (
"FunctionGraph(Elemwise{add,no_inplace}(DimShuffle{x,0,1}"
"(Elemwise{add,no_inplace}(DimShuffle{x,0}(x), y)), z))"
)
assert str(g) in (init_str_g_inplace, init_str_g_noinplace), str(g)
rewrite_str_g_inplace = (
"FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
"(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z))"
)
rewrite_str_g_noinplace = (
"FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
"(DimShuffle{x,x,0}(x), DimShuffle{x,0,1}(y)), z))"
)
dimshuffle_lift.rewrite(g) dimshuffle_lift.rewrite(g)
assert str(g) in (rewrite_str_g_inplace, rewrite_str_g_noinplace), str(g) assert equal_computations(
g.outputs,
[(x.dimshuffle("x", "x", 0) + y.dimshuffle("x", 0, 1)) + z],
)
# Check stacktrace was copied over correctly after rewrite was applied # Check stacktrace was copied over correctly after rewrite was applied
assert check_stack_trace(g, ops_to_check="all") assert check_stack_trace(g, ops_to_check="all")
def test_recursive_lift(self): def test_recursive_lift(self):
v = vector(dtype="float64") v = vector("v", dtype="float64")
m = matrix(dtype="float64") m = matrix("m", dtype="float64")
out = ((v + 42) * (m + 84)).T out = ((v + 42) * (m + 84)).T
g = FunctionGraph([v, m], [out]) g = FunctionGraph([v, m], [out], clone=False)
# TODO FIXME: Construct these graphs and compare them. new_out = local_dimshuffle_lift.transform(g, g.outputs[0].owner)
init_str_g = ( assert equal_computations(
"FunctionGraph(InplaceDimShuffle{1,0}(Elemwise{mul,no_inplace}" new_out,
"(InplaceDimShuffle{x,0}(Elemwise{add,no_inplace}" [(v.dimshuffle(0, "x") + 42) * (m.T + 84)],
"(<TensorType(float64, (?,))>, "
"InplaceDimShuffle{x}(TensorConstant{42}))), "
"Elemwise{add,no_inplace}"
"(<TensorType(float64, (?, ?))>, "
"InplaceDimShuffle{x,x}(TensorConstant{84})))))"
)
assert str(g) == init_str_g
new_out = local_dimshuffle_lift.transform(g, g.outputs[0].owner)[0]
new_g = FunctionGraph(g.inputs, [new_out])
rewrite_str_g = (
"FunctionGraph(Elemwise{mul,no_inplace}(Elemwise{add,no_inplace}"
"(InplaceDimShuffle{0,x}(<TensorType(float64, (?,))>), "
"InplaceDimShuffle{x,x}(TensorConstant{42})), "
"Elemwise{add,no_inplace}(InplaceDimShuffle{1,0}"
"(<TensorType(float64, (?, ?))>), "
"InplaceDimShuffle{x,x}(TensorConstant{84}))))"
) )
assert str(new_g) == rewrite_str_g
# Check stacktrace was copied over correctly after rewrite was applied # Check stacktrace was copied over correctly after rewrite was applied
new_g = FunctionGraph(g.inputs, new_out, clone=False)
assert check_stack_trace(new_g, ops_to_check="all") assert check_stack_trace(new_g, ops_to_check="all")
def test_useless_dimshuffle(self): def test_useless_dimshuffle(self):
x, _, _ = inputs() x, *_ = inputs()
e = ds(x, (0, 1)) e = ds(x, (0, 1))
g = FunctionGraph([x], [e]) g = FunctionGraph([x], [e], clone=False)
# TODO FIXME: Construct these graphs and compare them. assert isinstance(g.outputs[0].owner.op, DimShuffle)
assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1}(x))"
dimshuffle_lift.rewrite(g) dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(x)" assert g.outputs[0] is x
# Check stacktrace was copied over correctly after rewrite was applied # Check stacktrace was copied over correctly after rewrite was applied
assert hasattr(g.outputs[0].tag, "trace") assert hasattr(g.outputs[0].tag, "trace")
...@@ -203,17 +156,10 @@ class TestDimshuffleLift: ...@@ -203,17 +156,10 @@ class TestDimshuffleLift:
ds_y = ds(y, (2, 1, 0)) # useless ds_y = ds(y, (2, 1, 0)) # useless
ds_z = ds(z, (2, 1, 0)) # useful ds_z = ds(z, (2, 1, 0)) # useful
ds_u = ds(u, ("x")) # useful ds_u = ds(u, ("x")) # useful
g = FunctionGraph([x, y, z, u], [ds_x, ds_y, ds_z, ds_u]) g = FunctionGraph([x, y, z, u], [ds_x, ds_y, ds_z, ds_u], clone=False)
# TODO FIXME: Construct these graphs and compare them. assert len(g.apply_nodes) == 4
assert (
str(g)
== "FunctionGraph(InplaceDimShuffle{0,x}(x), InplaceDimShuffle{2,1,0}(y), InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
)
dimshuffle_lift.rewrite(g) dimshuffle_lift.rewrite(g)
assert ( assert equal_computations(g.outputs, [x, y, z.T, u.dimshuffle("x")])
str(g)
== "FunctionGraph(x, y, InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
)
# Check stacktrace was copied over correctly after rewrite was applied # Check stacktrace was copied over correctly after rewrite was applied
assert hasattr(g.outputs[0].tag, "trace") assert hasattr(g.outputs[0].tag, "trace")
...@@ -237,34 +183,32 @@ def test_local_useless_dimshuffle_in_reshape(): ...@@ -237,34 +183,32 @@ def test_local_useless_dimshuffle_in_reshape():
reshape_dimshuffle_row, reshape_dimshuffle_row,
reshape_dimshuffle_col, reshape_dimshuffle_col,
], ],
clone=False,
) )
assert len(g.apply_nodes) == 4 * 3
# TODO FIXME: Construct these graphs and compare them.
assert str(g) == (
"FunctionGraph(Reshape{1}(InplaceDimShuffle{x,0}(vector), Shape(vector)), "
"Reshape{2}(InplaceDimShuffle{x,0,x,1}(mat), Shape(mat)), "
"Reshape{2}(InplaceDimShuffle{1,x}(row), Shape(row)), "
"Reshape{2}(InplaceDimShuffle{0}(col), Shape(col)))"
)
useless_dimshuffle_in_reshape = out2in(local_useless_dimshuffle_in_reshape) useless_dimshuffle_in_reshape = out2in(local_useless_dimshuffle_in_reshape)
useless_dimshuffle_in_reshape.rewrite(g) useless_dimshuffle_in_reshape.rewrite(g)
assert str(g) == ( assert equal_computations(
"FunctionGraph(Reshape{1}(vector, Shape(vector)), " g.outputs,
"Reshape{2}(mat, Shape(mat)), " [
"Reshape{2}(row, Shape(row)), " reshape(vec, vec.shape),
"Reshape{2}(col, Shape(col)))" reshape(mat, mat.shape),
reshape(row, row.shape),
reshape(col, col.shape),
],
) )
# Check stacktrace was copied over correctly after rewrite was applied # Check stacktrace was copied over correctly after rewrite was applied
assert check_stack_trace(g, ops_to_check="all") assert check_stack_trace(g, ops_to_check="all")
# Check that the rewrite does not get applied when the order # Check that the rewrite does not get applied when the order
# of dimensions has changed. # of dimensions has changed.
reshape_dimshuffle_mat2 = reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape) reshape_dimshuffle_mat2 = reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape)
h = FunctionGraph([mat], [reshape_dimshuffle_mat2]) h = FunctionGraph([mat], [reshape_dimshuffle_mat2], clone=False)
str_h = str(h) assert len(h.apply_nodes) == 3
useless_dimshuffle_in_reshape.rewrite(h) useless_dimshuffle_in_reshape.rewrite(h)
assert str(h) == str_h assert equal_computations(
h.outputs, [reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape)]
)
class TestFusion: class TestFusion:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论