提交 b12cd96a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Ricardo Vieira

Refactor Composite Op

- Lazily create and cache `FunctionGraph`s, the `Composite.perform` implementation, C code, and name values - Use `fgraph_to_python` for `Composite.perform` - Use the `HasInnerGraph` interface
上级 c0d2c635
...@@ -27,6 +27,7 @@ from pytensor.configdefaults import config ...@@ -27,6 +27,7 @@ from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType, grad_undefined from pytensor.gradient import DisconnectedType, grad_undefined
from pytensor.graph.basic import Apply, Constant, Variable, clone, list_of_nodes from pytensor.graph.basic import Apply, Constant, Variable, clone, list_of_nodes
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import HasInnerGraph
from pytensor.graph.rewriting.basic import MergeOptimizer from pytensor.graph.rewriting.basic import MergeOptimizer
from pytensor.graph.type import HasDataType, HasShape from pytensor.graph.type import HasDataType, HasShape
from pytensor.graph.utils import MetaObject, MethodNotDefined from pytensor.graph.utils import MetaObject, MethodNotDefined
...@@ -3987,7 +3988,7 @@ class ComplexFromPolar(BinaryScalarOp): ...@@ -3987,7 +3988,7 @@ class ComplexFromPolar(BinaryScalarOp):
complex_from_polar = ComplexFromPolar(name="complex_from_polar") complex_from_polar = ComplexFromPolar(name="complex_from_polar")
class Composite(ScalarOp): class Composite(ScalarOp, HasInnerGraph):
""" """
Composite is an Op that takes a graph of scalar operations and Composite is an Op that takes a graph of scalar operations and
produces c code for the whole graph. Its purpose is to implement loop produces c code for the whole graph. Its purpose is to implement loop
...@@ -3999,9 +4000,65 @@ class Composite(ScalarOp): ...@@ -3999,9 +4000,65 @@ class Composite(ScalarOp):
init_param: Union[Tuple[str, str], Tuple[str]] = ("inputs", "outputs") init_param: Union[Tuple[str, str], Tuple[str]] = ("inputs", "outputs")
def __init__(self, inputs, outputs):
# We need to clone the graph as sometimes its nodes already
# contain a reference to an fgraph. As we want the Composite
# to be pickable, we can't have reference to fgraph.
# Also, if there is Composite in the inner graph, we want to
# remove them. In that case, we do a more complicated clone
# that will flatten Composite. We don't need to do this
# recursively, as the way the fusion optimizer work, we have
# only 1 new Composite each time at the output.
for i in inputs:
assert i not in outputs # This isn't supported, use identity
if len(outputs) > 1 or not any(
isinstance(var.owner.op, Composite) for var in outputs
):
# No inner Composite
inputs, outputs = clone(inputs, outputs)
else:
# Inner Composite that we need to flatten
assert len(outputs) == 1
# 1. Create a new graph from inputs up to the
# Composite
res = pytensor.compile.rebuild_collect_shared(
inputs=inputs, outputs=outputs[0].owner.inputs, copy_inputs_over=False
) # Clone also the inputs
# 2. We continue this partial clone with the graph in
# the inner Composite
res2 = pytensor.compile.rebuild_collect_shared(
inputs=outputs[0].owner.op.inputs,
outputs=outputs[0].owner.op.outputs,
replace=dict(zip(outputs[0].owner.op.inputs, res[1])),
)
assert len(res2[1]) == len(outputs)
assert len(res[0]) == len(inputs)
assert res[0] != inputs
inputs, outputs = res[0], res2[1]
self.inputs = copy(inputs)
self.outputs = copy(outputs)
self.inputs_type = tuple([input.type for input in inputs])
self.outputs_type = tuple([output.type for output in outputs])
self.nin = len(inputs)
self.nout = len(outputs)
self.prepare_node_called = set()
@property
def fn(self):
return self._fn
@property
def inner_inputs(self):
return self.fgraph.inputs
@property
def inner_outputs(self):
return self.fgraph.outputs
def __str__(self): def __str__(self):
if self.name is None:
self.init_name()
return self.name return self.name
def make_new_inplace(self, output_types_preference=None, name=None): def make_new_inplace(self, output_types_preference=None, name=None):
...@@ -4020,127 +4077,42 @@ class Composite(ScalarOp): ...@@ -4020,127 +4077,42 @@ class Composite(ScalarOp):
super(Composite, out).__init__(output_types_preference, name) super(Composite, out).__init__(output_types_preference, name)
return out return out
def init_c_code(self): @property
""" def py_perform(self):
Assemble the C code for this Composite Op. if hasattr(self, "_py_perform_fn"):
return self._py_perform_fn
The result is assigned to `self._c_code`.
"""
from pytensor.link.c.interface import CLinkerType
# It was already called
if hasattr(self, "_c_code"):
return
subd = dict(
chain(
((e, f"%(i{int(i)})s") for i, e in enumerate(self.fgraph.inputs)),
((e, f"%(o{int(i)})s") for i, e in enumerate(self.fgraph.outputs)),
)
)
for var in self.fgraph.variables:
if var.owner is None:
if var not in self.fgraph.inputs:
# This is an orphan
if isinstance(var, Constant) and isinstance(var.type, CLinkerType):
subd[var] = var.type.c_literal(var.data)
else:
raise ValueError(
"All orphans in the fgraph to Composite must"
" be Constant, CLinkerType instances."
)
elif any(i.dtype == "float16" for i in var.owner.inputs) or any(
o.dtype == "float16" for o in var.owner.outputs
):
# flag for elemwise ops to check.
self.inner_float16 = True
_c_code = "{\n"
self.nodenames = [
f"%(nodename)s_subnode{int(j)}"
for j, n in enumerate(self.fgraph.toposort())
]
i = 0
for j, node in enumerate(self.fgraph.toposort()):
for output in node.outputs:
if output not in subd:
i += 1
name = f"V%(id)s_tmp{int(i)}"
subd[output] = name
_c_code += f"{output.type.dtype_specs()[1]} {name};\n"
s = node.op.c_code(
node,
self.nodenames[j],
[subd[input] for input in node.inputs],
[subd[output] for output in node.outputs],
dict(fail="%(fail)s", id=f"%(id)s_{int(j)}"),
)
_c_code += s
_c_code += "\n"
_c_code += "}\n"
self._c_code = _c_code
def init_py_impls(self):
"""
Return a list of functions that compute each output of self.
"""
# In the case where the graph is a dag, but not a tree like:
# add(*1 -> mul(x, y), *1)
# We have an efficient way to build the executable (we build
# and traverse each node only once).
# But we don't have an efficient execution. We will execute
# like a tree, so nodes that have more then 1 client will be
# executed as many times as there number of clients. In the
# example above, it will calculate *1 twice. Doing otherwise
# imply making a complicated execution engine.
# We need the fast creation of the executor as we always do it
# even if we will use the c code. The Python implementation is
# already slow, so it is not as much important to have a fast
# execution there.
memo = {} from pytensor.link.utils import fgraph_to_python
def compose_impl(r): def python_convert(op, node=None, **kwargs):
if r in memo: assert node is not None
return memo[r]
if r in self.fgraph.inputs:
idx = self.fgraph.inputs.index(r)
def f(inputs): n_outs = len(node.outputs)
return inputs[idx]
memo[r] = f if n_outs > 1:
return f
elif r.owner is None: # in fgraph.orphans:
def f(inputs): def _perform(*inputs, outputs=[[None]] * n_outs):
return r.data op.perform(node, inputs, outputs)
return tuple(o[0] for o in outputs)
memo[r] = f else:
return f
node = r.owner
producers = [compose_impl(input) for input in node.inputs]
def f(inputs): def _perform(*inputs, outputs=[[None]]):
return node.op.impl(*[p(inputs) for p in producers]) op.perform(node, inputs, outputs)
return outputs[0][0]
memo[r] = f return _perform
return f
self._impls = [compose_impl(r) for r in self.fgraph.outputs] self._py_perform_fn = fgraph_to_python(self.fgraph, python_convert)
return self._py_perform_fn
def init_name(self): @property
""" def name(self):
Return a readable string representation of self.fgraph. if hasattr(self, "_name"):
return self._name
""" # TODO FIXME: Just implement pretty printing for the `Op`; don't do
rval = self.name # this redundant, outside work in the `Op` itself.
if rval is None:
for i, r in enumerate(self.fgraph.inputs): for i, r in enumerate(self.fgraph.inputs):
r.name = f"i{int(i)}" r.name = f"i{int(i)}"
for i, r in enumerate(self.fgraph.outputs): for i, r in enumerate(self.fgraph.outputs):
...@@ -4151,9 +4123,18 @@ class Composite(ScalarOp): ...@@ -4151,9 +4123,18 @@ class Composite(ScalarOp):
r.name = f"t{int(i)}" r.name = f"t{int(i)}"
outputs_str = ", ".join([pprint(output) for output in self.fgraph.outputs]) outputs_str = ", ".join([pprint(output) for output in self.fgraph.outputs])
rval = f"Composite{{{outputs_str}}}" rval = f"Composite{{{outputs_str}}}"
self.name = rval self._name = rval
return self._name
@name.setter
def name(self, name):
self._name = name
@property
def fgraph(self):
if hasattr(self, "_fgraph"):
return self._fgraph
def init_fgraph(self):
# The clone done by FunctionGraph is needed as we don't want # The clone done by FunctionGraph is needed as we don't want
# the fgraph to be set to the variable as we need to pickle # the fgraph to be set to the variable as we need to pickle
# them for the cache of c module to work. # them for the cache of c module to work.
...@@ -4161,64 +4142,14 @@ class Composite(ScalarOp): ...@@ -4161,64 +4142,14 @@ class Composite(ScalarOp):
MergeOptimizer().rewrite(fgraph) MergeOptimizer().rewrite(fgraph)
for node in fgraph.apply_nodes: for node in fgraph.apply_nodes:
if not isinstance(node.op, ScalarOp): if not isinstance(node.op, ScalarOp):
raise ValueError( raise TypeError(
"The fgraph to Composite must be exclusively" "The fgraph to Composite must be exclusively"
" composed of ScalarOp instances." " composed of ScalarOp instances."
) )
self.fgraph = fgraph self._fgraph = fgraph
return self._fgraph
def __init__(self, inputs, outputs):
# We need to clone the graph as sometimes its nodes already
# contain a reference to an fgraph. As we want the Composite
# to be pickable, we can't have reference to fgraph.
# Also, if there is Composite in the inner graph, we want to
# remove them. In that case, we do a more complicated clone
# that will flatten Composite. We don't need to do this
# recursively, as the way the fusion optimizer work, we have
# only 1 new Composite each time at the output.
for i in inputs:
assert i not in outputs # This isn't supported, use identity
if len(outputs) > 1 or not any(
isinstance(var.owner.op, Composite) for var in outputs
):
# No inner Composite
inputs, outputs = clone(inputs, outputs)
else:
# Inner Composite that we need to flatten
assert len(outputs) == 1
# 1. Create a new graph from inputs up to the
# Composite
res = pytensor.compile.rebuild_collect_shared(
inputs=inputs, outputs=outputs[0].owner.inputs, copy_inputs_over=False
) # Clone also the inputs
# 2. We continue this partial clone with the graph in
# the inner Composite
res2 = pytensor.compile.rebuild_collect_shared(
inputs=outputs[0].owner.op.inputs,
outputs=outputs[0].owner.op.outputs,
replace=dict(zip(outputs[0].owner.op.inputs, res[1])),
)
assert len(res2[1]) == len(outputs)
assert len(res[0]) == len(inputs)
assert res[0] != inputs
inputs, outputs = res[0], res2[1]
self.inputs = copy(inputs)
self.outputs = copy(outputs)
self.inputs_type = tuple([input.type for input in inputs])
self.outputs_type = tuple([output.type for output in outputs])
self.nin = len(inputs)
self.nout = len(outputs)
self.init_fgraph() # self.fgraph
# Postpone the creation in case it isn't needed.
# self.init_name() # self.name
self.name = None
self.prepare_node_called = set()
def prepare_node(self, node, storage_map, compute_map, impl): def prepare_node(self, node, storage_map, compute_map, impl):
if impl == "py":
self.init_py_impls() # self._impls
if impl not in self.prepare_node_called: if impl not in self.prepare_node_called:
for n in list_of_nodes(self.inputs, self.outputs): for n in list_of_nodes(self.inputs, self.outputs):
n.op.prepare_node(n, None, None, impl) n.op.prepare_node(n, None, None, impl)
...@@ -4229,7 +4160,13 @@ class Composite(ScalarOp): ...@@ -4229,7 +4160,13 @@ class Composite(ScalarOp):
new_ins, new_outs = composite_f32.apply(self.fgraph) new_ins, new_outs = composite_f32.apply(self.fgraph)
return Composite(new_ins, new_outs) return Composite(new_ins, new_outs)
def clone(self):
new_ins, new_outs = composite_f32.apply(self.fgraph)
return Composite(new_ins, new_outs)
def output_types(self, input_types): def output_types(self, input_types):
# TODO FIXME: What's the intended purpose/use of this method, and why
# does it even need to be a method?
if tuple(input_types) != self.inputs_type: if tuple(input_types) != self.inputs_type:
raise TypeError( raise TypeError(
f"Wrong types for Composite. Expected {self.inputs_type}, got {tuple(input_types)}." f"Wrong types for Composite. Expected {self.inputs_type}, got {tuple(input_types)}."
...@@ -4256,8 +4193,9 @@ class Composite(ScalarOp): ...@@ -4256,8 +4193,9 @@ class Composite(ScalarOp):
return node return node
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
for storage, impl in zip(output_storage, self._impls): outputs = self.py_perform(*inputs)
storage[0] = impl(inputs) for storage, out_val in zip(output_storage, outputs):
storage[0] = out_val
def impl(self, *inputs): def impl(self, *inputs):
output_storage = [[None] for i in range(self.nout)] output_storage = [[None] for i in range(self.nout)]
...@@ -4270,8 +4208,110 @@ class Composite(ScalarOp): ...@@ -4270,8 +4208,110 @@ class Composite(ScalarOp):
def grad(self, inputs, output_grads): def grad(self, inputs, output_grads):
raise NotImplementedError("grad is not implemented for Composite") raise NotImplementedError("grad is not implemented for Composite")
def __eq__(self, other):
if self is other:
return True
if (
type(self) != type(other)
or self.nin != other.nin
or self.nout != other.nout
):
return False
# TODO FIXME: Why this? Shouldn't we expect equivalent inputs to this
# object to generate the same `_c_code`?
return self.c_code_template == other.c_code_template
def __hash__(self):
# Note that in general, the configparser settings at the time
# of code generation (__init__) affect the semantics of this Op.
# This function assumes that all relevant info about the configparser
# is embodied in _c_code. So the _c_code, rather than self.fgraph,
# is the signature of the semantics of this Op.
# _c_code is preserved through unpickling, so the Op will not change
# semantics when it is reloaded with different configparser
# settings.
#
# TODO FIXME: Doesn't the above just mean that we should be including
# the relevant "configparser settings" here? Also, why should we even
# care about the exact form of the generated C code when comparing
# `Op`s? All this smells of leaky concerns and interfaces.
return hash((type(self), self.nin, self.nout, self.c_code_template))
def __getstate__(self):
rval = dict(self.__dict__)
rval.pop("_c_code", None)
rval.pop("_py_perform_fn", None)
rval.pop("_fgraph", None)
rval.pop("prepare_node_called", None)
return rval
def __setstate__(self, d):
self.__dict__.update(d)
self.prepare_node_called = set()
@property
def c_code_template(self):
from pytensor.link.c.interface import CLinkerType
if hasattr(self, "_c_code"):
return self._c_code
subd = dict(
chain(
((e, f"%(i{int(i)})s") for i, e in enumerate(self.fgraph.inputs)),
((e, f"%(o{int(i)})s") for i, e in enumerate(self.fgraph.outputs)),
)
)
for var in self.fgraph.variables:
if var.owner is None:
if var not in self.fgraph.inputs:
# This is an orphan
if isinstance(var, Constant) and isinstance(var.type, CLinkerType):
subd[var] = var.type.c_literal(var.data)
else:
raise ValueError(
"All orphans in the fgraph to Composite must"
" be Constant, CLinkerType instances."
)
elif any(i.dtype == "float16" for i in var.owner.inputs) or any(
o.dtype == "float16" for o in var.owner.outputs
):
# flag for elemwise ops to check.
self.inner_float16 = True
_c_code = "{\n"
self.nodenames = [
f"%(nodename)s_subnode{int(j)}"
for j, n in enumerate(self.fgraph.toposort())
]
i = 0
for j, node in enumerate(self.fgraph.toposort()):
for output in node.outputs:
if output not in subd:
i += 1
name = f"V%(id)s_tmp{int(i)}"
subd[output] = name
_c_code += f"{output.type.dtype_specs()[1]} {name};\n"
s = node.op.c_code(
node,
self.nodenames[j],
[subd[input] for input in node.inputs],
[subd[output] for output in node.outputs],
dict(fail="%(fail)s", id=f"%(id)s_{int(j)}"),
)
_c_code += s
_c_code += "\n"
_c_code += "}\n"
self._c_code = _c_code
return self._c_code
def c_code(self, node, nodename, inames, onames, sub): def c_code(self, node, nodename, inames, onames, sub):
self.init_c_code()
d = dict( d = dict(
chain( chain(
...@@ -4286,7 +4326,7 @@ class Composite(ScalarOp): ...@@ -4286,7 +4326,7 @@ class Composite(ScalarOp):
# It won't generate conflicting variable name. # It won't generate conflicting variable name.
d["id"] = "_DUMMY_ID_" d["id"] = "_DUMMY_ID_"
return self._c_code % d return self.c_code_template % d
def c_code_cache_version(self): def c_code_cache_version(self):
rval = [3] rval = [3]
...@@ -4314,7 +4354,6 @@ class Composite(ScalarOp): ...@@ -4314,7 +4354,6 @@ class Composite(ScalarOp):
return "\n".join(sorted(rval)) return "\n".join(sorted(rval))
def c_support_code_apply(self, node, name): def c_support_code_apply(self, node, name):
self.init_c_code()
rval = [] rval = []
for subnode, subnodename in zip(self.fgraph.toposort(), self.nodenames): for subnode, subnodename in zip(self.fgraph.toposort(), self.nodenames):
subnode_support_code = subnode.op.c_support_code_apply( subnode_support_code = subnode.op.c_support_code_apply(
...@@ -4328,49 +4367,6 @@ class Composite(ScalarOp): ...@@ -4328,49 +4367,6 @@ class Composite(ScalarOp):
# c_support_code instead of c_support_code_apply. # c_support_code instead of c_support_code_apply.
return "\n".join(rval) return "\n".join(rval)
def __eq__(self, other):
if self is other:
return True
if (
type(self) != type(other)
or self.nin != other.nin
or self.nout != other.nout
):
return False
# see __hash__ for comment on why there is no mention of fgraph
# or module cache key here.
self.init_c_code() # self._c_code and self.nodenames
other.init_c_code()
return self._c_code == other._c_code
def __hash__(self):
self.init_c_code() # self._c_code and self.nodenames
rval = hash((type(self), self.nin, self.nout, self._c_code))
# Note that in general, the configparser settings at the time
# of code generation (__init__) affect the semantics of this Op.
# This function assumes that all relevant info about the configparser
# is embodied in _c_code. So the _c_code, rather than self.fgraph,
# is the signature of the semantics of this Op.
# _c_code is preserved through unpickling, so the Op will not change
# semantics when it is reloaded with different configparser
# settings.
return rval
def __getstate__(self):
rval = dict(self.__dict__)
rval.pop("_impls", None)
rval.pop("prepare_node_called", None)
del rval["fgraph"]
return rval
def __setstate__(self, d):
self.__dict__.update(d)
# We must call init to set fgraph and _impls again, as otherwise
# self.perform will not work.
self.prepare_node_called = set()
self.init_fgraph()
self.init_py_impls()
class Compositef32: class Compositef32:
# This is a dict of scalar op classes that need special handling # This is a dict of scalar op classes that need special handling
......
...@@ -2,6 +2,7 @@ import numpy as np ...@@ -2,6 +2,7 @@ import numpy as np
import pytest import pytest
import pytensor import pytensor
import pytensor.tensor as at
import tests.unittest_tools as utt import tests.unittest_tools as utt
from pytensor.compile.mode import Mode from pytensor.compile.mode import Mode
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
...@@ -130,11 +131,16 @@ class TestComposite: ...@@ -130,11 +131,16 @@ class TestComposite:
def test_with_constants(self): def test_with_constants(self):
x, y, z = floats("xyz") x, y, z = floats("xyz")
e = mul(add(70.0, y), true_div(x, y)) e = mul(add(70.0, y), true_div(x, y))
C = Composite([x, y], [e]) comp_op = Composite([x, y], [e])
c = C.make_node(x, y) comp_node = comp_op.make_node(x, y)
assert "70.0" in c.op.c_code(c, "dummy", ["x", "y"], ["z"], dict(id=0))
# print c.c_code(['x', 'y'], ['z'], dict(id = 0)) c_code = comp_node.op.c_code(comp_node, "dummy", ["x", "y"], ["z"], dict(id=0))
g = FunctionGraph([x, y], [c.out]) assert "70.0" in c_code
# Make sure caching of the c_code template works
assert hasattr(comp_node.op, "_c_code")
g = FunctionGraph([x, y], [comp_node.out])
fn = make_function(DualLinker().accept(g)) fn = make_function(DualLinker().accept(g))
assert fn(1.0, 2.0) == 36.0 assert fn(1.0, 2.0) == 36.0
...@@ -174,24 +180,35 @@ class TestComposite: ...@@ -174,24 +180,35 @@ class TestComposite:
"*1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)" "*1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)"
) )
def test_make_node_continue_graph(self): def test_non_scalar_error(self):
# This is a test for a bug (now fixed) that disabled the x = float32("x")
# local_gpu_elemwise_0 optimization and printed an comp_op = Composite([x], [(at.zeros((2,)) + x).sum()])
# optimization warning on the terminal.
with pytest.raises(TypeError, match=".*exclusively.*ScalarOp.*"):
# We test that Composite.make_node accept as inputs Variable comp_op.fgraph
# some that represent existing computation.
def test_multi_out_perform(self):
si0 = pytensor.scalar.int8() from pytensor.graph.basic import Apply
si1 = pytensor.scalar.int8() from pytensor.scalar.basic import ScalarOp
si2 = pytensor.scalar.float32()
sout = (si0 * si1) / si2 class MultiOutOp(ScalarOp):
sop = pytensor.scalar.Composite([si0, si1, si2], [sout]) def make_node(self, x):
si0 = pytensor.scalar.int8() return Apply(self, [x], [x.type(), x.type()])
si1 = pytensor.scalar.int8()
si2 = pytensor.scalar.float32() def perform(self, node, inputs, outputs):
si3 = pytensor.scalar.float32() outputs[1][0] = outputs[0][0] = inputs[0]
sop.make_node(si0 * si3, si1, si2)
def c_code(self, *args):
return "dummy"
x = float32("x")
comp_op = Composite([x], MultiOutOp()(x))
y, z = comp_op(x)
fn = pytensor.function([x], [y, z], mode=Mode("py", None))
assert fn(1.0) == [1.0, 1.0]
class TestLogical: class TestLogical:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论