提交 b12cd96a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Ricardo Vieira

Refactor Composite Op

- Lazily create and cache `FunctionGraph`s, the `Composite.perform` implementation, C code, and name values - Use `fgraph_to_python` for `Composite.perform` - Use the `HasInnerGraph` interface
上级 c0d2c635
......@@ -27,6 +27,7 @@ from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType, grad_undefined
from pytensor.graph.basic import Apply, Constant, Variable, clone, list_of_nodes
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import HasInnerGraph
from pytensor.graph.rewriting.basic import MergeOptimizer
from pytensor.graph.type import HasDataType, HasShape
from pytensor.graph.utils import MetaObject, MethodNotDefined
......@@ -3987,7 +3988,7 @@ class ComplexFromPolar(BinaryScalarOp):
complex_from_polar = ComplexFromPolar(name="complex_from_polar")
class Composite(ScalarOp):
class Composite(ScalarOp, HasInnerGraph):
"""
Composite is an Op that takes a graph of scalar operations and
produces c code for the whole graph. Its purpose is to implement loop
......@@ -3999,174 +4000,6 @@ class Composite(ScalarOp):
init_param: Union[Tuple[str, str], Tuple[str]] = ("inputs", "outputs")
def __str__(self):
if self.name is None:
self.init_name()
return self.name
def make_new_inplace(self, output_types_preference=None, name=None):
"""
This op.__init__ fct don't have the same parameter as other scalar op.
This break the insert_inplace_optimizer optimization.
This fct allow fix patch this.
"""
d = {k: getattr(self, k) for k in self.init_param}
out = self.__class__(**d)
if name:
out.name = name
else:
name = out.name
super(Composite, out).__init__(output_types_preference, name)
return out
def init_c_code(self):
"""
Assemble the C code for this Composite Op.
The result is assigned to `self._c_code`.
"""
from pytensor.link.c.interface import CLinkerType
# It was already called
if hasattr(self, "_c_code"):
return
subd = dict(
chain(
((e, f"%(i{int(i)})s") for i, e in enumerate(self.fgraph.inputs)),
((e, f"%(o{int(i)})s") for i, e in enumerate(self.fgraph.outputs)),
)
)
for var in self.fgraph.variables:
if var.owner is None:
if var not in self.fgraph.inputs:
# This is an orphan
if isinstance(var, Constant) and isinstance(var.type, CLinkerType):
subd[var] = var.type.c_literal(var.data)
else:
raise ValueError(
"All orphans in the fgraph to Composite must"
" be Constant, CLinkerType instances."
)
elif any(i.dtype == "float16" for i in var.owner.inputs) or any(
o.dtype == "float16" for o in var.owner.outputs
):
# flag for elemwise ops to check.
self.inner_float16 = True
_c_code = "{\n"
self.nodenames = [
f"%(nodename)s_subnode{int(j)}"
for j, n in enumerate(self.fgraph.toposort())
]
i = 0
for j, node in enumerate(self.fgraph.toposort()):
for output in node.outputs:
if output not in subd:
i += 1
name = f"V%(id)s_tmp{int(i)}"
subd[output] = name
_c_code += f"{output.type.dtype_specs()[1]} {name};\n"
s = node.op.c_code(
node,
self.nodenames[j],
[subd[input] for input in node.inputs],
[subd[output] for output in node.outputs],
dict(fail="%(fail)s", id=f"%(id)s_{int(j)}"),
)
_c_code += s
_c_code += "\n"
_c_code += "}\n"
self._c_code = _c_code
def init_py_impls(self):
"""
Return a list of functions that compute each output of self.
"""
# In the case where the graph is a dag, but not a tree like:
# add(*1 -> mul(x, y), *1)
# We have an efficient way to build the executable (we build
# and traverse each node only once).
# But we don't have an efficient execution. We will execute
# like a tree, so nodes that have more then 1 client will be
# executed as many times as there number of clients. In the
# example above, it will calculate *1 twice. Doing otherwise
# imply making a complicated execution engine.
# We need the fast creation of the executor as we always do it
# even if we will use the c code. The Python implementation is
# already slow, so it is not as much important to have a fast
# execution there.
memo = {}
def compose_impl(r):
if r in memo:
return memo[r]
if r in self.fgraph.inputs:
idx = self.fgraph.inputs.index(r)
def f(inputs):
return inputs[idx]
memo[r] = f
return f
elif r.owner is None: # in fgraph.orphans:
def f(inputs):
return r.data
memo[r] = f
return f
node = r.owner
producers = [compose_impl(input) for input in node.inputs]
def f(inputs):
return node.op.impl(*[p(inputs) for p in producers])
memo[r] = f
return f
self._impls = [compose_impl(r) for r in self.fgraph.outputs]
def init_name(self):
"""
Return a readable string representation of self.fgraph.
"""
rval = self.name
if rval is None:
for i, r in enumerate(self.fgraph.inputs):
r.name = f"i{int(i)}"
for i, r in enumerate(self.fgraph.outputs):
r.name = f"o{int(i)}"
io = set(self.fgraph.inputs + self.fgraph.outputs)
for i, r in enumerate(self.fgraph.variables):
if r not in io and len(self.fgraph.clients[r]) > 1:
r.name = f"t{int(i)}"
outputs_str = ", ".join([pprint(output) for output in self.fgraph.outputs])
rval = f"Composite{{{outputs_str}}}"
self.name = rval
def init_fgraph(self):
# The clone done by FunctionGraph is needed as we don't want
# the fgraph to be set to the variable as we need to pickle
# them for the cache of c module to work.
fgraph = FunctionGraph(self.inputs, self.outputs)
MergeOptimizer().rewrite(fgraph)
for node in fgraph.apply_nodes:
if not isinstance(node.op, ScalarOp):
raise ValueError(
"The fgraph to Composite must be exclusively"
" composed of ScalarOp instances."
)
self.fgraph = fgraph
def __init__(self, inputs, outputs):
# We need to clone the graph as sometimes its nodes already
# contain a reference to an fgraph. As we want the Composite
......@@ -4179,6 +4012,7 @@ class Composite(ScalarOp):
# only 1 new Composite each time at the output.
for i in inputs:
assert i not in outputs # This isn't supported, use identity
if len(outputs) > 1 or not any(
isinstance(var.owner.op, Composite) for var in outputs
):
......@@ -4210,15 +4044,112 @@ class Composite(ScalarOp):
self.outputs_type = tuple([output.type for output in outputs])
self.nin = len(inputs)
self.nout = len(outputs)
self.init_fgraph() # self.fgraph
# Postpone the creation in case it isn't needed.
# self.init_name() # self.name
self.name = None
self.prepare_node_called = set()
@property
def fn(self):
return self._fn
@property
def inner_inputs(self):
return self.fgraph.inputs
@property
def inner_outputs(self):
return self.fgraph.outputs
def __str__(self):
return self.name
def make_new_inplace(self, output_types_preference=None, name=None):
"""
This op.__init__ fct don't have the same parameter as other scalar op.
This break the insert_inplace_optimizer optimization.
This fct allow fix patch this.
"""
d = {k: getattr(self, k) for k in self.init_param}
out = self.__class__(**d)
if name:
out.name = name
else:
name = out.name
super(Composite, out).__init__(output_types_preference, name)
return out
@property
def py_perform(self):
if hasattr(self, "_py_perform_fn"):
return self._py_perform_fn
from pytensor.link.utils import fgraph_to_python
def python_convert(op, node=None, **kwargs):
assert node is not None
n_outs = len(node.outputs)
if n_outs > 1:
def _perform(*inputs, outputs=[[None]] * n_outs):
op.perform(node, inputs, outputs)
return tuple(o[0] for o in outputs)
else:
def _perform(*inputs, outputs=[[None]]):
op.perform(node, inputs, outputs)
return outputs[0][0]
return _perform
self._py_perform_fn = fgraph_to_python(self.fgraph, python_convert)
return self._py_perform_fn
@property
def name(self):
if hasattr(self, "_name"):
return self._name
# TODO FIXME: Just implement pretty printing for the `Op`; don't do
# this redundant, outside work in the `Op` itself.
for i, r in enumerate(self.fgraph.inputs):
r.name = f"i{int(i)}"
for i, r in enumerate(self.fgraph.outputs):
r.name = f"o{int(i)}"
io = set(self.fgraph.inputs + self.fgraph.outputs)
for i, r in enumerate(self.fgraph.variables):
if r not in io and len(self.fgraph.clients[r]) > 1:
r.name = f"t{int(i)}"
outputs_str = ", ".join([pprint(output) for output in self.fgraph.outputs])
rval = f"Composite{{{outputs_str}}}"
self._name = rval
return self._name
@name.setter
def name(self, name):
self._name = name
@property
def fgraph(self):
if hasattr(self, "_fgraph"):
return self._fgraph
# The clone done by FunctionGraph is needed as we don't want
# the fgraph to be set to the variable as we need to pickle
# them for the cache of c module to work.
fgraph = FunctionGraph(self.inputs, self.outputs)
MergeOptimizer().rewrite(fgraph)
for node in fgraph.apply_nodes:
if not isinstance(node.op, ScalarOp):
raise TypeError(
"The fgraph to Composite must be exclusively"
" composed of ScalarOp instances."
)
self._fgraph = fgraph
return self._fgraph
def prepare_node(self, node, storage_map, compute_map, impl):
if impl == "py":
self.init_py_impls() # self._impls
if impl not in self.prepare_node_called:
for n in list_of_nodes(self.inputs, self.outputs):
n.op.prepare_node(n, None, None, impl)
......@@ -4229,7 +4160,13 @@ class Composite(ScalarOp):
new_ins, new_outs = composite_f32.apply(self.fgraph)
return Composite(new_ins, new_outs)
def clone(self):
new_ins, new_outs = composite_f32.apply(self.fgraph)
return Composite(new_ins, new_outs)
def output_types(self, input_types):
# TODO FIXME: What's the intended purpose/use of this method, and why
# does it even need to be a method?
if tuple(input_types) != self.inputs_type:
raise TypeError(
f"Wrong types for Composite. Expected {self.inputs_type}, got {tuple(input_types)}."
......@@ -4256,8 +4193,9 @@ class Composite(ScalarOp):
return node
def perform(self, node, inputs, output_storage):
for storage, impl in zip(output_storage, self._impls):
storage[0] = impl(inputs)
outputs = self.py_perform(*inputs)
for storage, out_val in zip(output_storage, outputs):
storage[0] = out_val
def impl(self, *inputs):
output_storage = [[None] for i in range(self.nout)]
......@@ -4270,8 +4208,110 @@ class Composite(ScalarOp):
def grad(self, inputs, output_grads):
raise NotImplementedError("grad is not implemented for Composite")
def __eq__(self, other):
if self is other:
return True
if (
type(self) != type(other)
or self.nin != other.nin
or self.nout != other.nout
):
return False
# TODO FIXME: Why this? Shouldn't we expect equivalent inputs to this
# object to generate the same `_c_code`?
return self.c_code_template == other.c_code_template
def __hash__(self):
# Note that in general, the configparser settings at the time
# of code generation (__init__) affect the semantics of this Op.
# This function assumes that all relevant info about the configparser
# is embodied in _c_code. So the _c_code, rather than self.fgraph,
# is the signature of the semantics of this Op.
# _c_code is preserved through unpickling, so the Op will not change
# semantics when it is reloaded with different configparser
# settings.
#
# TODO FIXME: Doesn't the above just mean that we should be including
# the relevant "configparser settings" here? Also, why should we even
# care about the exact form of the generated C code when comparing
# `Op`s? All this smells of leaky concerns and interfaces.
return hash((type(self), self.nin, self.nout, self.c_code_template))
def __getstate__(self):
rval = dict(self.__dict__)
rval.pop("_c_code", None)
rval.pop("_py_perform_fn", None)
rval.pop("_fgraph", None)
rval.pop("prepare_node_called", None)
return rval
def __setstate__(self, d):
self.__dict__.update(d)
self.prepare_node_called = set()
@property
def c_code_template(self):
from pytensor.link.c.interface import CLinkerType
if hasattr(self, "_c_code"):
return self._c_code
subd = dict(
chain(
((e, f"%(i{int(i)})s") for i, e in enumerate(self.fgraph.inputs)),
((e, f"%(o{int(i)})s") for i, e in enumerate(self.fgraph.outputs)),
)
)
for var in self.fgraph.variables:
if var.owner is None:
if var not in self.fgraph.inputs:
# This is an orphan
if isinstance(var, Constant) and isinstance(var.type, CLinkerType):
subd[var] = var.type.c_literal(var.data)
else:
raise ValueError(
"All orphans in the fgraph to Composite must"
" be Constant, CLinkerType instances."
)
elif any(i.dtype == "float16" for i in var.owner.inputs) or any(
o.dtype == "float16" for o in var.owner.outputs
):
# flag for elemwise ops to check.
self.inner_float16 = True
_c_code = "{\n"
self.nodenames = [
f"%(nodename)s_subnode{int(j)}"
for j, n in enumerate(self.fgraph.toposort())
]
i = 0
for j, node in enumerate(self.fgraph.toposort()):
for output in node.outputs:
if output not in subd:
i += 1
name = f"V%(id)s_tmp{int(i)}"
subd[output] = name
_c_code += f"{output.type.dtype_specs()[1]} {name};\n"
s = node.op.c_code(
node,
self.nodenames[j],
[subd[input] for input in node.inputs],
[subd[output] for output in node.outputs],
dict(fail="%(fail)s", id=f"%(id)s_{int(j)}"),
)
_c_code += s
_c_code += "\n"
_c_code += "}\n"
self._c_code = _c_code
return self._c_code
def c_code(self, node, nodename, inames, onames, sub):
self.init_c_code()
d = dict(
chain(
......@@ -4286,7 +4326,7 @@ class Composite(ScalarOp):
# It won't generate conflicting variable name.
d["id"] = "_DUMMY_ID_"
return self._c_code % d
return self.c_code_template % d
def c_code_cache_version(self):
rval = [3]
......@@ -4314,7 +4354,6 @@ class Composite(ScalarOp):
return "\n".join(sorted(rval))
def c_support_code_apply(self, node, name):
self.init_c_code()
rval = []
for subnode, subnodename in zip(self.fgraph.toposort(), self.nodenames):
subnode_support_code = subnode.op.c_support_code_apply(
......@@ -4328,49 +4367,6 @@ class Composite(ScalarOp):
# c_support_code instead of c_support_code_apply.
return "\n".join(rval)
def __eq__(self, other):
if self is other:
return True
if (
type(self) != type(other)
or self.nin != other.nin
or self.nout != other.nout
):
return False
# see __hash__ for comment on why there is no mention of fgraph
# or module cache key here.
self.init_c_code() # self._c_code and self.nodenames
other.init_c_code()
return self._c_code == other._c_code
def __hash__(self):
self.init_c_code() # self._c_code and self.nodenames
rval = hash((type(self), self.nin, self.nout, self._c_code))
# Note that in general, the configparser settings at the time
# of code generation (__init__) affect the semantics of this Op.
# This function assumes that all relevant info about the configparser
# is embodied in _c_code. So the _c_code, rather than self.fgraph,
# is the signature of the semantics of this Op.
# _c_code is preserved through unpickling, so the Op will not change
# semantics when it is reloaded with different configparser
# settings.
return rval
def __getstate__(self):
rval = dict(self.__dict__)
rval.pop("_impls", None)
rval.pop("prepare_node_called", None)
del rval["fgraph"]
return rval
def __setstate__(self, d):
self.__dict__.update(d)
# We must call init to set fgraph and _impls again, as otherwise
# self.perform will not work.
self.prepare_node_called = set()
self.init_fgraph()
self.init_py_impls()
class Compositef32:
# This is a dict of scalar op classes that need special handling
......
......@@ -2,6 +2,7 @@ import numpy as np
import pytest
import pytensor
import pytensor.tensor as at
import tests.unittest_tools as utt
from pytensor.compile.mode import Mode
from pytensor.graph.fg import FunctionGraph
......@@ -130,11 +131,16 @@ class TestComposite:
def test_with_constants(self):
x, y, z = floats("xyz")
e = mul(add(70.0, y), true_div(x, y))
C = Composite([x, y], [e])
c = C.make_node(x, y)
assert "70.0" in c.op.c_code(c, "dummy", ["x", "y"], ["z"], dict(id=0))
# print c.c_code(['x', 'y'], ['z'], dict(id = 0))
g = FunctionGraph([x, y], [c.out])
comp_op = Composite([x, y], [e])
comp_node = comp_op.make_node(x, y)
c_code = comp_node.op.c_code(comp_node, "dummy", ["x", "y"], ["z"], dict(id=0))
assert "70.0" in c_code
# Make sure caching of the c_code template works
assert hasattr(comp_node.op, "_c_code")
g = FunctionGraph([x, y], [comp_node.out])
fn = make_function(DualLinker().accept(g))
assert fn(1.0, 2.0) == 36.0
......@@ -174,24 +180,35 @@ class TestComposite:
"*1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)"
)
def test_make_node_continue_graph(self):
# This is a test for a bug (now fixed) that disabled the
# local_gpu_elemwise_0 optimization and printed an
# optimization warning on the terminal.
# We test that Composite.make_node accept as inputs Variable
# some that represent existing computation.
si0 = pytensor.scalar.int8()
si1 = pytensor.scalar.int8()
si2 = pytensor.scalar.float32()
sout = (si0 * si1) / si2
sop = pytensor.scalar.Composite([si0, si1, si2], [sout])
si0 = pytensor.scalar.int8()
si1 = pytensor.scalar.int8()
si2 = pytensor.scalar.float32()
si3 = pytensor.scalar.float32()
sop.make_node(si0 * si3, si1, si2)
def test_non_scalar_error(self):
x = float32("x")
comp_op = Composite([x], [(at.zeros((2,)) + x).sum()])
with pytest.raises(TypeError, match=".*exclusively.*ScalarOp.*"):
comp_op.fgraph
def test_multi_out_perform(self):
from pytensor.graph.basic import Apply
from pytensor.scalar.basic import ScalarOp
class MultiOutOp(ScalarOp):
def make_node(self, x):
return Apply(self, [x], [x.type(), x.type()])
def perform(self, node, inputs, outputs):
outputs[1][0] = outputs[0][0] = inputs[0]
def c_code(self, *args):
return "dummy"
x = float32("x")
comp_op = Composite([x], MultiOutOp()(x))
y, z = comp_op(x)
fn = pytensor.function([x], [y, z], mode=Mode("py", None))
assert fn(1.0) == [1.0, 1.0]
class TestLogical:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论