提交 b12cd96a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Ricardo Vieira

Refactor Composite Op

- Lazily create and cache `FunctionGraph`s, the `Composite.perform` implementation, C code, and name values - Use `fgraph_to_python` for `Composite.perform` - Use the `HasInnerGraph` interface
上级 c0d2c635
差异被折叠。
......@@ -2,6 +2,7 @@ import numpy as np
import pytest
import pytensor
import pytensor.tensor as at
import tests.unittest_tools as utt
from pytensor.compile.mode import Mode
from pytensor.graph.fg import FunctionGraph
......@@ -130,11 +131,16 @@ class TestComposite:
def test_with_constants(self):
x, y, z = floats("xyz")
e = mul(add(70.0, y), true_div(x, y))
C = Composite([x, y], [e])
c = C.make_node(x, y)
assert "70.0" in c.op.c_code(c, "dummy", ["x", "y"], ["z"], dict(id=0))
# print c.c_code(['x', 'y'], ['z'], dict(id = 0))
g = FunctionGraph([x, y], [c.out])
comp_op = Composite([x, y], [e])
comp_node = comp_op.make_node(x, y)
c_code = comp_node.op.c_code(comp_node, "dummy", ["x", "y"], ["z"], dict(id=0))
assert "70.0" in c_code
# Make sure caching of the c_code template works
assert hasattr(comp_node.op, "_c_code")
g = FunctionGraph([x, y], [comp_node.out])
fn = make_function(DualLinker().accept(g))
assert fn(1.0, 2.0) == 36.0
......@@ -174,24 +180,35 @@ class TestComposite:
"*1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)"
)
def test_make_node_continue_graph(self):
# This is a test for a bug (now fixed) that disabled the
# local_gpu_elemwise_0 optimization and printed an
# optimization warning on the terminal.
# We test that Composite.make_node accept as inputs Variable
# some that represent existing computation.
si0 = pytensor.scalar.int8()
si1 = pytensor.scalar.int8()
si2 = pytensor.scalar.float32()
sout = (si0 * si1) / si2
sop = pytensor.scalar.Composite([si0, si1, si2], [sout])
si0 = pytensor.scalar.int8()
si1 = pytensor.scalar.int8()
si2 = pytensor.scalar.float32()
si3 = pytensor.scalar.float32()
sop.make_node(si0 * si3, si1, si2)
def test_non_scalar_error(self):
x = float32("x")
comp_op = Composite([x], [(at.zeros((2,)) + x).sum()])
with pytest.raises(TypeError, match=".*exclusively.*ScalarOp.*"):
comp_op.fgraph
def test_multi_out_perform(self):
from pytensor.graph.basic import Apply
from pytensor.scalar.basic import ScalarOp
class MultiOutOp(ScalarOp):
def make_node(self, x):
return Apply(self, [x], [x.type(), x.type()])
def perform(self, node, inputs, outputs):
outputs[1][0] = outputs[0][0] = inputs[0]
def c_code(self, *args):
return "dummy"
x = float32("x")
comp_op = Composite([x], MultiOutOp()(x))
y, z = comp_op(x)
fn = pytensor.function([x], [y, z], mode=Mode("py", None))
assert fn(1.0) == [1.0, 1.0]
class TestLogical:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论