提交 71a44de8 authored 作者: Mohammad Pezeshki's avatar Mohammad Pezeshki

tests for printing a Composite are added.

上级 21d3f510
...@@ -3108,7 +3108,8 @@ class Composite(ScalarOp): ...@@ -3108,7 +3108,8 @@ class Composite(ScalarOp):
for i, r in enumerate(self.fgraph.variables): for i, r in enumerate(self.fgraph.variables):
if r not in io and len(r.clients) > 1: if r not in io and len(r.clients) > 1:
r.name = 't%i' % i r.name = 't%i' % i
rval = "Composite{%s}" % pprint(self.fgraph.outputs[0]) rval = "Composite{%s}" % ', '.join([pprint(output) for output
in self.fgraph.outputs])
self.name = rval self.name = rval
def init_fgraph(self): def init_fgraph(self):
......
...@@ -146,6 +146,17 @@ class test_composite(unittest.TestCase): ...@@ -146,6 +146,17 @@ class test_composite(unittest.TestCase):
fn = gof.DualLinker().accept(g).make_function() fn = gof.DualLinker().accept(g).make_function()
assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5] assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5]
def test_composite_printing(self):
x, y, z = floats('xyz')
e0 = x + y + z
e1 = x + y * z
e2 = x / y
C = Composite([x, y, z], [e0, e1, e2])
c = C.make_node(x, y, z)
g = FunctionGraph([x, y, z], c.outputs)
fn = gof.DualLinker().accept(g).make_function()
assert str(g) == '[*1 -> Composite{((i0 + i1) + i2), (i0 + (i1 * i2)), (i0 / i1)}(x, y, z), *1::1, *1::2]'
def test_make_node_continue_graph(self): def test_make_node_continue_graph(self):
# This is a test for a bug (now fixed) that disabled the # This is a test for a bug (now fixed) that disabled the
# local_gpu_elemwise_0 optimization and printed an # local_gpu_elemwise_0 optimization and printed an
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论