提交 215cecd4 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Refactor baseclass ScalarInnerGraphOp from Composite Op

上级 774c32ab
......@@ -3986,7 +3986,150 @@ class ComplexFromPolar(BinaryScalarOp):
complex_from_polar = ComplexFromPolar(name="complex_from_polar")
class Composite(ScalarOp, HasInnerGraph):
class ScalarInnerGraphOp(ScalarOp, HasInnerGraph):
"""Includes boilerplate code for Python and C-implementation of Scalar Ops with inner graph."""
def __init__(self, *args, **kwargs):
self.prepare_node_called = set()
@property
def fn(self):
return None
@property
def inner_inputs(self):
return self.fgraph.inputs
@property
def inner_outputs(self):
return self.fgraph.outputs
@property
def py_perform_fn(self):
if hasattr(self, "_py_perform_fn"):
return self._py_perform_fn
from pytensor.link.utils import fgraph_to_python
def python_convert(op, node=None, **kwargs):
assert node is not None
n_outs = len(node.outputs)
if n_outs > 1:
def _perform(*inputs, outputs=[[None]] * n_outs):
op.perform(node, inputs, outputs)
return tuple(o[0] for o in outputs)
else:
def _perform(*inputs, outputs=[[None]]):
op.perform(node, inputs, outputs)
return outputs[0][0]
return _perform
self._py_perform_fn = fgraph_to_python(self.fgraph, python_convert)
return self._py_perform_fn
def impl(self, *inputs):
output_storage = [[None] for i in range(self.nout)]
self.perform(None, inputs, output_storage)
ret = to_return_values([storage[0] for storage in output_storage])
if self.nout > 1:
ret = tuple(ret)
return ret
def c_code_cache_version(self):
rval = list(self.c_code_cache_version_outer())
for x in self.fgraph.toposort():
xv = x.op.c_code_cache_version()
if xv:
rval.append(xv)
else:
return ()
return tuple(rval)
def c_header_dirs(self, **kwargs):
rval = sum(
(subnode.op.c_header_dirs(**kwargs) for subnode in self.fgraph.toposort()),
[],
)
return rval
def c_support_code(self, **kwargs):
# Remove duplicate code blocks by using a `set`
rval = {
subnode.op.c_support_code(**kwargs).strip()
for subnode in self.fgraph.toposort()
}
return "\n".join(sorted(rval))
def c_support_code_apply(self, node, name):
rval = []
for subnode, subnodename in zip(self.fgraph.toposort(), self.nodenames):
subnode_support_code = subnode.op.c_support_code_apply(
subnode, subnodename % dict(nodename=name)
)
if subnode_support_code:
rval.append(subnode_support_code)
# there should be no need to remove duplicate code blocks because
# each block should have been specialized for the given nodename.
# Any block that isn't specialized should be returned via
# c_support_code instead of c_support_code_apply.
return "\n".join(rval)
def prepare_node(self, node, storage_map, compute_map, impl):
if impl not in self.prepare_node_called:
for n in list_of_nodes(self.inputs, self.outputs):
n.op.prepare_node(n, None, None, impl)
self.prepare_node_called.add(impl)
def __eq__(self, other):
if self is other:
return True
if (
type(self) != type(other)
or self.nin != other.nin
or self.nout != other.nout
):
return False
# TODO FIXME: Why this? Shouldn't we expect equivalent inputs to this
# object to generate the same `_c_code`?
return self.c_code_template == other.c_code_template
def __hash__(self):
# Note that in general, the configparser settings at the time
# of code generation (__init__) affect the semantics of this Op.
# This function assumes that all relevant info about the configparser
# is embodied in _c_code. So the _c_code, rather than self.fgraph,
# is the signature of the semantics of this Op.
# _c_code is preserved through unpickling, so the Op will not change
# semantics when it is reloaded with different configparser
# settings.
#
# TODO FIXME: Doesn't the above just mean that we should be including
# the relevant "configparser settings" here? Also, why should we even
# care about the exact form of the generated C code when comparing
# `Op`s? All this smells of leaky concerns and interfaces.
return hash((type(self), self.nin, self.nout, self.c_code_template))
def __getstate__(self):
rval = dict(self.__dict__)
rval.pop("_c_code", None)
rval.pop("_py_perform_fn", None)
rval.pop("_fgraph", None)
rval.pop("prepare_node_called", None)
return rval
def __setstate__(self, d):
self.__dict__.update(d)
self.prepare_node_called = set()
class Composite(ScalarInnerGraphOp):
"""
Composite is an Op that takes a graph of scalar operations and
produces c code for the whole graph. Its purpose is to implement loop
......@@ -4043,19 +4186,7 @@ class Composite(ScalarOp, HasInnerGraph):
self.outputs_type = tuple([output.type for output in outputs])
self.nin = len(inputs)
self.nout = len(outputs)
self.prepare_node_called = set()
@property
def fn(self):
return None
@property
def inner_inputs(self):
return self.fgraph.inputs
@property
def inner_outputs(self):
return self.fgraph.outputs
super().__init__()
def __str__(self):
return self.name
......@@ -4076,35 +4207,6 @@ class Composite(ScalarOp, HasInnerGraph):
super(Composite, out).__init__(output_types_preference, name)
return out
@property
def py_perform(self):
if hasattr(self, "_py_perform_fn"):
return self._py_perform_fn
from pytensor.link.utils import fgraph_to_python
def python_convert(op, node=None, **kwargs):
assert node is not None
n_outs = len(node.outputs)
if n_outs > 1:
def _perform(*inputs, outputs=[[None]] * n_outs):
op.perform(node, inputs, outputs)
return tuple(o[0] for o in outputs)
else:
def _perform(*inputs, outputs=[[None]]):
op.perform(node, inputs, outputs)
return outputs[0][0]
return _perform
self._py_perform_fn = fgraph_to_python(self.fgraph, python_convert)
return self._py_perform_fn
@property
def fgraph(self):
if hasattr(self, "_fgraph"):
......@@ -4139,12 +4241,6 @@ class Composite(ScalarOp, HasInnerGraph):
self._fgraph = fgraph
return self._fgraph
def prepare_node(self, node, storage_map, compute_map, impl):
if impl not in self.prepare_node_called:
for n in list_of_nodes(self.inputs, self.outputs):
n.op.prepare_node(n, None, None, impl)
self.prepare_node_called.add(impl)
def clone_float32(self):
# This will not modify the fgraph or the nodes
new_ins, new_outs = composite_f32.apply(self.fgraph)
......@@ -4155,8 +4251,6 @@ class Composite(ScalarOp, HasInnerGraph):
return Composite(new_ins, new_outs)
def output_types(self, input_types):
# TODO FIXME: What's the intended purpose/use of this method, and why
# does it even need to be a method?
if tuple(input_types) != self.inputs_type:
raise TypeError(
f"Wrong types for Composite. Expected {self.inputs_type}, got {tuple(input_types)}."
......@@ -4183,63 +4277,13 @@ class Composite(ScalarOp, HasInnerGraph):
return node
def perform(self, node, inputs, output_storage):
outputs = self.py_perform(*inputs)
outputs = self.py_perform_fn(*inputs)
for storage, out_val in zip(output_storage, outputs):
storage[0] = out_val
def impl(self, *inputs):
output_storage = [[None] for i in range(self.nout)]
self.perform(None, inputs, output_storage)
ret = to_return_values([storage[0] for storage in output_storage])
if self.nout > 1:
ret = tuple(ret)
return ret
def grad(self, inputs, output_grads):
raise NotImplementedError("grad is not implemented for Composite")
def __eq__(self, other):
if self is other:
return True
if (
type(self) != type(other)
or self.nin != other.nin
or self.nout != other.nout
):
return False
# TODO FIXME: Why this? Shouldn't we expect equivalent inputs to this
# object to generate the same `_c_code`?
return self.c_code_template == other.c_code_template
def __hash__(self):
# Note that in general, the configparser settings at the time
# of code generation (__init__) affect the semantics of this Op.
# This function assumes that all relevant info about the configparser
# is embodied in _c_code. So the _c_code, rather than self.fgraph,
# is the signature of the semantics of this Op.
# _c_code is preserved through unpickling, so the Op will not change
# semantics when it is reloaded with different configparser
# settings.
#
# TODO FIXME: Doesn't the above just mean that we should be including
# the relevant "configparser settings" here? Also, why should we even
# care about the exact form of the generated C code when comparing
# `Op`s? All this smells of leaky concerns and interfaces.
return hash((type(self), self.nin, self.nout, self.c_code_template))
def __getstate__(self):
rval = dict(self.__dict__)
rval.pop("_c_code", None)
rval.pop("_py_perform_fn", None)
rval.pop("_fgraph", None)
rval.pop("prepare_node_called", None)
return rval
def __setstate__(self, d):
self.__dict__.update(d)
self.prepare_node_called = set()
@property
def c_code_template(self):
from pytensor.link.c.interface import CLinkerType
......@@ -4317,44 +4361,8 @@ class Composite(ScalarOp, HasInnerGraph):
return self.c_code_template % d
def c_code_cache_version(self):
rval = [3]
for x in self.fgraph.toposort():
xv = x.op.c_code_cache_version()
if xv:
rval.append(xv)
else:
return ()
return tuple(rval)
def c_header_dirs(self, **kwargs):
rval = sum(
(subnode.op.c_header_dirs(**kwargs) for subnode in self.fgraph.toposort()),
[],
)
return rval
def c_support_code(self, **kwargs):
# Remove duplicate code blocks by using a `set`
rval = {
subnode.op.c_support_code(**kwargs).strip()
for subnode in self.fgraph.toposort()
}
return "\n".join(sorted(rval))
def c_support_code_apply(self, node, name):
rval = []
for subnode, subnodename in zip(self.fgraph.toposort(), self.nodenames):
subnode_support_code = subnode.op.c_support_code_apply(
subnode, subnodename % dict(nodename=name)
)
if subnode_support_code:
rval.append(subnode_support_code)
# there should be no need to remove duplicate code blocks because
# each block should have been specialized for the given nodename.
# Any block that isn't specialized should be returned via
# c_support_code instead of c_support_code_apply.
return "\n".join(rval)
def c_code_cache_version_outer(self) -> Tuple[int, ...]:
return (3,)
class Compositef32:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论