提交 b12cd96a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Ricardo Vieira

Refactor Composite Op

- Lazily create and cache `FunctionGraph`s, the `Composite.perform` implementation, C code, and name values - Use `fgraph_to_python` for `Composite.perform` - Use the `HasInnerGraph` interface
上级 c0d2c635
...@@ -27,6 +27,7 @@ from pytensor.configdefaults import config ...@@ -27,6 +27,7 @@ from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType, grad_undefined from pytensor.gradient import DisconnectedType, grad_undefined
from pytensor.graph.basic import Apply, Constant, Variable, clone, list_of_nodes from pytensor.graph.basic import Apply, Constant, Variable, clone, list_of_nodes
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import HasInnerGraph
from pytensor.graph.rewriting.basic import MergeOptimizer from pytensor.graph.rewriting.basic import MergeOptimizer
from pytensor.graph.type import HasDataType, HasShape from pytensor.graph.type import HasDataType, HasShape
from pytensor.graph.utils import MetaObject, MethodNotDefined from pytensor.graph.utils import MetaObject, MethodNotDefined
...@@ -3987,7 +3988,7 @@ class ComplexFromPolar(BinaryScalarOp): ...@@ -3987,7 +3988,7 @@ class ComplexFromPolar(BinaryScalarOp):
complex_from_polar = ComplexFromPolar(name="complex_from_polar") complex_from_polar = ComplexFromPolar(name="complex_from_polar")
class Composite(ScalarOp): class Composite(ScalarOp, HasInnerGraph):
""" """
Composite is an Op that takes a graph of scalar operations and Composite is an Op that takes a graph of scalar operations and
produces c code for the whole graph. Its purpose is to implement loop produces c code for the whole graph. Its purpose is to implement loop
...@@ -3999,174 +4000,6 @@ class Composite(ScalarOp): ...@@ -3999,174 +4000,6 @@ class Composite(ScalarOp):
init_param: Union[Tuple[str, str], Tuple[str]] = ("inputs", "outputs") init_param: Union[Tuple[str, str], Tuple[str]] = ("inputs", "outputs")
def __str__(self):
if self.name is None:
self.init_name()
return self.name
def make_new_inplace(self, output_types_preference=None, name=None):
"""
This op.__init__ fct don't have the same parameter as other scalar op.
This break the insert_inplace_optimizer optimization.
This fct allow fix patch this.
"""
d = {k: getattr(self, k) for k in self.init_param}
out = self.__class__(**d)
if name:
out.name = name
else:
name = out.name
super(Composite, out).__init__(output_types_preference, name)
return out
def init_c_code(self):
"""
Assemble the C code for this Composite Op.
The result is assigned to `self._c_code`.
"""
from pytensor.link.c.interface import CLinkerType
# It was already called
if hasattr(self, "_c_code"):
return
subd = dict(
chain(
((e, f"%(i{int(i)})s") for i, e in enumerate(self.fgraph.inputs)),
((e, f"%(o{int(i)})s") for i, e in enumerate(self.fgraph.outputs)),
)
)
for var in self.fgraph.variables:
if var.owner is None:
if var not in self.fgraph.inputs:
# This is an orphan
if isinstance(var, Constant) and isinstance(var.type, CLinkerType):
subd[var] = var.type.c_literal(var.data)
else:
raise ValueError(
"All orphans in the fgraph to Composite must"
" be Constant, CLinkerType instances."
)
elif any(i.dtype == "float16" for i in var.owner.inputs) or any(
o.dtype == "float16" for o in var.owner.outputs
):
# flag for elemwise ops to check.
self.inner_float16 = True
_c_code = "{\n"
self.nodenames = [
f"%(nodename)s_subnode{int(j)}"
for j, n in enumerate(self.fgraph.toposort())
]
i = 0
for j, node in enumerate(self.fgraph.toposort()):
for output in node.outputs:
if output not in subd:
i += 1
name = f"V%(id)s_tmp{int(i)}"
subd[output] = name
_c_code += f"{output.type.dtype_specs()[1]} {name};\n"
s = node.op.c_code(
node,
self.nodenames[j],
[subd[input] for input in node.inputs],
[subd[output] for output in node.outputs],
dict(fail="%(fail)s", id=f"%(id)s_{int(j)}"),
)
_c_code += s
_c_code += "\n"
_c_code += "}\n"
self._c_code = _c_code
def init_py_impls(self):
"""
Return a list of functions that compute each output of self.
"""
# In the case where the graph is a dag, but not a tree like:
# add(*1 -> mul(x, y), *1)
# We have an efficient way to build the executable (we build
# and traverse each node only once).
# But we don't have an efficient execution. We will execute
# like a tree, so nodes that have more then 1 client will be
# executed as many times as there number of clients. In the
# example above, it will calculate *1 twice. Doing otherwise
# imply making a complicated execution engine.
# We need the fast creation of the executor as we always do it
# even if we will use the c code. The Python implementation is
# already slow, so it is not as much important to have a fast
# execution there.
memo = {}
def compose_impl(r):
if r in memo:
return memo[r]
if r in self.fgraph.inputs:
idx = self.fgraph.inputs.index(r)
def f(inputs):
return inputs[idx]
memo[r] = f
return f
elif r.owner is None: # in fgraph.orphans:
def f(inputs):
return r.data
memo[r] = f
return f
node = r.owner
producers = [compose_impl(input) for input in node.inputs]
def f(inputs):
return node.op.impl(*[p(inputs) for p in producers])
memo[r] = f
return f
self._impls = [compose_impl(r) for r in self.fgraph.outputs]
def init_name(self):
"""
Return a readable string representation of self.fgraph.
"""
rval = self.name
if rval is None:
for i, r in enumerate(self.fgraph.inputs):
r.name = f"i{int(i)}"
for i, r in enumerate(self.fgraph.outputs):
r.name = f"o{int(i)}"
io = set(self.fgraph.inputs + self.fgraph.outputs)
for i, r in enumerate(self.fgraph.variables):
if r not in io and len(self.fgraph.clients[r]) > 1:
r.name = f"t{int(i)}"
outputs_str = ", ".join([pprint(output) for output in self.fgraph.outputs])
rval = f"Composite{{{outputs_str}}}"
self.name = rval
def init_fgraph(self):
# The clone done by FunctionGraph is needed as we don't want
# the fgraph to be set to the variable as we need to pickle
# them for the cache of c module to work.
fgraph = FunctionGraph(self.inputs, self.outputs)
MergeOptimizer().rewrite(fgraph)
for node in fgraph.apply_nodes:
if not isinstance(node.op, ScalarOp):
raise ValueError(
"The fgraph to Composite must be exclusively"
" composed of ScalarOp instances."
)
self.fgraph = fgraph
def __init__(self, inputs, outputs): def __init__(self, inputs, outputs):
# We need to clone the graph as sometimes its nodes already # We need to clone the graph as sometimes its nodes already
# contain a reference to an fgraph. As we want the Composite # contain a reference to an fgraph. As we want the Composite
...@@ -4179,6 +4012,7 @@ class Composite(ScalarOp): ...@@ -4179,6 +4012,7 @@ class Composite(ScalarOp):
# only 1 new Composite each time at the output. # only 1 new Composite each time at the output.
for i in inputs: for i in inputs:
assert i not in outputs # This isn't supported, use identity assert i not in outputs # This isn't supported, use identity
if len(outputs) > 1 or not any( if len(outputs) > 1 or not any(
isinstance(var.owner.op, Composite) for var in outputs isinstance(var.owner.op, Composite) for var in outputs
): ):
...@@ -4210,15 +4044,112 @@ class Composite(ScalarOp): ...@@ -4210,15 +4044,112 @@ class Composite(ScalarOp):
self.outputs_type = tuple([output.type for output in outputs]) self.outputs_type = tuple([output.type for output in outputs])
self.nin = len(inputs) self.nin = len(inputs)
self.nout = len(outputs) self.nout = len(outputs)
self.init_fgraph() # self.fgraph
# Postpone the creation in case it isn't needed.
# self.init_name() # self.name
self.name = None
self.prepare_node_called = set() self.prepare_node_called = set()
@property
def fn(self):
return self._fn
@property
def inner_inputs(self):
return self.fgraph.inputs
@property
def inner_outputs(self):
return self.fgraph.outputs
def __str__(self):
return self.name
def make_new_inplace(self, output_types_preference=None, name=None):
"""
This op.__init__ fct don't have the same parameter as other scalar op.
This break the insert_inplace_optimizer optimization.
This fct allow fix patch this.
"""
d = {k: getattr(self, k) for k in self.init_param}
out = self.__class__(**d)
if name:
out.name = name
else:
name = out.name
super(Composite, out).__init__(output_types_preference, name)
return out
@property
def py_perform(self):
if hasattr(self, "_py_perform_fn"):
return self._py_perform_fn
from pytensor.link.utils import fgraph_to_python
def python_convert(op, node=None, **kwargs):
assert node is not None
n_outs = len(node.outputs)
if n_outs > 1:
def _perform(*inputs, outputs=[[None]] * n_outs):
op.perform(node, inputs, outputs)
return tuple(o[0] for o in outputs)
else:
def _perform(*inputs, outputs=[[None]]):
op.perform(node, inputs, outputs)
return outputs[0][0]
return _perform
self._py_perform_fn = fgraph_to_python(self.fgraph, python_convert)
return self._py_perform_fn
@property
def name(self):
if hasattr(self, "_name"):
return self._name
# TODO FIXME: Just implement pretty printing for the `Op`; don't do
# this redundant, outside work in the `Op` itself.
for i, r in enumerate(self.fgraph.inputs):
r.name = f"i{int(i)}"
for i, r in enumerate(self.fgraph.outputs):
r.name = f"o{int(i)}"
io = set(self.fgraph.inputs + self.fgraph.outputs)
for i, r in enumerate(self.fgraph.variables):
if r not in io and len(self.fgraph.clients[r]) > 1:
r.name = f"t{int(i)}"
outputs_str = ", ".join([pprint(output) for output in self.fgraph.outputs])
rval = f"Composite{{{outputs_str}}}"
self._name = rval
return self._name
@name.setter
def name(self, name):
self._name = name
@property
def fgraph(self):
if hasattr(self, "_fgraph"):
return self._fgraph
# The clone done by FunctionGraph is needed as we don't want
# the fgraph to be set to the variable as we need to pickle
# them for the cache of c module to work.
fgraph = FunctionGraph(self.inputs, self.outputs)
MergeOptimizer().rewrite(fgraph)
for node in fgraph.apply_nodes:
if not isinstance(node.op, ScalarOp):
raise TypeError(
"The fgraph to Composite must be exclusively"
" composed of ScalarOp instances."
)
self._fgraph = fgraph
return self._fgraph
def prepare_node(self, node, storage_map, compute_map, impl): def prepare_node(self, node, storage_map, compute_map, impl):
if impl == "py":
self.init_py_impls() # self._impls
if impl not in self.prepare_node_called: if impl not in self.prepare_node_called:
for n in list_of_nodes(self.inputs, self.outputs): for n in list_of_nodes(self.inputs, self.outputs):
n.op.prepare_node(n, None, None, impl) n.op.prepare_node(n, None, None, impl)
...@@ -4229,7 +4160,13 @@ class Composite(ScalarOp): ...@@ -4229,7 +4160,13 @@ class Composite(ScalarOp):
new_ins, new_outs = composite_f32.apply(self.fgraph) new_ins, new_outs = composite_f32.apply(self.fgraph)
return Composite(new_ins, new_outs) return Composite(new_ins, new_outs)
def clone(self):
new_ins, new_outs = composite_f32.apply(self.fgraph)
return Composite(new_ins, new_outs)
def output_types(self, input_types): def output_types(self, input_types):
# TODO FIXME: What's the intended purpose/use of this method, and why
# does it even need to be a method?
if tuple(input_types) != self.inputs_type: if tuple(input_types) != self.inputs_type:
raise TypeError( raise TypeError(
f"Wrong types for Composite. Expected {self.inputs_type}, got {tuple(input_types)}." f"Wrong types for Composite. Expected {self.inputs_type}, got {tuple(input_types)}."
...@@ -4256,8 +4193,9 @@ class Composite(ScalarOp): ...@@ -4256,8 +4193,9 @@ class Composite(ScalarOp):
return node return node
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
for storage, impl in zip(output_storage, self._impls): outputs = self.py_perform(*inputs)
storage[0] = impl(inputs) for storage, out_val in zip(output_storage, outputs):
storage[0] = out_val
def impl(self, *inputs): def impl(self, *inputs):
output_storage = [[None] for i in range(self.nout)] output_storage = [[None] for i in range(self.nout)]
...@@ -4270,8 +4208,110 @@ class Composite(ScalarOp): ...@@ -4270,8 +4208,110 @@ class Composite(ScalarOp):
def grad(self, inputs, output_grads): def grad(self, inputs, output_grads):
raise NotImplementedError("grad is not implemented for Composite") raise NotImplementedError("grad is not implemented for Composite")
def __eq__(self, other):
if self is other:
return True
if (
type(self) != type(other)
or self.nin != other.nin
or self.nout != other.nout
):
return False
# TODO FIXME: Why this? Shouldn't we expect equivalent inputs to this
# object to generate the same `_c_code`?
return self.c_code_template == other.c_code_template
def __hash__(self):
# Note that in general, the configparser settings at the time
# of code generation (__init__) affect the semantics of this Op.
# This function assumes that all relevant info about the configparser
# is embodied in _c_code. So the _c_code, rather than self.fgraph,
# is the signature of the semantics of this Op.
# _c_code is preserved through unpickling, so the Op will not change
# semantics when it is reloaded with different configparser
# settings.
#
# TODO FIXME: Doesn't the above just mean that we should be including
# the relevant "configparser settings" here? Also, why should we even
# care about the exact form of the generated C code when comparing
# `Op`s? All this smells of leaky concerns and interfaces.
return hash((type(self), self.nin, self.nout, self.c_code_template))
def __getstate__(self):
rval = dict(self.__dict__)
rval.pop("_c_code", None)
rval.pop("_py_perform_fn", None)
rval.pop("_fgraph", None)
rval.pop("prepare_node_called", None)
return rval
def __setstate__(self, d):
self.__dict__.update(d)
self.prepare_node_called = set()
@property
def c_code_template(self):
from pytensor.link.c.interface import CLinkerType
if hasattr(self, "_c_code"):
return self._c_code
subd = dict(
chain(
((e, f"%(i{int(i)})s") for i, e in enumerate(self.fgraph.inputs)),
((e, f"%(o{int(i)})s") for i, e in enumerate(self.fgraph.outputs)),
)
)
for var in self.fgraph.variables:
if var.owner is None:
if var not in self.fgraph.inputs:
# This is an orphan
if isinstance(var, Constant) and isinstance(var.type, CLinkerType):
subd[var] = var.type.c_literal(var.data)
else:
raise ValueError(
"All orphans in the fgraph to Composite must"
" be Constant, CLinkerType instances."
)
elif any(i.dtype == "float16" for i in var.owner.inputs) or any(
o.dtype == "float16" for o in var.owner.outputs
):
# flag for elemwise ops to check.
self.inner_float16 = True
_c_code = "{\n"
self.nodenames = [
f"%(nodename)s_subnode{int(j)}"
for j, n in enumerate(self.fgraph.toposort())
]
i = 0
for j, node in enumerate(self.fgraph.toposort()):
for output in node.outputs:
if output not in subd:
i += 1
name = f"V%(id)s_tmp{int(i)}"
subd[output] = name
_c_code += f"{output.type.dtype_specs()[1]} {name};\n"
s = node.op.c_code(
node,
self.nodenames[j],
[subd[input] for input in node.inputs],
[subd[output] for output in node.outputs],
dict(fail="%(fail)s", id=f"%(id)s_{int(j)}"),
)
_c_code += s
_c_code += "\n"
_c_code += "}\n"
self._c_code = _c_code
return self._c_code
def c_code(self, node, nodename, inames, onames, sub): def c_code(self, node, nodename, inames, onames, sub):
self.init_c_code()
d = dict( d = dict(
chain( chain(
...@@ -4286,7 +4326,7 @@ class Composite(ScalarOp): ...@@ -4286,7 +4326,7 @@ class Composite(ScalarOp):
# It won't generate conflicting variable name. # It won't generate conflicting variable name.
d["id"] = "_DUMMY_ID_" d["id"] = "_DUMMY_ID_"
return self._c_code % d return self.c_code_template % d
def c_code_cache_version(self): def c_code_cache_version(self):
rval = [3] rval = [3]
...@@ -4314,7 +4354,6 @@ class Composite(ScalarOp): ...@@ -4314,7 +4354,6 @@ class Composite(ScalarOp):
return "\n".join(sorted(rval)) return "\n".join(sorted(rval))
def c_support_code_apply(self, node, name): def c_support_code_apply(self, node, name):
self.init_c_code()
rval = [] rval = []
for subnode, subnodename in zip(self.fgraph.toposort(), self.nodenames): for subnode, subnodename in zip(self.fgraph.toposort(), self.nodenames):
subnode_support_code = subnode.op.c_support_code_apply( subnode_support_code = subnode.op.c_support_code_apply(
...@@ -4328,49 +4367,6 @@ class Composite(ScalarOp): ...@@ -4328,49 +4367,6 @@ class Composite(ScalarOp):
# c_support_code instead of c_support_code_apply. # c_support_code instead of c_support_code_apply.
return "\n".join(rval) return "\n".join(rval)
def __eq__(self, other):
if self is other:
return True
if (
type(self) != type(other)
or self.nin != other.nin
or self.nout != other.nout
):
return False
# see __hash__ for comment on why there is no mention of fgraph
# or module cache key here.
self.init_c_code() # self._c_code and self.nodenames
other.init_c_code()
return self._c_code == other._c_code
def __hash__(self):
self.init_c_code() # self._c_code and self.nodenames
rval = hash((type(self), self.nin, self.nout, self._c_code))
# Note that in general, the configparser settings at the time
# of code generation (__init__) affect the semantics of this Op.
# This function assumes that all relevant info about the configparser
# is embodied in _c_code. So the _c_code, rather than self.fgraph,
# is the signature of the semantics of this Op.
# _c_code is preserved through unpickling, so the Op will not change
# semantics when it is reloaded with different configparser
# settings.
return rval
def __getstate__(self):
rval = dict(self.__dict__)
rval.pop("_impls", None)
rval.pop("prepare_node_called", None)
del rval["fgraph"]
return rval
def __setstate__(self, d):
self.__dict__.update(d)
# We must call init to set fgraph and _impls again, as otherwise
# self.perform will not work.
self.prepare_node_called = set()
self.init_fgraph()
self.init_py_impls()
class Compositef32: class Compositef32:
# This is a dict of scalar op classes that need special handling # This is a dict of scalar op classes that need special handling
......
...@@ -2,6 +2,7 @@ import numpy as np ...@@ -2,6 +2,7 @@ import numpy as np
import pytest import pytest
import pytensor import pytensor
import pytensor.tensor as at
import tests.unittest_tools as utt import tests.unittest_tools as utt
from pytensor.compile.mode import Mode from pytensor.compile.mode import Mode
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
...@@ -130,11 +131,16 @@ class TestComposite: ...@@ -130,11 +131,16 @@ class TestComposite:
def test_with_constants(self): def test_with_constants(self):
x, y, z = floats("xyz") x, y, z = floats("xyz")
e = mul(add(70.0, y), true_div(x, y)) e = mul(add(70.0, y), true_div(x, y))
C = Composite([x, y], [e]) comp_op = Composite([x, y], [e])
c = C.make_node(x, y) comp_node = comp_op.make_node(x, y)
assert "70.0" in c.op.c_code(c, "dummy", ["x", "y"], ["z"], dict(id=0))
# print c.c_code(['x', 'y'], ['z'], dict(id = 0)) c_code = comp_node.op.c_code(comp_node, "dummy", ["x", "y"], ["z"], dict(id=0))
g = FunctionGraph([x, y], [c.out]) assert "70.0" in c_code
# Make sure caching of the c_code template works
assert hasattr(comp_node.op, "_c_code")
g = FunctionGraph([x, y], [comp_node.out])
fn = make_function(DualLinker().accept(g)) fn = make_function(DualLinker().accept(g))
assert fn(1.0, 2.0) == 36.0 assert fn(1.0, 2.0) == 36.0
...@@ -174,24 +180,35 @@ class TestComposite: ...@@ -174,24 +180,35 @@ class TestComposite:
"*1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)" "*1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)"
) )
def test_make_node_continue_graph(self): def test_non_scalar_error(self):
# This is a test for a bug (now fixed) that disabled the x = float32("x")
# local_gpu_elemwise_0 optimization and printed an comp_op = Composite([x], [(at.zeros((2,)) + x).sum()])
# optimization warning on the terminal.
with pytest.raises(TypeError, match=".*exclusively.*ScalarOp.*"):
# We test that Composite.make_node accept as inputs Variable comp_op.fgraph
# some that represent existing computation.
def test_multi_out_perform(self):
si0 = pytensor.scalar.int8() from pytensor.graph.basic import Apply
si1 = pytensor.scalar.int8() from pytensor.scalar.basic import ScalarOp
si2 = pytensor.scalar.float32()
sout = (si0 * si1) / si2 class MultiOutOp(ScalarOp):
sop = pytensor.scalar.Composite([si0, si1, si2], [sout]) def make_node(self, x):
si0 = pytensor.scalar.int8() return Apply(self, [x], [x.type(), x.type()])
si1 = pytensor.scalar.int8()
si2 = pytensor.scalar.float32() def perform(self, node, inputs, outputs):
si3 = pytensor.scalar.float32() outputs[1][0] = outputs[0][0] = inputs[0]
sop.make_node(si0 * si3, si1, si2)
def c_code(self, *args):
return "dummy"
x = float32("x")
comp_op = Composite([x], MultiOutOp()(x))
y, z = comp_op(x)
fn = pytensor.function([x], [y, z], mode=Mode("py", None))
assert fn(1.0) == [1.0, 1.0]
class TestLogical: class TestLogical:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论