提交 9ba6d99f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Replace str "output" by a dummy Op in the clients of the FunctionGraph

上级 7f623fef
......@@ -30,6 +30,7 @@ from pytensor.configdefaults import config
from pytensor.graph.basic import Variable, io_toposort
from pytensor.graph.destroyhandler import DestroyHandler
from pytensor.graph.features import AlreadyThere, BadOptimization
from pytensor.graph.fg import Output
from pytensor.graph.op import HasInnerGraph, Op
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
from pytensor.link.basic import Container, LocalLinker
......@@ -628,7 +629,9 @@ def _is_used_in_graph(fgraph, var):
True if `var` is used by another node in the graph.
"""
return not (fgraph.clients[var] == [("output", 1)] or fgraph.clients[var] == [])
return any(
client for client, _ in fgraph.clients[var] if not isinstance(client.op, Output)
)
def _check_strides_match(a, b, warn_err, op):
......@@ -977,7 +980,7 @@ def _check_preallocated_output(
# disable memory checks in that mode, since they were already run.
try:
changed_inner_mode = False
if isinstance(getattr(node, "op", None), HasInnerGraph):
if isinstance(node.op, HasInnerGraph):
fn = node.op.fn
if not (fn and hasattr(fn, "maker") and hasattr(fn.maker, "mode")):
_logger.warning(f"Expected pytensor function not found in {node.op}.fn")
......@@ -1132,10 +1135,6 @@ class _FunctionGraphEvent:
def __init__(self, kind, node, idx=None, reason=None):
self.kind = kind
if node == "output":
self.node = "output"
self.op = "output"
else:
self.node = node
self.op = node.op
self.idx = idx
......@@ -1143,7 +1142,7 @@ class _FunctionGraphEvent:
def __str__(self):
if self.kind == "change":
if self.op != "output":
if not isinstance(self.op, Output):
msg = str(len(self.node.inputs))
else:
msg = ""
......
......@@ -78,8 +78,6 @@ def view_tree_set(fgraph, v, treeset):
"""
treeset.add(v)
for cl, v_input_pos_to_cl in fgraph.clients[v]:
if cl == "output":
continue
vmap = cl.op.view_map
dmap = cl.op.destroy_map
for opos, iposlist in chain(vmap.items(), dmap.items()):
......@@ -1202,8 +1200,11 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
has_destroyers_attr = hasattr(fgraph, "has_destroyers")
for i in range(len(fgraph.outputs)):
original_out = fgraph.outputs[i]
output_client = fgraph.get_output_client(i)
views_of_output_i = set()
view_tree_set(fgraph, alias_root(fgraph.outputs[i]), views_of_output_i)
view_tree_set(fgraph, alias_root(original_out), views_of_output_i)
copied = False
# do not allow outputs to be aliased
for j in range(i + 1, len(fgraph.outputs)):
......@@ -1212,16 +1213,16 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
if fgraph.outputs[j] in views_of_output_i:
if wrapped_outputs[i].borrow and wrapped_outputs[j].borrow:
fgraph.change_node_input(
"output", i, view_op(fgraph.outputs[i]), reason=reason
*output_client, view_op(original_out), reason=reason
)
else:
fgraph.change_node_input(
"output", i, deep_copy_op(fgraph.outputs[i]), reason=reason
*output_client, deep_copy_op(original_out), reason=reason
)
copied = True
break
if not copied:
if not copied: # no-break
for input_j in all_graph_inputs:
# do not allow outputs to be aliased to an inputs (j), unless
# a) that j'th input has been 'destroyed' by
......@@ -1239,33 +1240,29 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
j = fgraph.inputs.index(input_j)
if wrapped_outputs[i].borrow and wrapped_inputs[j].borrow:
fgraph.change_node_input(
"output",
i,
view_op(fgraph.outputs[i]),
*output_client,
view_op(original_out),
reason=reason,
)
break
else:
fgraph.change_node_input(
"output",
i,
deep_copy_op(fgraph.outputs[i]),
*output_client,
deep_copy_op(original_out),
reason=reason,
)
break
elif wrapped_outputs[i].borrow:
fgraph.change_node_input(
"output",
i,
view_op(fgraph.outputs[i]),
*output_client,
view_op(original_out),
reason=reason,
)
break
else:
fgraph.change_node_input(
"output",
i,
deep_copy_op(fgraph.outputs[i]),
*output_client,
deep_copy_op(original_out),
reason=reason,
)
break
......
......@@ -16,20 +16,17 @@ import sys
import time
from collections import Counter, defaultdict
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any
from typing import Any
import numpy as np
import pytensor
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.fg import FunctionGraph, Output
from pytensor.link.utils import get_destroy_dependencies
if TYPE_CHECKING:
from pytensor.graph.fg import FunctionGraph
@contextmanager
def extended_open(filename, mode="r"):
if filename == "<stdout>":
......@@ -1038,7 +1035,7 @@ class ProfileStats:
executable_nodes = set()
for var in fgraph.inputs:
for c, _ in fgraph.clients[var]:
if c != "output":
if not isinstance(c.op, Output):
deps = c.inputs + destroy_dependencies[c]
if all(compute_map[v][0] for v in deps):
executable_nodes.add(c)
......@@ -1166,7 +1163,7 @@ class ProfileStats:
for var in node.outputs:
for c, _ in fgraph.clients[var]:
if c != "output":
if not isinstance(c.op, Output):
deps = c.inputs + destroy_dependencies[c]
if all(compute_map[v][0] for v in deps):
new_exec_nodes.add(c)
......
......@@ -11,6 +11,7 @@ import pytensor
from pytensor.configdefaults import config
from pytensor.graph.basic import Constant
from pytensor.graph.features import AlreadyThere, Bookkeeper
from pytensor.graph.fg import Output
from pytensor.graph.utils import InconsistencyError
from pytensor.misc.ordered_set import OrderedSet
......@@ -401,8 +402,6 @@ class DestroyHandler(Bookkeeper):
def recursive_destroys_finder(protected_var):
# protected_var is the idx'th input of app.
for app, idx in fgraph.clients[protected_var]:
if app == "output":
continue
destroy_maps = app.op.destroy_map.values()
# If True means that the apply node, destroys the protected_var.
if idx in [dmap for sublist in destroy_maps for dmap in sublist]:
......@@ -578,7 +577,7 @@ class DestroyHandler(Bookkeeper):
app.inputs[i] changed from old_r to new_r.
"""
if app == "output":
if isinstance(app.op, Output):
# app == 'output' is special key that means FunctionGraph is redefining which nodes are being
# considered 'outputs' of the graph.
pass
......
差异被折叠。
......@@ -30,7 +30,7 @@ from pytensor.graph.basic import (
vars_between,
)
from pytensor.graph.features import AlreadyThere, Feature, NodeFinder
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op
from pytensor.graph.utils import AssocList, InconsistencyError
from pytensor.misc.ordered_set import OrderedSet
......@@ -738,7 +738,7 @@ class MergeOptimizer(GraphRewriter):
if any(
i in flatten(c.op.destroy_map.values())
for c, i in clients
if c != "output" and c.op.destroy_map
if c.op.destroy_map
):
continue
......@@ -1612,8 +1612,6 @@ class PatternNodeRewriter(NodeRewriter):
if get_nodes and self.get_nodes is not None:
for real_node in self.get_nodes(fgraph, node):
if real_node == "output":
continue
ret = self.transform(fgraph, real_node, get_nodes=False)
if ret is not False and ret is not None:
return dict(zip(real_node.outputs, ret))
......@@ -2399,7 +2397,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
if self.tracks_on_change_inputs:
def chin_(node, i, r, new_r, reason):
if node is not current_node and not isinstance(node, str):
if node is not current_node and not isinstance(node.op, Output):
q.append(node)
chin = chin_
......
import copy
from collections.abc import Generator, Sequence
from typing import TYPE_CHECKING, Optional, cast
from typing import TYPE_CHECKING, Optional
import pytensor
from pytensor.graph.basic import (
......@@ -10,7 +10,7 @@ from pytensor.graph.basic import (
graph_inputs,
vars_between,
)
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
......@@ -230,11 +230,9 @@ def get_clients_at_depth(
for var in node.outputs:
if depth > 0:
for out_node, _ in fgraph.clients[var]:
if out_node == "output":
if isinstance(out_node.op, Output):
continue
yield from get_clients_at_depth(
fgraph, cast(Apply, out_node), depth - 1
)
yield from get_clients_at_depth(fgraph, out_node, depth - 1)
else:
assert var.owner is not None
yield var.owner
......@@ -354,9 +354,7 @@ def get_c_declare(fgraph, r, name, sub):
# it means they need `r`'s dtype to be declared, so
# we have to pass `check_input=True` to `c_declare`.
if any(
getattr(c.op, "check_input", config.check_input)
for (c, _) in fgraph.clients[r]
if not isinstance(c, str)
getattr(c.op, "check_input", config.check_input) for (c, _) in fgraph.clients[r]
) or (r.owner and getattr(r.owner.op, "check_input", config.check_input)):
c_declare = r.type.c_declare(name, sub, True)
else:
......
......@@ -954,7 +954,6 @@ class VMLinker(LocalLinker):
if k.owner and self.fgraph.clients[k]:
ls = []
for cl in self.fgraph.clients[k]:
if cl[0] != "output":
ls += cl[0].outputs
dependencies[k] += ls
return dependencies
......
......@@ -437,10 +437,15 @@ N.B.:
for out in inner_outputs:
if (
isinstance(getattr(out.owner, "op", None), HasInnerGraph)
or hasattr(getattr(out.owner, "op", None), "scalar_op")
and isinstance(out.owner.op.scalar_op, HasInnerGraph)
) and out not in inner_graph_vars:
out.owner is not None
and (
isinstance(out.owner.op, HasInnerGraph)
or isinstance(
getattr(out.owner.op, "scalar_op", None), HasInnerGraph
)
)
and out not in inner_graph_vars
):
inner_graph_vars.append(out)
_debugprint(
......
......@@ -27,7 +27,7 @@ from pytensor.graph.basic import (
)
from pytensor.graph.destroyhandler import DestroyHandler
from pytensor.graph.features import ReplaceValidate
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import compute_test_value
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import (
......@@ -1303,7 +1303,7 @@ def scan_save_mem(fgraph, node):
for cl, _ in fgraph.clients[out]:
# 2.1 outputs of the function
# => output needs all its intermediate values
if isinstance(cl, str):
if isinstance(cl.op, Output):
# if the node is actually an output, then
# we need to store the entire thing
global_nsteps = None
......@@ -1412,7 +1412,7 @@ def scan_save_mem(fgraph, node):
for i, out in enumerate(node.outputs[:c_outs]):
# look at all its clients
for cl, _ in fgraph.clients[out]:
if isinstance(cl, str):
if isinstance(cl.op, Output):
store_steps[i] = 0
break
elif not isinstance(cl.op, Subtensor):
......@@ -2309,7 +2309,7 @@ def scan_push_out_dot1(fgraph, node):
and isinstance(out.owner.op.scalar_op, ps.Add)
and inp in out.owner.inputs
and len(fgraph.clients[outer_out]) == 1
and not isinstance(fgraph.clients[outer_out][0][0], str)
and not isinstance(fgraph.clients[outer_out][0][0], Output)
and isinstance(fgraph.clients[outer_out][0][0].op, Subtensor)
and fgraph.clients[outer_out][0][0].op.idx_list == (-1,)
):
......
......@@ -24,7 +24,7 @@ from pytensor import scalar as ps
from pytensor.gradient import DisconnectedType, grad_undefined
from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply, Constant, Variable, equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
from pytensor.graph.rewriting.db import EquilibriumDB
......@@ -1681,7 +1681,7 @@ class Alloc(COp):
return False
for client, idx in clients:
if client == "output":
if isinstance(client.op, Output):
# If the output is a constant, it will have to be deepcopied
# each time the function is called. So we do not fold.
return False
......
......@@ -31,15 +31,12 @@ def is_rv_used_in_graph(base_rv, node, fgraph):
TODO: We should apply all the shape rewrites before these rewrites, since
that would properly remove the unnecessary dependencies on `base_rv` (when
possible).
"""
def _node_check(n, i):
if n == "output":
n = fgraph.outputs[i].owner
return n == node or isinstance(n.op, Shape | Shape_i)
return not all(_node_check(n, i) for n, i in fgraph.clients.get(base_rv, ()))
return any(
n
for n, i in fgraph.clients.get(base_rv, ())
if not (n is node or isinstance(n.op, Shape | Shape_i))
)
@node_rewriter([RandomVariable], inplace=True)
......
......@@ -14,7 +14,7 @@ from pytensor.configdefaults import config
from pytensor.graph import FunctionGraph
from pytensor.graph.basic import Apply, Constant, Variable, ancestors, io_toposort
from pytensor.graph.features import ReplaceValidate
from pytensor.graph.fg import ApplyOrOutput
from pytensor.graph.fg import Output
from pytensor.graph.rewriting.basic import (
EquilibriumGraphRewriter,
GraphRewriter,
......@@ -688,7 +688,7 @@ class FusionOptimizer(GraphRewriter):
"""
FUSEABLE_MAPPING = defaultdict[Variable, list[Apply]]
UNFUSEABLE_MAPPING = defaultdict[Variable, set[ApplyOrOutput]]
UNFUSEABLE_MAPPING = defaultdict[Variable, set[Apply]]
def initialize_fuseable_mappings(
*, fg: FunctionGraph
......@@ -727,7 +727,6 @@ class FusionOptimizer(GraphRewriter):
for client, _ in clients:
if (
out_maybe_fuseable
and not isinstance(client, str) # "output"
and isinstance(client.op, Elemwise)
# and not isinstance(client.op.scalar_op, ps.Composite)
and len(client.outputs) == 1
......@@ -841,7 +840,7 @@ class FusionOptimizer(GraphRewriter):
implied_unfuseable_clients = {
c
for client in unfuseable_clients_clone.get(next_out, ())
if not isinstance(client, str) # "output"
if not isinstance(client.op, Output)
for c in client.outputs
}
......
......@@ -299,8 +299,6 @@ def local_det_chol(fgraph, node):
"""
(x,) = node.inputs
for cl, xpos in fgraph.clients[x]:
if cl == "output":
continue
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, Cholesky):
L = cl.outputs[0]
return [prod(diagonal(L, axis1=-2, axis2=-1) ** 2, axis=-1)]
......
......@@ -1126,14 +1126,11 @@ class AlgebraicCanonizer(NodeRewriter):
# this canonized graph... if so, we do nothing and wait for
# them to be transformed.
for c, c_idx in out_clients:
if c == "output":
continue
while (
isinstance(getattr(c, "op", None), DimShuffle)
and len(fgraph.clients[c.outputs[0]]) <= 1
isinstance(c.op, DimShuffle) and len(fgraph.clients[c.outputs[0]]) <= 1
):
c = fgraph.clients[c.outputs[0]][0][0]
if getattr(c, "op", "") in [self.main, self.inverse, self.reciprocal]:
if c.op in [self.main, self.inverse, self.reciprocal]:
return False
# Here we make the canonical version of the graph around this node
......
......@@ -401,7 +401,7 @@ class ShapeFeature(Feature):
merged_shape.append(other_shape[i])
elif (
ps.owner
and isinstance(getattr(ps.owner, "op", None), Shape_i)
and isinstance(ps.owner.op, Shape_i)
and ps.owner.op.i == i
and ps.owner.inputs[0] in (r, other_r)
):
......@@ -602,7 +602,7 @@ class ShapeFeature(Feature):
# r is *scheduled*.
# At that point, node is no longer a client of r, but of new_r
for shpnode, idx in fgraph.clients[r] + [(node, i)]:
if isinstance(getattr(shpnode, "op", None), Shape_i):
if isinstance(shpnode.op, Shape_i):
idx = shpnode.op.i
repl = self.shape_of[new_r][idx]
if repl.owner is shpnode:
......@@ -1057,7 +1057,10 @@ def local_Shape_of_SpecifyShape(fgraph, node):
specified_shape = node.inputs[0]
if not isinstance(getattr(specified_shape.owner, "op", None), SpecifyShape):
if not (
specified_shape.owner is not None
and isinstance(specified_shape.owner.op, SpecifyShape)
):
return False
x, *shape = specified_shape.owner.inputs
......
......@@ -6,7 +6,7 @@ import pytest
from pytensor.configdefaults import config
from pytensor.graph.basic import NominalVariable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.utils import MissingInputError
from pytensor.printing import debugprint
from tests.graph.utils import (
......@@ -78,8 +78,13 @@ class TestFunctionGraph:
assert fg.variables == {var1, var2, var3, var4}
assert fg.get_clients(var1) == [(var3.owner, 0)]
assert fg.get_clients(var2) == [(var4.owner, 1)]
assert fg.get_clients(var3) == [("output", 0), (var4.owner, 0)]
assert fg.get_clients(var4) == [("output", 1)]
var3_clients = fg.get_clients(var3)
assert len(var3_clients) == 2
assert var3_clients[0][0].op == Output(0)
assert var3_clients[1] == (var4.owner, 0)
var4_clients = fg.get_clients(var4)
assert len(var4_clients) == 1
assert var4_clients[0][0].op == Output(1)
varC = MyConstant("varC")
var5 = op1(var1, varC)
......@@ -208,8 +213,11 @@ class TestFunctionGraph:
fg = FunctionGraph([var1, var2], [var3, var5], clone=False)
var6 = MyVariable2("var6")
[out_client] = [
cl for cl, _ in fg.clients[fg.outputs[0]] if isinstance(cl.op, Output)
]
with pytest.raises(TypeError):
fg.change_node_input("output", 1, var6)
fg.change_node_input(out_client, 0, var6)
with pytest.raises(TypeError):
fg.change_node_input(var5.owner, 1, var6)
......@@ -358,12 +366,13 @@ class TestFunctionGraph:
# TODO: What if the index value is greater than 1? It will throw an
# `IndexError`, but that doesn't sound like anything we'd want.
out_node = Output(idx=1).make_node(var4)
with pytest.raises(Exception, match="Inconsistent clients list.*"):
fg.add_client(var4, ("output", 1))
fg.add_client(var4, (out_node, 0))
fg.check_integrity()
fg.remove_client(var4, ("output", 1))
fg.remove_client(var4, (out_node, 0))
with pytest.raises(TypeError, match="The first entry of.*"):
fg.add_client(var4, (None, 0))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论