提交 871eee1b authored 作者: Frederic's avatar Frederic

Fix crash for Composite with multiple outputs just introduced and test the flatte.

上级 4f15d801
...@@ -2980,7 +2980,8 @@ class Composite(ScalarOp): ...@@ -2980,7 +2980,8 @@ class Composite(ScalarOp):
# that will flatten Composite. We don't need to do this # that will flatten Composite. We don't need to do this
# recusively, as the way the fusion optimizer work, we have # recusively, as the way the fusion optimizer work, we have
# only 1 new Composite each time at the output. # only 1 new Composite each time at the output.
if not any([isinstance(var.owner.op, Composite) for var in outputs]): if len(outputs) > 1 or not any([isinstance(var.owner.op, Composite)
for var in outputs]):
# No inner Composite # No inner Composite
inputs, outputs = gof.graph.clone(inputs, outputs) inputs, outputs = gof.graph.clone(inputs, outputs)
else: else:
...@@ -3004,8 +3005,8 @@ class Composite(ScalarOp): ...@@ -3004,8 +3005,8 @@ class Composite(ScalarOp):
assert res[0] != inputs assert res[0] != inputs
inputs, outputs = res[0], res2[1] inputs, outputs = res[0], res2[1]
# Next assert comment just for speed # Next assert comment just for speed
assert not any([isinstance(node.op, Composite) for node in #assert not any([isinstance(node.op, Composite) for node in
theano.gof.graph.ops(inputs, outputs)]) # theano.gof.graph.ops(inputs, outputs)])
self.inputs = copy(inputs) self.inputs = copy(inputs)
self.outputs = copy(outputs) self.outputs = copy(outputs)
......
...@@ -68,19 +68,17 @@ class test_composite(unittest.TestCase): ...@@ -68,19 +68,17 @@ class test_composite(unittest.TestCase):
fn = gof.DualLinker().accept(g).make_function() fn = gof.DualLinker().accept(g).make_function()
assert fn(1.0, 2.0) == 1.5 assert fn(1.0, 2.0) == 1.5
# def test_sin(self): def test_flatten(self):
# x = inputs() #Test that we flatten multiple Composite.
# e = sin(x) x, y, z = inputs()
# C = Composite([x], [e]) C = Composite([x, y], [x + y])
# c = C.make_node(x) CC = Composite([x, y], [C(x * y, y)])
# # print c.c_code(['x'], ['z'], dict(id = 0)) assert not isinstance(CC.outputs[0].owner.op, Composite)
# g = FunctionGraph([x], [c.out])
# fn = gof.DualLinker().accept(g).make_function() # Test with multiple outputs
# assert fn(0) == 0 CC = Composite([x, y, z], [C(x * y, y), C(x * z, y)])
# assert fn(3.14159265358/2) == 1 #We don't flatten that case.
# assert fn(3.14159265358) == 0 assert isinstance(CC.outputs[0].owner.op, Composite)
# WRITEME: Test for sin, pow, and other scalar ops.
def test_with_constants(self): def test_with_constants(self):
x, y, z = inputs() x, y, z = inputs()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论