提交 b12cd96a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Ricardo Vieira

Refactor Composite Op

- Lazily create and cache `FunctionGraph`s, the `Composite.perform` implementation, C code, and name values - Use `fgraph_to_python` for `Composite.perform` - Use the `HasInnerGraph` interface
上级 c0d2c635
差异被折叠。
...@@ -2,6 +2,7 @@ import numpy as np ...@@ -2,6 +2,7 @@ import numpy as np
import pytest import pytest
import pytensor import pytensor
import pytensor.tensor as at
import tests.unittest_tools as utt import tests.unittest_tools as utt
from pytensor.compile.mode import Mode from pytensor.compile.mode import Mode
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
...@@ -130,11 +131,16 @@ class TestComposite: ...@@ -130,11 +131,16 @@ class TestComposite:
def test_with_constants(self): def test_with_constants(self):
x, y, z = floats("xyz") x, y, z = floats("xyz")
e = mul(add(70.0, y), true_div(x, y)) e = mul(add(70.0, y), true_div(x, y))
C = Composite([x, y], [e]) comp_op = Composite([x, y], [e])
c = C.make_node(x, y) comp_node = comp_op.make_node(x, y)
assert "70.0" in c.op.c_code(c, "dummy", ["x", "y"], ["z"], dict(id=0))
# print c.c_code(['x', 'y'], ['z'], dict(id = 0)) c_code = comp_node.op.c_code(comp_node, "dummy", ["x", "y"], ["z"], dict(id=0))
g = FunctionGraph([x, y], [c.out]) assert "70.0" in c_code
# Make sure caching of the c_code template works
assert hasattr(comp_node.op, "_c_code")
g = FunctionGraph([x, y], [comp_node.out])
fn = make_function(DualLinker().accept(g)) fn = make_function(DualLinker().accept(g))
assert fn(1.0, 2.0) == 36.0 assert fn(1.0, 2.0) == 36.0
...@@ -174,24 +180,35 @@ class TestComposite: ...@@ -174,24 +180,35 @@ class TestComposite:
"*1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)" "*1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)"
) )
def test_make_node_continue_graph(self): def test_non_scalar_error(self):
# This is a test for a bug (now fixed) that disabled the x = float32("x")
# local_gpu_elemwise_0 optimization and printed an comp_op = Composite([x], [(at.zeros((2,)) + x).sum()])
# optimization warning on the terminal.
with pytest.raises(TypeError, match=".*exclusively.*ScalarOp.*"):
# We test that Composite.make_node accept as inputs Variable comp_op.fgraph
# some that represent existing computation.
def test_multi_out_perform(self):
si0 = pytensor.scalar.int8() from pytensor.graph.basic import Apply
si1 = pytensor.scalar.int8() from pytensor.scalar.basic import ScalarOp
si2 = pytensor.scalar.float32()
sout = (si0 * si1) / si2 class MultiOutOp(ScalarOp):
sop = pytensor.scalar.Composite([si0, si1, si2], [sout]) def make_node(self, x):
si0 = pytensor.scalar.int8() return Apply(self, [x], [x.type(), x.type()])
si1 = pytensor.scalar.int8()
si2 = pytensor.scalar.float32() def perform(self, node, inputs, outputs):
si3 = pytensor.scalar.float32() outputs[1][0] = outputs[0][0] = inputs[0]
sop.make_node(si0 * si3, si1, si2)
def c_code(self, *args):
return "dummy"
x = float32("x")
comp_op = Composite([x], MultiOutOp()(x))
y, z = comp_op(x)
fn = pytensor.function([x], [y, z], mode=Mode("py", None))
assert fn(1.0) == [1.0, 1.0]
class TestLogical: class TestLogical:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论