提交 871eee1b authored 作者: Frederic's avatar Frederic

Fix crash for Composite with multiple outputs just introduced and test the flatte.

上级 4f15d801
......@@ -2980,7 +2980,8 @@ class Composite(ScalarOp):
# that will flatten Composite. We don't need to do this
# recusively, as the way the fusion optimizer work, we have
# only 1 new Composite each time at the output.
if not any([isinstance(var.owner.op, Composite) for var in outputs]):
if len(outputs) > 1 or not any([isinstance(var.owner.op, Composite)
for var in outputs]):
# No inner Composite
inputs, outputs = gof.graph.clone(inputs, outputs)
else:
......@@ -3004,8 +3005,8 @@ class Composite(ScalarOp):
assert res[0] != inputs
inputs, outputs = res[0], res2[1]
# Next assert comment just for speed
assert not any([isinstance(node.op, Composite) for node in
theano.gof.graph.ops(inputs, outputs)])
#assert not any([isinstance(node.op, Composite) for node in
# theano.gof.graph.ops(inputs, outputs)])
self.inputs = copy(inputs)
self.outputs = copy(outputs)
......
......@@ -68,19 +68,17 @@ class test_composite(unittest.TestCase):
fn = gof.DualLinker().accept(g).make_function()
assert fn(1.0, 2.0) == 1.5
# def test_sin(self):
# x = inputs()
# e = sin(x)
# C = Composite([x], [e])
# c = C.make_node(x)
# # print c.c_code(['x'], ['z'], dict(id = 0))
# g = FunctionGraph([x], [c.out])
# fn = gof.DualLinker().accept(g).make_function()
# assert fn(0) == 0
# assert fn(3.14159265358/2) == 1
# assert fn(3.14159265358) == 0
# WRITEME: Test for sin, pow, and other scalar ops.
def test_flatten(self):
#Test that we flatten multiple Composite.
x, y, z = inputs()
C = Composite([x, y], [x + y])
CC = Composite([x, y], [C(x * y, y)])
assert not isinstance(CC.outputs[0].owner.op, Composite)
# Test with multiple outputs
CC = Composite([x, y, z], [C(x * y, y), C(x * z, y)])
#We don't flatten that case.
assert isinstance(CC.outputs[0].owner.op, Composite)
def test_with_constants(self):
x, y, z = inputs()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论