提交 5a39ef60 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Do not check rewrites based on string representation

上级 0a839391
......@@ -12,6 +12,7 @@ from pytensor.compile.function import function
from pytensor.compile.mode import get_default_mode, get_mode
from pytensor.compile.ops import DeepCopyOp, deep_copy_op
from pytensor.configdefaults import config
from pytensor.graph.basic import equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
......@@ -1410,31 +1411,28 @@ class TestLiftTransposeThroughDot:
def test_matrix_matrix(self):
a, b = matrices("ab")
g = self.simple_rewrite(FunctionGraph([a, b], [dot(a, b).T]))
sg = "FunctionGraph(dot(InplaceDimShuffle{1,0}(b), InplaceDimShuffle{1,0}(a)))"
assert str(g) == sg, (str(g), sg)
g = self.simple_rewrite(FunctionGraph([a, b], [dot(a, b).T], clone=False))
assert equal_computations(g.outputs, [dot(b.T, a.T)])
assert check_stack_trace(g, ops_to_check="all")
def test_row_matrix(self):
a = vector("a")
b = matrix("b")
g = rewrite(
FunctionGraph([a, b], [dot(a.dimshuffle("x", 0), b).T]),
FunctionGraph([a, b], [dot(a.dimshuffle("x", 0), b).T], clone=False),
level="stabilize",
)
sg = "FunctionGraph(dot(InplaceDimShuffle{1,0}(b), InplaceDimShuffle{0,x}(a)))"
assert str(g) == sg, (str(g), sg)
assert equal_computations(g.outputs, [dot(b.T, a.dimshuffle(0, "x"))])
assert check_stack_trace(g, ops_to_check="all")
def test_matrix_col(self):
a = vector("a")
b = matrix("b")
g = rewrite(
FunctionGraph([a, b], [dot(b, a.dimshuffle(0, "x")).T]),
FunctionGraph([a, b], [dot(b, a.dimshuffle(0, "x")).T], clone=False),
level="stabilize",
)
sg = "FunctionGraph(dot(InplaceDimShuffle{x,0}(a), InplaceDimShuffle{1,0}(b)))"
assert str(g) == sg, (str(g), sg)
assert equal_computations(g.outputs, [dot(a.dimshuffle("x", 0), b.T)])
assert check_stack_trace(g, ops_to_check="all")
......
......@@ -12,7 +12,7 @@ from pytensor.compile.function import function
from pytensor.compile.mode import Mode, get_default_mode
from pytensor.configdefaults import config
from pytensor.gradient import grad
from pytensor.graph.basic import Constant
from pytensor.graph.basic import Constant, equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
......@@ -86,113 +86,66 @@ def inputs(xbc=(0, 0), ybc=(0, 0), zbc=(0, 0)):
class TestDimshuffleLift:
def test_double_transpose(self):
x, y, z = inputs()
x, *_ = inputs()
e = ds(ds(x, (1, 0)), (1, 0))
g = FunctionGraph([x], [e])
# TODO FIXME: Construct these graphs and compare them.
assert (
str(g) == "FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{1,0}(x)))"
)
g = FunctionGraph([x], [e], clone=False)
assert isinstance(g.outputs[0].owner.op, DimShuffle)
dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(x)"
assert g.outputs[0] is x
# no need to check_stack_trace as graph is supposed to be empty
def test_merge2(self):
x, y, z = inputs()
x, *_ = inputs()
e = ds(ds(x, (1, "x", 0)), (2, 0, "x", 1))
g = FunctionGraph([x], [e])
# TODO FIXME: Construct these graphs and compare them.
assert (
str(g)
== "FunctionGraph(InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{1,x,0}(x)))"
), str(g)
g = FunctionGraph([x], [e], clone=False)
assert len(g.apply_nodes) == 2
dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1,x,x}(x))", str(g)
assert equal_computations(g.outputs, [x.dimshuffle(0, 1, "x", "x")])
# Check stacktrace was copied over correctly after rewrite was applied
assert check_stack_trace(g, ops_to_check="all")
def test_elim3(self):
x, y, z = inputs()
e = ds(ds(ds(x, (0, "x", 1)), (2, 0, "x", 1)), (1, 0))
g = FunctionGraph([x], [e])
# TODO FIXME: Construct these graphs and compare them.
assert str(g) == (
"FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}"
"(InplaceDimShuffle{0,x,1}(x))))"
), str(g)
g = FunctionGraph([x], [e], clone=False)
assert isinstance(g.outputs[0].owner.op, DimShuffle)
dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(x)", str(g)
assert g.outputs[0] is x
# no need to check_stack_trace as graph is supposed to be empty
def test_lift(self):
x, y, z = inputs([False] * 1, [False] * 2, [False] * 3)
e = x + y + z
g = FunctionGraph([x, y, z], [e])
# TODO FIXME: Construct these graphs and compare them.
# It does not really matter if the DimShuffles are inplace
# or not.
init_str_g_inplace = (
"FunctionGraph(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0,1}"
"(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0}(x), y)), z))"
)
init_str_g_noinplace = (
"FunctionGraph(Elemwise{add,no_inplace}(DimShuffle{x,0,1}"
"(Elemwise{add,no_inplace}(DimShuffle{x,0}(x), y)), z))"
)
assert str(g) in (init_str_g_inplace, init_str_g_noinplace), str(g)
rewrite_str_g_inplace = (
"FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
"(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z))"
)
rewrite_str_g_noinplace = (
"FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
"(DimShuffle{x,x,0}(x), DimShuffle{x,0,1}(y)), z))"
)
g = FunctionGraph([x, y, z], [e], clone=False)
dimshuffle_lift.rewrite(g)
assert str(g) in (rewrite_str_g_inplace, rewrite_str_g_noinplace), str(g)
assert equal_computations(
g.outputs,
[(x.dimshuffle("x", "x", 0) + y.dimshuffle("x", 0, 1)) + z],
)
# Check stacktrace was copied over correctly after rewrite was applied
assert check_stack_trace(g, ops_to_check="all")
def test_recursive_lift(self):
v = vector(dtype="float64")
m = matrix(dtype="float64")
v = vector("v", dtype="float64")
m = matrix("m", dtype="float64")
out = ((v + 42) * (m + 84)).T
g = FunctionGraph([v, m], [out])
# TODO FIXME: Construct these graphs and compare them.
init_str_g = (
"FunctionGraph(InplaceDimShuffle{1,0}(Elemwise{mul,no_inplace}"
"(InplaceDimShuffle{x,0}(Elemwise{add,no_inplace}"
"(<TensorType(float64, (?,))>, "
"InplaceDimShuffle{x}(TensorConstant{42}))), "
"Elemwise{add,no_inplace}"
"(<TensorType(float64, (?, ?))>, "
"InplaceDimShuffle{x,x}(TensorConstant{84})))))"
)
assert str(g) == init_str_g
new_out = local_dimshuffle_lift.transform(g, g.outputs[0].owner)[0]
new_g = FunctionGraph(g.inputs, [new_out])
rewrite_str_g = (
"FunctionGraph(Elemwise{mul,no_inplace}(Elemwise{add,no_inplace}"
"(InplaceDimShuffle{0,x}(<TensorType(float64, (?,))>), "
"InplaceDimShuffle{x,x}(TensorConstant{42})), "
"Elemwise{add,no_inplace}(InplaceDimShuffle{1,0}"
"(<TensorType(float64, (?, ?))>), "
"InplaceDimShuffle{x,x}(TensorConstant{84}))))"
g = FunctionGraph([v, m], [out], clone=False)
new_out = local_dimshuffle_lift.transform(g, g.outputs[0].owner)
assert equal_computations(
new_out,
[(v.dimshuffle(0, "x") + 42) * (m.T + 84)],
)
assert str(new_g) == rewrite_str_g
# Check stacktrace was copied over correctly after rewrite was applied
new_g = FunctionGraph(g.inputs, new_out, clone=False)
assert check_stack_trace(new_g, ops_to_check="all")
def test_useless_dimshuffle(self):
x, _, _ = inputs()
x, *_ = inputs()
e = ds(x, (0, 1))
g = FunctionGraph([x], [e])
# TODO FIXME: Construct these graphs and compare them.
assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1}(x))"
g = FunctionGraph([x], [e], clone=False)
assert isinstance(g.outputs[0].owner.op, DimShuffle)
dimshuffle_lift.rewrite(g)
assert str(g) == "FunctionGraph(x)"
assert g.outputs[0] is x
# Check stacktrace was copied over correctly after rewrite was applied
assert hasattr(g.outputs[0].tag, "trace")
......@@ -203,17 +156,10 @@ class TestDimshuffleLift:
ds_y = ds(y, (2, 1, 0)) # useless
ds_z = ds(z, (2, 1, 0)) # useful
ds_u = ds(u, ("x")) # useful
g = FunctionGraph([x, y, z, u], [ds_x, ds_y, ds_z, ds_u])
# TODO FIXME: Construct these graphs and compare them.
assert (
str(g)
== "FunctionGraph(InplaceDimShuffle{0,x}(x), InplaceDimShuffle{2,1,0}(y), InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
)
g = FunctionGraph([x, y, z, u], [ds_x, ds_y, ds_z, ds_u], clone=False)
assert len(g.apply_nodes) == 4
dimshuffle_lift.rewrite(g)
assert (
str(g)
== "FunctionGraph(x, y, InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
)
assert equal_computations(g.outputs, [x, y, z.T, u.dimshuffle("x")])
# Check stacktrace was copied over correctly after rewrite was applied
assert hasattr(g.outputs[0].tag, "trace")
......@@ -237,34 +183,32 @@ def test_local_useless_dimshuffle_in_reshape():
reshape_dimshuffle_row,
reshape_dimshuffle_col,
],
clone=False,
)
# TODO FIXME: Construct these graphs and compare them.
assert str(g) == (
"FunctionGraph(Reshape{1}(InplaceDimShuffle{x,0}(vector), Shape(vector)), "
"Reshape{2}(InplaceDimShuffle{x,0,x,1}(mat), Shape(mat)), "
"Reshape{2}(InplaceDimShuffle{1,x}(row), Shape(row)), "
"Reshape{2}(InplaceDimShuffle{0}(col), Shape(col)))"
)
assert len(g.apply_nodes) == 4 * 3
useless_dimshuffle_in_reshape = out2in(local_useless_dimshuffle_in_reshape)
useless_dimshuffle_in_reshape.rewrite(g)
assert str(g) == (
"FunctionGraph(Reshape{1}(vector, Shape(vector)), "
"Reshape{2}(mat, Shape(mat)), "
"Reshape{2}(row, Shape(row)), "
"Reshape{2}(col, Shape(col)))"
assert equal_computations(
g.outputs,
[
reshape(vec, vec.shape),
reshape(mat, mat.shape),
reshape(row, row.shape),
reshape(col, col.shape),
],
)
# Check stacktrace was copied over correctly after rewrite was applied
assert check_stack_trace(g, ops_to_check="all")
# Check that the rewrite does not get applied when the order
# of dimensions has changed.
reshape_dimshuffle_mat2 = reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape)
h = FunctionGraph([mat], [reshape_dimshuffle_mat2])
str_h = str(h)
h = FunctionGraph([mat], [reshape_dimshuffle_mat2], clone=False)
assert len(h.apply_nodes) == 3
useless_dimshuffle_in_reshape.rewrite(h)
assert str(h) == str_h
assert equal_computations(
h.outputs, [reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape)]
)
class TestFusion:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论