提交 60c39df1 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add direct test for nested broadcasted Composite graphs

上级 2c84b496
...@@ -1167,6 +1167,37 @@ class TestFusion: ...@@ -1167,6 +1167,37 @@ class TestFusion:
assert out_val.shape == exp_res.shape assert out_val.shape == exp_res.shape
assert np.allclose(out_val, exp_res) assert np.allclose(out_val, exp_res)
def test_not_fusing_broadcasted_subgraphs(self):
"""Test that broadcasted Elemwise subgraphs are not fused in a single Elemwise Composite Op.
There are some cases in self.test_elemwise_fusion, but this test confirms that the
fused subgraphs are exactly the expected ones.
"""
xs = vector("xm")
xm = matrix("xs")
es = log(xs + 5)
em = exp(xm * 5)
esm = es - em
f = pytensor.function([xs, xm], esm, mode=self.mode)
apply_nodes = f.maker.fgraph.toposort()
assert len(apply_nodes) == 3
assert isinstance(apply_nodes[0].op, DimShuffle)
# Inner Vector output Composite
assert isinstance(apply_nodes[1].op.scalar_op, Composite)
assert {node.op for node in apply_nodes[1].op.scalar_op.fgraph.apply_nodes} == {
aes.add,
aes.log,
}
# Outer Matrix output Composite
assert isinstance(apply_nodes[2].op.scalar_op, Composite)
assert {node.op for node in apply_nodes[2].op.scalar_op.fgraph.apply_nodes} == {
aes.sub,
aes.exp,
aes.mul,
}
class TimesN(aes.basic.UnaryScalarOp): class TimesN(aes.basic.UnaryScalarOp):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论