提交 b6804245 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Clean up FunctionGraph str and repr implementation

上级 55ca059a
......@@ -144,12 +144,12 @@ def test_misc():
g = Env([x, y, z], [e])
assert g.consistent()
PatternOptimizer((transpose_view, (transpose_view, "x")), "x").optimize(g)
assert str(g) == "[x]"
assert str(g) == "FunctionGraph(x)"
new_e = add(x, y)
g.replace_validate(x, new_e)
assert str(g) == "[Add(x, y)]"
assert str(g) == "FunctionGraph(Add(x, y))"
g.replace(new_e, dot(add_in_place(x, y), transpose_view(x)))
assert str(g) == "[Dot(AddInPlace(x, y), TransposeView(x))]"
assert str(g) == "FunctionGraph(Dot(AddInPlace(x, y), TransposeView(x)))"
assert not g.consistent()
......@@ -325,7 +325,10 @@ def test_long_destroyers_loop():
OpSubOptimizer(add, add_in_place).optimize(g)
assert g.consistent()
# we don't want to see that!
assert str(g) != "[Dot(Dot(AddInPlace(x, y), AddInPlace(y, z)), AddInPlace(z, x))]"
assert (
str(g)
!= "FunctionGraph(Dot(Dot(AddInPlace(x, y), AddInPlace(y, z)), AddInPlace(z, x)))"
)
e2 = dot(dot(add_in_place(x, y), add_in_place(y, z)), add_in_place(z, x))
with pytest.raises(InconsistencyError):
Env(*graph.clone([x, y, z], [e2]))
......
......@@ -43,7 +43,7 @@ class TestPatternOptimizer:
e = op1(op2(x, y), z)
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, (op2, "1", "2"), "3"), (op4, "3", "2")).optimize(g)
assert str(g) == "[Op4(z, y)]"
assert str(g) == "FunctionGraph(Op4(z, y))"
def test_nested_out_pattern(self):
x, y, z = inputs()
......@@ -52,7 +52,7 @@ class TestPatternOptimizer:
PatternOptimizer(
(op1, "1", "2"), (op4, (op1, "1"), (op2, "2"), (op3, "1", "2"))
).optimize(g)
assert str(g) == "[Op4(Op1(x), Op2(y), Op3(x, y))]"
assert str(g) == "FunctionGraph(Op4(Op1(x), Op2(y), Op3(x, y)))"
def test_unification_1(self):
x, y, z = inputs()
......@@ -63,7 +63,7 @@ class TestPatternOptimizer:
(op4, "2", "1"),
).optimize(g)
# So the replacement should occur
assert str(g) == "[Op4(z, x)]"
assert str(g) == "FunctionGraph(Op4(z, x))"
def test_unification_2(self):
x, y, z = inputs()
......@@ -74,7 +74,7 @@ class TestPatternOptimizer:
(op4, "2", "1"),
).optimize(g)
# The replacement should NOT occur
assert str(g) == "[Op1(Op2(x, y), z)]"
assert str(g) == "FunctionGraph(Op1(Op2(x, y), z))"
def test_replace_subgraph(self):
# replacing inside the graph
......@@ -82,7 +82,7 @@ class TestPatternOptimizer:
e = op1(op2(x, y), z)
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op2, "1", "2"), (op1, "2", "1")).optimize(g)
assert str(g) == "[Op1(Op1(y, x), z)]"
assert str(g) == "FunctionGraph(Op1(Op1(y, x), z))"
def test_no_recurse(self):
# if the out pattern is an acceptable in pattern
......@@ -92,7 +92,7 @@ class TestPatternOptimizer:
e = op1(op2(x, y), z)
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op2, "1", "2"), (op2, "2", "1"), ign=True).optimize(g)
assert str(g) == "[Op1(Op2(y, x), z)]"
assert str(g) == "FunctionGraph(Op1(Op2(y, x), z))"
def test_multiple(self):
# it should replace all occurrences of the pattern
......@@ -100,7 +100,7 @@ class TestPatternOptimizer:
e = op1(op2(x, y), op2(x, y), op2(y, z))
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op2, "1", "2"), (op4, "1")).optimize(g)
assert str(g) == "[Op1(Op4(x), Op4(x), Op4(y))]"
assert str(g) == "FunctionGraph(Op1(Op4(x), Op4(x), Op4(y)))"
def test_nested_even(self):
# regardless of the order in which we optimize, this
......@@ -109,21 +109,21 @@ class TestPatternOptimizer:
e = op1(op1(op1(op1(x))))
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, (op1, "1")), "1").optimize(g)
assert str(g) == "[x]"
assert str(g) == "FunctionGraph(x)"
def test_nested_odd(self):
x, y, z = inputs()
e = op1(op1(op1(op1(op1(x)))))
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, (op1, "1")), "1").optimize(g)
assert str(g) == "[Op1(x)]"
assert str(g) == "FunctionGraph(Op1(x))"
def test_expand(self):
x, y, z = inputs()
e = op1(op1(op1(x)))
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, "1"), (op2, (op1, "1")), ign=True).optimize(g)
assert str(g) == "[Op2(Op1(Op2(Op1(Op2(Op1(x))))))]"
assert str(g) == "FunctionGraph(Op2(Op1(Op2(Op1(Op2(Op1(x)))))))"
def test_ambiguous(self):
# this test should always work with TopoOptimizer and the
......@@ -133,7 +133,7 @@ class TestPatternOptimizer:
e = op1(op1(op1(op1(op1(x)))))
g = FunctionGraph([x, y, z], [e])
TopoPatternOptimizer((op1, (op1, "1")), (op1, "1"), ign=False).optimize(g)
assert str(g) == "[Op1(x)]"
assert str(g) == "FunctionGraph(Op1(x))"
def test_constant_unification(self):
x = Constant(MyType(), 2, name="x")
......@@ -142,7 +142,7 @@ class TestPatternOptimizer:
e = op1(op1(x, y), y)
g = FunctionGraph([y], [e])
PatternOptimizer((op1, z, "1"), (op2, "1", z)).optimize(g)
assert str(g) == "[Op1(Op2(y, z), y)]"
assert str(g) == "FunctionGraph(Op1(Op2(y, z), y))"
def test_constraints(self):
x, y, z = inputs()
......@@ -156,14 +156,14 @@ class TestPatternOptimizer:
PatternOptimizer(
(op1, {"pattern": "1", "constraint": constraint}), (op3, "1")
).optimize(g)
assert str(g) == "[Op4(Op3(Op2(x, y)), Op1(Op1(x, y)))]"
assert str(g) == "FunctionGraph(Op4(Op3(Op2(x, y)), Op1(Op1(x, y))))"
def test_match_same(self):
x, y, z = inputs()
e = op1(x, x)
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, "x", "y"), (op3, "x", "y")).optimize(g)
assert str(g) == "[Op3(x, x)]"
assert str(g) == "FunctionGraph(Op3(x, x))"
def test_match_same_illegal(self):
x, y, z = inputs()
......@@ -177,7 +177,7 @@ class TestPatternOptimizer:
PatternOptimizer(
{"pattern": (op1, "x", "y"), "constraint": constraint}, (op3, "x", "y")
).optimize(g)
assert str(g) == "[Op2(Op1(x, x), Op3(x, y))]"
assert str(g) == "FunctionGraph(Op2(Op1(x, x), Op3(x, y)))"
def test_multi(self):
x, y, z = inputs()
......@@ -185,7 +185,7 @@ class TestPatternOptimizer:
e = op3(op4(e0), e0)
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op4, (op1, "x", "y")), (op3, "x", "y")).optimize(g)
assert str(g) == "[Op3(Op4(*1 -> Op1(x, y)), *1)]"
assert str(g) == "FunctionGraph(Op3(Op4(*1 -> Op1(x, y)), *1))"
def test_eq(self):
# replacing the whole graph
......@@ -194,7 +194,7 @@ class TestPatternOptimizer:
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, (op_z, "1", "2"), "3"), (op4, "3", "2")).optimize(g)
str_g = str(g)
assert str_g == "[Op4(z, y)]"
assert str_g == "FunctionGraph(Op4(z, y))"
# def test_multi_ingraph(self):
......@@ -205,7 +205,7 @@ class TestPatternOptimizer:
# g = FunctionGraph([x, y, z], [e])
# PatternOptimizer((op4, (op1, 'x', 'y'), (op1, 'x', 'y')),
# (op3, 'x', 'y')).optimize(g)
# assert str(g) == "[Op3(x, y)]"
# assert str(g) == "FunctionGraph(Op3(x, y))"
def OpSubOptimizer(op1, op2):
......@@ -218,14 +218,14 @@ class TestOpSubOptimizer:
e = op1(op1(op1(op1(op1(x)))))
g = FunctionGraph([x, y, z], [e])
OpSubOptimizer(op1, op2).optimize(g)
assert str(g) == "[Op2(Op2(Op2(Op2(Op2(x)))))]"
assert str(g) == "FunctionGraph(Op2(Op2(Op2(Op2(Op2(x))))))"
def test_straightforward_2(self):
x, y, z = inputs()
e = op1(op2(x), op3(y), op4(z))
g = FunctionGraph([x, y, z], [e])
OpSubOptimizer(op3, op4).optimize(g)
assert str(g) == "[Op1(Op2(x), Op4(y), Op4(z))]"
assert str(g) == "FunctionGraph(Op1(Op2(x), Op4(y), Op4(z)))"
class NoInputOp(Op):
......@@ -247,7 +247,7 @@ class TestMergeOptimizer:
e = op1(op2(x, y), op2(x, y), op2(x, z))
g = FunctionGraph([x, y, z], [e])
MergeOptimizer().optimize(g)
assert str(g) == "[Op1(*1 -> Op2(x, y), *1, Op2(x, z))]"
assert str(g) == "FunctionGraph(Op1(*1 -> Op2(x, y), *1, Op2(x, z)))"
def test_constant_merging(self):
x = MyVariable("x")
......@@ -258,8 +258,8 @@ class TestMergeOptimizer:
MergeOptimizer().optimize(g)
strg = str(g)
assert (
strg == "[Op1(*1 -> Op2(x, y), *1, *1)]"
or strg == "[Op1(*1 -> Op2(x, z), *1, *1)]"
strg == "FunctionGraph(Op1(*1 -> Op2(x, y), *1, *1))"
or strg == "FunctionGraph(Op1(*1 -> Op2(x, z), *1, *1))"
)
def test_deep_merge(self):
......@@ -267,14 +267,14 @@ class TestMergeOptimizer:
e = op1(op3(op2(x, y), z), op4(op3(op2(x, y), z)))
g = FunctionGraph([x, y, z], [e])
MergeOptimizer().optimize(g)
assert str(g) == "[Op1(*1 -> Op3(Op2(x, y), z), Op4(*1))]"
assert str(g) == "FunctionGraph(Op1(*1 -> Op3(Op2(x, y), z), Op4(*1)))"
def test_no_merge(self):
x, y, z = inputs()
e = op1(op3(op2(x, y)), op3(op2(y, x)))
g = FunctionGraph([x, y, z], [e])
MergeOptimizer().optimize(g)
assert str(g) == "[Op1(Op3(Op2(x, y)), Op3(Op2(y, x)))]"
assert str(g) == "FunctionGraph(Op1(Op3(Op2(x, y)), Op3(Op2(y, x))))"
def test_merge_outputs(self):
x, y, z = inputs()
......@@ -282,7 +282,7 @@ class TestMergeOptimizer:
e2 = op3(op2(x, y))
g = FunctionGraph([x, y, z], [e1, e2])
MergeOptimizer().optimize(g)
assert str(g) == "[*1 -> Op3(Op2(x, y)), *1]"
assert str(g) == "FunctionGraph(*1 -> Op3(Op2(x, y)), *1)"
def test_multiple_merges(self):
x, y, z = inputs()
......@@ -295,9 +295,10 @@ class TestMergeOptimizer:
# note: graph.as_string can only produce the following two possibilities, but if
# the implementation was to change there are 6 other acceptable answers.
assert (
strg == "[Op1(*1 -> Op1(x, y), Op4(*2 -> Op2(Op3(x), y, z), *1), Op1(*2))]"
strg
== "FunctionGraph(Op1(*1 -> Op1(x, y), Op4(*2 -> Op2(Op3(x), y, z), *1), Op1(*2)))"
or strg
== "[Op1(*2 -> Op1(x, y), Op4(*1 -> Op2(Op3(x), y, z), *2), Op1(*1))]"
== "FunctionGraph(Op1(*2 -> Op1(x, y), Op4(*1 -> Op2(Op3(x), y, z), *2), Op1(*1)))"
)
def test_identical_constant_args(self):
......@@ -313,7 +314,7 @@ class TestMergeOptimizer:
g = FunctionGraph([x, y, z], [e1])
MergeOptimizer().optimize(g)
strg = str(g)
assert strg == "[Op1(y, y)]" or strg == "[Op1(z, z)]"
assert strg == "FunctionGraph(Op1(y, y))" or strg == "FunctionGraph(Op1(z, z))"
def est_one_assert_merge(self):
# Merge two nodes, one has assert, the other not.
......@@ -494,7 +495,7 @@ class TestEquilibrium:
)
opt.optimize(g)
# print g
assert str(g) == "[Op2(x, y)]"
assert str(g) == "FunctionGraph(Op2(x, y))"
def test_2(self):
x, y, z = map(MyVariable, "xyz")
......@@ -512,7 +513,7 @@ class TestEquilibrium:
max_use_ratio=10,
)
opt.optimize(g)
assert str(g) == "[Op2(x, y)]"
assert str(g) == "FunctionGraph(Op2(x, y))"
@theano.change_flags(on_opt_error="ignore")
def test_low_use_ratio(self):
......@@ -538,7 +539,7 @@ class TestEquilibrium:
finally:
_logger.setLevel(oldlevel)
# print 'after', g
assert str(g) == "[Op1(x, y)]"
assert str(g) == "FunctionGraph(Op1(x, y))"
def test_pre_constant_merge_slice():
......
......@@ -173,12 +173,12 @@ class TestComposite:
gof.DualLinker().accept(g).make_function()
assert str(g) == (
"[*1 -> Composite{((i0 + i1) + i2),"
"FunctionGraph(*1 -> Composite{((i0 + i1) + i2),"
" (i0 + (i1 * i2)), (i0 / i1), "
"(i0 // Constant{5}), "
"(-i0), (i0 - i1), ((i0 ** i1) + (-i2)),"
" (i0 % Constant{3})}(x, y, z), "
"*1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7]"
"*1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)"
)
def test_make_node_continue_graph(self):
......
......@@ -119,9 +119,11 @@ class TestDimshuffleLift:
x, y, z = inputs()
e = ds(ds(x, (1, 0)), (1, 0))
g = FunctionGraph([x], [e])
assert str(g) == "[InplaceDimShuffle{1,0}(InplaceDimShuffle{1,0}(x))]"
assert (
str(g) == "FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{1,0}(x)))"
)
dimshuffle_lift.optimize(g)
assert str(g) == "[x]"
assert str(g) == "FunctionGraph(x)"
# no need to check_stack_trace as graph is supposed to be empty
def test_merge2(self):
......@@ -129,10 +131,11 @@ class TestDimshuffleLift:
e = ds(ds(x, (1, "x", 0)), (2, 0, "x", 1))
g = FunctionGraph([x], [e])
assert (
str(g) == "[InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{1,x,0}(x))]"
str(g)
== "FunctionGraph(InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{1,x,0}(x)))"
), str(g)
dimshuffle_lift.optimize(g)
assert str(g) == "[InplaceDimShuffle{0,1,x,x}(x)]", str(g)
assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1,x,x}(x))", str(g)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(g, ops_to_check="all")
......@@ -141,11 +144,11 @@ class TestDimshuffleLift:
e = ds(ds(ds(x, (0, "x", 1)), (2, 0, "x", 1)), (1, 0))
g = FunctionGraph([x], [e])
assert str(g) == (
"[InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}"
"(InplaceDimShuffle{0,x,1}(x)))]"
"FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}"
"(InplaceDimShuffle{0,x,1}(x))))"
), str(g)
dimshuffle_lift.optimize(g)
assert str(g) == "[x]", str(g)
assert str(g) == "FunctionGraph(x)", str(g)
# no need to check_stack_trace as graph is supposed to be empty
def test_lift(self):
......@@ -156,22 +159,22 @@ class TestDimshuffleLift:
# It does not really matter if the DimShuffles are inplace
# or not.
init_str_g_inplace = (
"[Elemwise{add,no_inplace}(InplaceDimShuffle{x,0,1}"
"(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0}(x), y)), z)]"
"FunctionGraph(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0,1}"
"(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0}(x), y)), z))"
)
init_str_g_noinplace = (
"[Elemwise{add,no_inplace}(DimShuffle{x,0,1}"
"(Elemwise{add,no_inplace}(DimShuffle{x,0}(x), y)), z)]"
"FunctionGraph(Elemwise{add,no_inplace}(DimShuffle{x,0,1}"
"(Elemwise{add,no_inplace}(DimShuffle{x,0}(x), y)), z))"
)
assert str(g) in (init_str_g_inplace, init_str_g_noinplace), str(g)
opt_str_g_inplace = (
"[Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
"(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z)]"
"FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
"(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z))"
)
opt_str_g_noinplace = (
"[Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
"(DimShuffle{x,x,0}(x), DimShuffle{x,0,1}(y)), z)]"
"FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
"(DimShuffle{x,x,0}(x), DimShuffle{x,0,1}(y)), z))"
)
dimshuffle_lift.optimize(g)
assert str(g) in (opt_str_g_inplace, opt_str_g_noinplace), str(g)
......@@ -184,24 +187,24 @@ class TestDimshuffleLift:
out = ((v + 42) * (m + 84)).T
g = FunctionGraph([v, m], [out])
init_str_g = (
"[InplaceDimShuffle{1,0}(Elemwise{mul,no_inplace}"
"FunctionGraph(InplaceDimShuffle{1,0}(Elemwise{mul,no_inplace}"
"(InplaceDimShuffle{x,0}(Elemwise{add,no_inplace}"
"(<TensorType(float64, vector)>, "
"InplaceDimShuffle{x}(TensorConstant{42}))), "
"Elemwise{add,no_inplace}"
"(<TensorType(float64, matrix)>, "
"InplaceDimShuffle{x,x}(TensorConstant{84}))))]"
"InplaceDimShuffle{x,x}(TensorConstant{84})))))"
)
assert str(g) == init_str_g
new_out = local_dimshuffle_lift.transform(g.outputs[0].owner)[0]
new_g = FunctionGraph(g.inputs, [new_out])
opt_str_g = (
"[Elemwise{mul,no_inplace}(Elemwise{add,no_inplace}"
"FunctionGraph(Elemwise{mul,no_inplace}(Elemwise{add,no_inplace}"
"(InplaceDimShuffle{0,x}(<TensorType(float64, vector)>), "
"InplaceDimShuffle{x,x}(TensorConstant{42})), "
"Elemwise{add,no_inplace}(InplaceDimShuffle{1,0}"
"(<TensorType(float64, matrix)>), "
"InplaceDimShuffle{x,x}(TensorConstant{84})))]"
"InplaceDimShuffle{x,x}(TensorConstant{84}))))"
)
assert str(new_g) == opt_str_g
# Check stacktrace was copied over correctly after opt was applied
......@@ -211,9 +214,9 @@ class TestDimshuffleLift:
x, _, _ = inputs()
e = ds(x, (0, 1))
g = FunctionGraph([x], [e])
assert str(g) == "[InplaceDimShuffle{0,1}(x)]"
assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1}(x))"
dimshuffle_lift.optimize(g)
assert str(g) == "[x]"
assert str(g) == "FunctionGraph(x)"
# Check stacktrace was copied over correctly after opt was applied
assert hasattr(g.outputs[0].tag, "trace")
......@@ -227,12 +230,12 @@ class TestDimshuffleLift:
g = FunctionGraph([x, y, z, u], [ds_x, ds_y, ds_z, ds_u])
assert (
str(g)
== "[InplaceDimShuffle{0,x}(x), InplaceDimShuffle{2,1,0}(y), InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1})]"
== "FunctionGraph(InplaceDimShuffle{0,x}(x), InplaceDimShuffle{2,1,0}(y), InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
)
dimshuffle_lift.optimize(g)
assert (
str(g)
== "[x, y, InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1})]"
== "FunctionGraph(x, y, InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
)
# Check stacktrace was copied over correctly after opt was applied
assert hasattr(g.outputs[0].tag, "trace")
......@@ -261,18 +264,18 @@ def test_local_useless_dimshuffle_in_reshape():
print(str(g))
assert str(g) == (
"[Reshape{1}(InplaceDimShuffle{x,0}(vector), Shape(vector)), "
"FunctionGraph(Reshape{1}(InplaceDimShuffle{x,0}(vector), Shape(vector)), "
"Reshape{2}(InplaceDimShuffle{x,0,x,1}(mat), Shape(mat)), "
"Reshape{2}(InplaceDimShuffle{1,x}(row), Shape(row)), "
"Reshape{2}(InplaceDimShuffle{0}(col), Shape(col))]"
"Reshape{2}(InplaceDimShuffle{0}(col), Shape(col)))"
)
useless_dimshuffle_in_reshape = out2in(local_useless_dimshuffle_in_reshape)
useless_dimshuffle_in_reshape.optimize(g)
assert str(g) == (
"[Reshape{1}(vector, Shape(vector)), "
"FunctionGraph(Reshape{1}(vector, Shape(vector)), "
"Reshape{2}(mat, Shape(mat)), "
"Reshape{2}(row, Shape(row)), "
"Reshape{2}(col, Shape(col))]"
"Reshape{2}(col, Shape(col)))"
)
# Check stacktrace was copied over correctly after opt was applied
......@@ -4762,7 +4765,7 @@ class TestLocalCanonicalizeAlloc:
g = FunctionGraph([x, y, z, w], [alloc_x, alloc_y, alloc_z, alloc_w])
assert str(g) == (
"[Alloc(<TensorType(float64, vector)>, "
"FunctionGraph(Alloc(<TensorType(float64, vector)>, "
"TensorConstant{1}, "
"TensorConstant{3}, "
"TensorConstant{2}), "
......@@ -4775,12 +4778,12 @@ class TestLocalCanonicalizeAlloc:
"TensorConstant{2}), "
"Alloc(<TensorType(float64, matrix)>, "
"TensorConstant{1}, "
"TensorConstant{2})]"
"TensorConstant{2}))"
)
alloc_lift.optimize(g)
assert str(g) == (
"[InplaceDimShuffle{x,0,1}"
"FunctionGraph(InplaceDimShuffle{x,0,1}"
"(Alloc(<TensorType(float64, vector)>, "
"TensorConstant{3}, "
"TensorConstant{2})), "
......@@ -4792,7 +4795,7 @@ class TestLocalCanonicalizeAlloc:
"TensorConstant{2})), "
"Alloc(<TensorType(float64, matrix)>, "
"TensorConstant{1}, "
"TensorConstant{2})]"
"TensorConstant{2}))"
)
# Check stacktrace was copied over correctly after opt was applied
......@@ -7666,22 +7669,22 @@ class TestLocalReshapeToDimshuffle:
g = FunctionGraph([x, y], [reshape_x, reshape_y])
assert str(g) == (
"[Reshape{2}"
"FunctionGraph(Reshape{2}"
"(<TensorType(float64, vector)>, "
"TensorConstant{[1 4]}), "
"Reshape{6}"
"(<TensorType(float64, matrix)>, "
"TensorConstant{[1 5 1 6 1 1]})]"
"TensorConstant{[1 5 1 6 1 1]}))"
)
reshape_lift.optimize(g)
useless_reshape.optimize(g)
assert str(g) == (
"[InplaceDimShuffle{x,0}"
"FunctionGraph(InplaceDimShuffle{x,0}"
"(<TensorType(float64, vector)>), "
"InplaceDimShuffle{x,0,x,1,x,x}"
"(Reshape{2}(<TensorType(float64, matrix)>, "
"TensorConstant{[5 6]}))]"
"TensorConstant{[5 6]})))"
)
# Check stacktrace was copied over correctly after opt was applied
......@@ -7713,7 +7716,7 @@ class TestLiftTransposeThroughDot:
def test_matrix_matrix(self):
a, b = matrices("ab")
g = self.simple_optimize(FunctionGraph([a, b], [tt.dot(a, b).T]))
sg = "[dot(InplaceDimShuffle{1,0}(b), InplaceDimShuffle{1,0}(a))]"
sg = "FunctionGraph(dot(InplaceDimShuffle{1,0}(b), InplaceDimShuffle{1,0}(a)))"
assert str(g) == sg, (str(g), sg)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(g, ops_to_check="all")
......@@ -7725,7 +7728,7 @@ class TestLiftTransposeThroughDot:
FunctionGraph([a, b], [tt.dot(a.dimshuffle("x", 0), b).T]),
level="stabilize",
)
sg = "[dot(InplaceDimShuffle{1,0}(b), InplaceDimShuffle{0,x}(a))]"
sg = "FunctionGraph(dot(InplaceDimShuffle{1,0}(b), InplaceDimShuffle{0,x}(a)))"
assert str(g) == sg, (str(g), sg)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(g, ops_to_check="all")
......@@ -7737,7 +7740,7 @@ class TestLiftTransposeThroughDot:
FunctionGraph([a, b], [tt.dot(b, a.dimshuffle(0, "x")).T]),
level="stabilize",
)
sg = "[dot(InplaceDimShuffle{x,0}(a), InplaceDimShuffle{1,0}(b))]"
sg = "FunctionGraph(dot(InplaceDimShuffle{x,0}(a), InplaceDimShuffle{1,0}(b)))"
assert str(g) == sg, (str(g), sg)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(g, ops_to_check="all")
......
......@@ -812,11 +812,8 @@ class FunctionGraph(utils.object2):
"Inconsistent clients list.", variable, node.inputs[i]
)
def __str__(self):
return f"[{', '.join(graph.as_string(self.inputs, self.outputs))}]"
def __repr__(self):
return self.__str__()
return f"FunctionGraph({', '.join(graph.as_string(self.inputs, self.outputs))})"
def clone(self, check_integrity=True):
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论