提交 9ba6d99f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Replace str "output" by a dummy Op in the clients of the FunctionGraph

上级 7f623fef
...@@ -30,6 +30,7 @@ from pytensor.configdefaults import config ...@@ -30,6 +30,7 @@ from pytensor.configdefaults import config
from pytensor.graph.basic import Variable, io_toposort from pytensor.graph.basic import Variable, io_toposort
from pytensor.graph.destroyhandler import DestroyHandler from pytensor.graph.destroyhandler import DestroyHandler
from pytensor.graph.features import AlreadyThere, BadOptimization from pytensor.graph.features import AlreadyThere, BadOptimization
from pytensor.graph.fg import Output
from pytensor.graph.op import HasInnerGraph, Op from pytensor.graph.op import HasInnerGraph, Op
from pytensor.graph.utils import InconsistencyError, MethodNotDefined from pytensor.graph.utils import InconsistencyError, MethodNotDefined
from pytensor.link.basic import Container, LocalLinker from pytensor.link.basic import Container, LocalLinker
...@@ -628,7 +629,9 @@ def _is_used_in_graph(fgraph, var): ...@@ -628,7 +629,9 @@ def _is_used_in_graph(fgraph, var):
True if `var` is used by another node in the graph. True if `var` is used by another node in the graph.
""" """
return not (fgraph.clients[var] == [("output", 1)] or fgraph.clients[var] == []) return any(
client for client, _ in fgraph.clients[var] if not isinstance(client.op, Output)
)
def _check_strides_match(a, b, warn_err, op): def _check_strides_match(a, b, warn_err, op):
...@@ -977,7 +980,7 @@ def _check_preallocated_output( ...@@ -977,7 +980,7 @@ def _check_preallocated_output(
# disable memory checks in that mode, since they were already run. # disable memory checks in that mode, since they were already run.
try: try:
changed_inner_mode = False changed_inner_mode = False
if isinstance(getattr(node, "op", None), HasInnerGraph): if isinstance(node.op, HasInnerGraph):
fn = node.op.fn fn = node.op.fn
if not (fn and hasattr(fn, "maker") and hasattr(fn.maker, "mode")): if not (fn and hasattr(fn, "maker") and hasattr(fn.maker, "mode")):
_logger.warning(f"Expected pytensor function not found in {node.op}.fn") _logger.warning(f"Expected pytensor function not found in {node.op}.fn")
...@@ -1132,10 +1135,6 @@ class _FunctionGraphEvent: ...@@ -1132,10 +1135,6 @@ class _FunctionGraphEvent:
def __init__(self, kind, node, idx=None, reason=None): def __init__(self, kind, node, idx=None, reason=None):
self.kind = kind self.kind = kind
if node == "output":
self.node = "output"
self.op = "output"
else:
self.node = node self.node = node
self.op = node.op self.op = node.op
self.idx = idx self.idx = idx
...@@ -1143,7 +1142,7 @@ class _FunctionGraphEvent: ...@@ -1143,7 +1142,7 @@ class _FunctionGraphEvent:
def __str__(self): def __str__(self):
if self.kind == "change": if self.kind == "change":
if self.op != "output": if not isinstance(self.op, Output):
msg = str(len(self.node.inputs)) msg = str(len(self.node.inputs))
else: else:
msg = "" msg = ""
......
...@@ -78,8 +78,6 @@ def view_tree_set(fgraph, v, treeset): ...@@ -78,8 +78,6 @@ def view_tree_set(fgraph, v, treeset):
""" """
treeset.add(v) treeset.add(v)
for cl, v_input_pos_to_cl in fgraph.clients[v]: for cl, v_input_pos_to_cl in fgraph.clients[v]:
if cl == "output":
continue
vmap = cl.op.view_map vmap = cl.op.view_map
dmap = cl.op.destroy_map dmap = cl.op.destroy_map
for opos, iposlist in chain(vmap.items(), dmap.items()): for opos, iposlist in chain(vmap.items(), dmap.items()):
...@@ -1202,8 +1200,11 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): ...@@ -1202,8 +1200,11 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
has_destroyers_attr = hasattr(fgraph, "has_destroyers") has_destroyers_attr = hasattr(fgraph, "has_destroyers")
for i in range(len(fgraph.outputs)): for i in range(len(fgraph.outputs)):
original_out = fgraph.outputs[i]
output_client = fgraph.get_output_client(i)
views_of_output_i = set() views_of_output_i = set()
view_tree_set(fgraph, alias_root(fgraph.outputs[i]), views_of_output_i) view_tree_set(fgraph, alias_root(original_out), views_of_output_i)
copied = False copied = False
# do not allow outputs to be aliased # do not allow outputs to be aliased
for j in range(i + 1, len(fgraph.outputs)): for j in range(i + 1, len(fgraph.outputs)):
...@@ -1212,16 +1213,16 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): ...@@ -1212,16 +1213,16 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
if fgraph.outputs[j] in views_of_output_i: if fgraph.outputs[j] in views_of_output_i:
if wrapped_outputs[i].borrow and wrapped_outputs[j].borrow: if wrapped_outputs[i].borrow and wrapped_outputs[j].borrow:
fgraph.change_node_input( fgraph.change_node_input(
"output", i, view_op(fgraph.outputs[i]), reason=reason *output_client, view_op(original_out), reason=reason
) )
else: else:
fgraph.change_node_input( fgraph.change_node_input(
"output", i, deep_copy_op(fgraph.outputs[i]), reason=reason *output_client, deep_copy_op(original_out), reason=reason
) )
copied = True copied = True
break break
if not copied: if not copied: # no-break
for input_j in all_graph_inputs: for input_j in all_graph_inputs:
# do not allow outputs to be aliased to an inputs (j), unless # do not allow outputs to be aliased to an inputs (j), unless
# a) that j'th input has been 'destroyed' by # a) that j'th input has been 'destroyed' by
...@@ -1239,33 +1240,29 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): ...@@ -1239,33 +1240,29 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
j = fgraph.inputs.index(input_j) j = fgraph.inputs.index(input_j)
if wrapped_outputs[i].borrow and wrapped_inputs[j].borrow: if wrapped_outputs[i].borrow and wrapped_inputs[j].borrow:
fgraph.change_node_input( fgraph.change_node_input(
"output", *output_client,
i, view_op(original_out),
view_op(fgraph.outputs[i]),
reason=reason, reason=reason,
) )
break break
else: else:
fgraph.change_node_input( fgraph.change_node_input(
"output", *output_client,
i, deep_copy_op(original_out),
deep_copy_op(fgraph.outputs[i]),
reason=reason, reason=reason,
) )
break break
elif wrapped_outputs[i].borrow: elif wrapped_outputs[i].borrow:
fgraph.change_node_input( fgraph.change_node_input(
"output", *output_client,
i, view_op(original_out),
view_op(fgraph.outputs[i]),
reason=reason, reason=reason,
) )
break break
else: else:
fgraph.change_node_input( fgraph.change_node_input(
"output", *output_client,
i, deep_copy_op(original_out),
deep_copy_op(fgraph.outputs[i]),
reason=reason, reason=reason,
) )
break break
......
...@@ -16,20 +16,17 @@ import sys ...@@ -16,20 +16,17 @@ import sys
import time import time
from collections import Counter, defaultdict from collections import Counter, defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Any from typing import Any
import numpy as np import numpy as np
import pytensor import pytensor
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.fg import FunctionGraph, Output
from pytensor.link.utils import get_destroy_dependencies from pytensor.link.utils import get_destroy_dependencies
if TYPE_CHECKING:
from pytensor.graph.fg import FunctionGraph
@contextmanager @contextmanager
def extended_open(filename, mode="r"): def extended_open(filename, mode="r"):
if filename == "<stdout>": if filename == "<stdout>":
...@@ -1038,7 +1035,7 @@ class ProfileStats: ...@@ -1038,7 +1035,7 @@ class ProfileStats:
executable_nodes = set() executable_nodes = set()
for var in fgraph.inputs: for var in fgraph.inputs:
for c, _ in fgraph.clients[var]: for c, _ in fgraph.clients[var]:
if c != "output": if not isinstance(c.op, Output):
deps = c.inputs + destroy_dependencies[c] deps = c.inputs + destroy_dependencies[c]
if all(compute_map[v][0] for v in deps): if all(compute_map[v][0] for v in deps):
executable_nodes.add(c) executable_nodes.add(c)
...@@ -1166,7 +1163,7 @@ class ProfileStats: ...@@ -1166,7 +1163,7 @@ class ProfileStats:
for var in node.outputs: for var in node.outputs:
for c, _ in fgraph.clients[var]: for c, _ in fgraph.clients[var]:
if c != "output": if not isinstance(c.op, Output):
deps = c.inputs + destroy_dependencies[c] deps = c.inputs + destroy_dependencies[c]
if all(compute_map[v][0] for v in deps): if all(compute_map[v][0] for v in deps):
new_exec_nodes.add(c) new_exec_nodes.add(c)
......
...@@ -11,6 +11,7 @@ import pytensor ...@@ -11,6 +11,7 @@ import pytensor
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Constant from pytensor.graph.basic import Constant
from pytensor.graph.features import AlreadyThere, Bookkeeper from pytensor.graph.features import AlreadyThere, Bookkeeper
from pytensor.graph.fg import Output
from pytensor.graph.utils import InconsistencyError from pytensor.graph.utils import InconsistencyError
from pytensor.misc.ordered_set import OrderedSet from pytensor.misc.ordered_set import OrderedSet
...@@ -401,8 +402,6 @@ class DestroyHandler(Bookkeeper): ...@@ -401,8 +402,6 @@ class DestroyHandler(Bookkeeper):
def recursive_destroys_finder(protected_var): def recursive_destroys_finder(protected_var):
# protected_var is the idx'th input of app. # protected_var is the idx'th input of app.
for app, idx in fgraph.clients[protected_var]: for app, idx in fgraph.clients[protected_var]:
if app == "output":
continue
destroy_maps = app.op.destroy_map.values() destroy_maps = app.op.destroy_map.values()
# If True means that the apply node, destroys the protected_var. # If True means that the apply node, destroys the protected_var.
if idx in [dmap for sublist in destroy_maps for dmap in sublist]: if idx in [dmap for sublist in destroy_maps for dmap in sublist]:
...@@ -578,7 +577,7 @@ class DestroyHandler(Bookkeeper): ...@@ -578,7 +577,7 @@ class DestroyHandler(Bookkeeper):
app.inputs[i] changed from old_r to new_r. app.inputs[i] changed from old_r to new_r.
""" """
if app == "output": if isinstance(app.op, Output):
# app == 'output' is special key that means FunctionGraph is redefining which nodes are being # app == 'output' is special key that means FunctionGraph is redefining which nodes are being
# considered 'outputs' of the graph. # considered 'outputs' of the graph.
pass pass
......
差异被折叠。
...@@ -30,7 +30,7 @@ from pytensor.graph.basic import ( ...@@ -30,7 +30,7 @@ from pytensor.graph.basic import (
vars_between, vars_between,
) )
from pytensor.graph.features import AlreadyThere, Feature, NodeFinder from pytensor.graph.features import AlreadyThere, Feature, NodeFinder
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.utils import AssocList, InconsistencyError from pytensor.graph.utils import AssocList, InconsistencyError
from pytensor.misc.ordered_set import OrderedSet from pytensor.misc.ordered_set import OrderedSet
...@@ -738,7 +738,7 @@ class MergeOptimizer(GraphRewriter): ...@@ -738,7 +738,7 @@ class MergeOptimizer(GraphRewriter):
if any( if any(
i in flatten(c.op.destroy_map.values()) i in flatten(c.op.destroy_map.values())
for c, i in clients for c, i in clients
if c != "output" and c.op.destroy_map if c.op.destroy_map
): ):
continue continue
...@@ -1612,8 +1612,6 @@ class PatternNodeRewriter(NodeRewriter): ...@@ -1612,8 +1612,6 @@ class PatternNodeRewriter(NodeRewriter):
if get_nodes and self.get_nodes is not None: if get_nodes and self.get_nodes is not None:
for real_node in self.get_nodes(fgraph, node): for real_node in self.get_nodes(fgraph, node):
if real_node == "output":
continue
ret = self.transform(fgraph, real_node, get_nodes=False) ret = self.transform(fgraph, real_node, get_nodes=False)
if ret is not False and ret is not None: if ret is not False and ret is not None:
return dict(zip(real_node.outputs, ret)) return dict(zip(real_node.outputs, ret))
...@@ -2399,7 +2397,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter): ...@@ -2399,7 +2397,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
if self.tracks_on_change_inputs: if self.tracks_on_change_inputs:
def chin_(node, i, r, new_r, reason): def chin_(node, i, r, new_r, reason):
if node is not current_node and not isinstance(node, str): if node is not current_node and not isinstance(node.op, Output):
q.append(node) q.append(node)
chin = chin_ chin = chin_
......
import copy import copy
from collections.abc import Generator, Sequence from collections.abc import Generator, Sequence
from typing import TYPE_CHECKING, Optional, cast from typing import TYPE_CHECKING, Optional
import pytensor import pytensor
from pytensor.graph.basic import ( from pytensor.graph.basic import (
...@@ -10,7 +10,7 @@ from pytensor.graph.basic import ( ...@@ -10,7 +10,7 @@ from pytensor.graph.basic import (
graph_inputs, graph_inputs,
vars_between, vars_between,
) )
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.db import RewriteDatabaseQuery
...@@ -230,11 +230,9 @@ def get_clients_at_depth( ...@@ -230,11 +230,9 @@ def get_clients_at_depth(
for var in node.outputs: for var in node.outputs:
if depth > 0: if depth > 0:
for out_node, _ in fgraph.clients[var]: for out_node, _ in fgraph.clients[var]:
if out_node == "output": if isinstance(out_node.op, Output):
continue continue
yield from get_clients_at_depth( yield from get_clients_at_depth(fgraph, out_node, depth - 1)
fgraph, cast(Apply, out_node), depth - 1
)
else: else:
assert var.owner is not None assert var.owner is not None
yield var.owner yield var.owner
...@@ -354,9 +354,7 @@ def get_c_declare(fgraph, r, name, sub): ...@@ -354,9 +354,7 @@ def get_c_declare(fgraph, r, name, sub):
# it means they need `r`'s dtype to be declared, so # it means they need `r`'s dtype to be declared, so
# we have to pass `check_input=True` to `c_declare`. # we have to pass `check_input=True` to `c_declare`.
if any( if any(
getattr(c.op, "check_input", config.check_input) getattr(c.op, "check_input", config.check_input) for (c, _) in fgraph.clients[r]
for (c, _) in fgraph.clients[r]
if not isinstance(c, str)
) or (r.owner and getattr(r.owner.op, "check_input", config.check_input)): ) or (r.owner and getattr(r.owner.op, "check_input", config.check_input)):
c_declare = r.type.c_declare(name, sub, True) c_declare = r.type.c_declare(name, sub, True)
else: else:
......
...@@ -954,7 +954,6 @@ class VMLinker(LocalLinker): ...@@ -954,7 +954,6 @@ class VMLinker(LocalLinker):
if k.owner and self.fgraph.clients[k]: if k.owner and self.fgraph.clients[k]:
ls = [] ls = []
for cl in self.fgraph.clients[k]: for cl in self.fgraph.clients[k]:
if cl[0] != "output":
ls += cl[0].outputs ls += cl[0].outputs
dependencies[k] += ls dependencies[k] += ls
return dependencies return dependencies
......
...@@ -437,10 +437,15 @@ N.B.: ...@@ -437,10 +437,15 @@ N.B.:
for out in inner_outputs: for out in inner_outputs:
if ( if (
isinstance(getattr(out.owner, "op", None), HasInnerGraph) out.owner is not None
or hasattr(getattr(out.owner, "op", None), "scalar_op") and (
and isinstance(out.owner.op.scalar_op, HasInnerGraph) isinstance(out.owner.op, HasInnerGraph)
) and out not in inner_graph_vars: or isinstance(
getattr(out.owner.op, "scalar_op", None), HasInnerGraph
)
)
and out not in inner_graph_vars
):
inner_graph_vars.append(out) inner_graph_vars.append(out)
_debugprint( _debugprint(
......
...@@ -27,7 +27,7 @@ from pytensor.graph.basic import ( ...@@ -27,7 +27,7 @@ from pytensor.graph.basic import (
) )
from pytensor.graph.destroyhandler import DestroyHandler from pytensor.graph.destroyhandler import DestroyHandler
from pytensor.graph.features import ReplaceValidate from pytensor.graph.features import ReplaceValidate
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import compute_test_value from pytensor.graph.op import compute_test_value
from pytensor.graph.replace import clone_replace from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import ( from pytensor.graph.rewriting.basic import (
...@@ -1303,7 +1303,7 @@ def scan_save_mem(fgraph, node): ...@@ -1303,7 +1303,7 @@ def scan_save_mem(fgraph, node):
for cl, _ in fgraph.clients[out]: for cl, _ in fgraph.clients[out]:
# 2.1 outputs of the function # 2.1 outputs of the function
# => output needs all its intermediate values # => output needs all its intermediate values
if isinstance(cl, str): if isinstance(cl.op, Output):
# if the node is actually an output, then # if the node is actually an output, then
# we need to store the entire thing # we need to store the entire thing
global_nsteps = None global_nsteps = None
...@@ -1412,7 +1412,7 @@ def scan_save_mem(fgraph, node): ...@@ -1412,7 +1412,7 @@ def scan_save_mem(fgraph, node):
for i, out in enumerate(node.outputs[:c_outs]): for i, out in enumerate(node.outputs[:c_outs]):
# look at all its clients # look at all its clients
for cl, _ in fgraph.clients[out]: for cl, _ in fgraph.clients[out]:
if isinstance(cl, str): if isinstance(cl.op, Output):
store_steps[i] = 0 store_steps[i] = 0
break break
elif not isinstance(cl.op, Subtensor): elif not isinstance(cl.op, Subtensor):
...@@ -2309,7 +2309,7 @@ def scan_push_out_dot1(fgraph, node): ...@@ -2309,7 +2309,7 @@ def scan_push_out_dot1(fgraph, node):
and isinstance(out.owner.op.scalar_op, ps.Add) and isinstance(out.owner.op.scalar_op, ps.Add)
and inp in out.owner.inputs and inp in out.owner.inputs
and len(fgraph.clients[outer_out]) == 1 and len(fgraph.clients[outer_out]) == 1
and not isinstance(fgraph.clients[outer_out][0][0], str) and not isinstance(fgraph.clients[outer_out][0][0], Output)
and isinstance(fgraph.clients[outer_out][0][0].op, Subtensor) and isinstance(fgraph.clients[outer_out][0][0].op, Subtensor)
and fgraph.clients[outer_out][0][0].op.idx_list == (-1,) and fgraph.clients[outer_out][0][0].op.idx_list == (-1,)
): ):
......
...@@ -24,7 +24,7 @@ from pytensor import scalar as ps ...@@ -24,7 +24,7 @@ from pytensor import scalar as ps
from pytensor.gradient import DisconnectedType, grad_undefined from pytensor.gradient import DisconnectedType, grad_undefined
from pytensor.graph import RewriteDatabaseQuery from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply, Constant, Variable, equal_computations from pytensor.graph.basic import Apply, Constant, Variable, equal_computations
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node from pytensor.graph.replace import _vectorize_node
from pytensor.graph.rewriting.db import EquilibriumDB from pytensor.graph.rewriting.db import EquilibriumDB
...@@ -1681,7 +1681,7 @@ class Alloc(COp): ...@@ -1681,7 +1681,7 @@ class Alloc(COp):
return False return False
for client, idx in clients: for client, idx in clients:
if client == "output": if isinstance(client.op, Output):
# If the output is a constant, it will have to be deepcopied # If the output is a constant, it will have to be deepcopied
# each time the function is called. So we do not fold. # each time the function is called. So we do not fold.
return False return False
......
...@@ -31,15 +31,12 @@ def is_rv_used_in_graph(base_rv, node, fgraph): ...@@ -31,15 +31,12 @@ def is_rv_used_in_graph(base_rv, node, fgraph):
TODO: We should apply all the shape rewrites before these rewrites, since TODO: We should apply all the shape rewrites before these rewrites, since
that would properly remove the unnecessary dependencies on `base_rv` (when that would properly remove the unnecessary dependencies on `base_rv` (when
possible). possible).
""" """
return any(
def _node_check(n, i): n
if n == "output": for n, i in fgraph.clients.get(base_rv, ())
n = fgraph.outputs[i].owner if not (n is node or isinstance(n.op, Shape | Shape_i))
return n == node or isinstance(n.op, Shape | Shape_i) )
return not all(_node_check(n, i) for n, i in fgraph.clients.get(base_rv, ()))
@node_rewriter([RandomVariable], inplace=True) @node_rewriter([RandomVariable], inplace=True)
......
...@@ -14,7 +14,7 @@ from pytensor.configdefaults import config ...@@ -14,7 +14,7 @@ from pytensor.configdefaults import config
from pytensor.graph import FunctionGraph from pytensor.graph import FunctionGraph
from pytensor.graph.basic import Apply, Constant, Variable, ancestors, io_toposort from pytensor.graph.basic import Apply, Constant, Variable, ancestors, io_toposort
from pytensor.graph.features import ReplaceValidate from pytensor.graph.features import ReplaceValidate
from pytensor.graph.fg import ApplyOrOutput from pytensor.graph.fg import Output
from pytensor.graph.rewriting.basic import ( from pytensor.graph.rewriting.basic import (
EquilibriumGraphRewriter, EquilibriumGraphRewriter,
GraphRewriter, GraphRewriter,
...@@ -688,7 +688,7 @@ class FusionOptimizer(GraphRewriter): ...@@ -688,7 +688,7 @@ class FusionOptimizer(GraphRewriter):
""" """
FUSEABLE_MAPPING = defaultdict[Variable, list[Apply]] FUSEABLE_MAPPING = defaultdict[Variable, list[Apply]]
UNFUSEABLE_MAPPING = defaultdict[Variable, set[ApplyOrOutput]] UNFUSEABLE_MAPPING = defaultdict[Variable, set[Apply]]
def initialize_fuseable_mappings( def initialize_fuseable_mappings(
*, fg: FunctionGraph *, fg: FunctionGraph
...@@ -727,7 +727,6 @@ class FusionOptimizer(GraphRewriter): ...@@ -727,7 +727,6 @@ class FusionOptimizer(GraphRewriter):
for client, _ in clients: for client, _ in clients:
if ( if (
out_maybe_fuseable out_maybe_fuseable
and not isinstance(client, str) # "output"
and isinstance(client.op, Elemwise) and isinstance(client.op, Elemwise)
# and not isinstance(client.op.scalar_op, ps.Composite) # and not isinstance(client.op.scalar_op, ps.Composite)
and len(client.outputs) == 1 and len(client.outputs) == 1
...@@ -841,7 +840,7 @@ class FusionOptimizer(GraphRewriter): ...@@ -841,7 +840,7 @@ class FusionOptimizer(GraphRewriter):
implied_unfuseable_clients = { implied_unfuseable_clients = {
c c
for client in unfuseable_clients_clone.get(next_out, ()) for client in unfuseable_clients_clone.get(next_out, ())
if not isinstance(client, str) # "output" if not isinstance(client.op, Output)
for c in client.outputs for c in client.outputs
} }
......
...@@ -299,8 +299,6 @@ def local_det_chol(fgraph, node): ...@@ -299,8 +299,6 @@ def local_det_chol(fgraph, node):
""" """
(x,) = node.inputs (x,) = node.inputs
for cl, xpos in fgraph.clients[x]: for cl, xpos in fgraph.clients[x]:
if cl == "output":
continue
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, Cholesky): if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, Cholesky):
L = cl.outputs[0] L = cl.outputs[0]
return [prod(diagonal(L, axis1=-2, axis2=-1) ** 2, axis=-1)] return [prod(diagonal(L, axis1=-2, axis2=-1) ** 2, axis=-1)]
......
...@@ -1126,14 +1126,11 @@ class AlgebraicCanonizer(NodeRewriter): ...@@ -1126,14 +1126,11 @@ class AlgebraicCanonizer(NodeRewriter):
# this canonized graph... if so, we do nothing and wait for # this canonized graph... if so, we do nothing and wait for
# them to be transformed. # them to be transformed.
for c, c_idx in out_clients: for c, c_idx in out_clients:
if c == "output":
continue
while ( while (
isinstance(getattr(c, "op", None), DimShuffle) isinstance(c.op, DimShuffle) and len(fgraph.clients[c.outputs[0]]) <= 1
and len(fgraph.clients[c.outputs[0]]) <= 1
): ):
c = fgraph.clients[c.outputs[0]][0][0] c = fgraph.clients[c.outputs[0]][0][0]
if getattr(c, "op", "") in [self.main, self.inverse, self.reciprocal]: if c.op in [self.main, self.inverse, self.reciprocal]:
return False return False
# Here we make the canonical version of the graph around this node # Here we make the canonical version of the graph around this node
......
...@@ -401,7 +401,7 @@ class ShapeFeature(Feature): ...@@ -401,7 +401,7 @@ class ShapeFeature(Feature):
merged_shape.append(other_shape[i]) merged_shape.append(other_shape[i])
elif ( elif (
ps.owner ps.owner
and isinstance(getattr(ps.owner, "op", None), Shape_i) and isinstance(ps.owner.op, Shape_i)
and ps.owner.op.i == i and ps.owner.op.i == i
and ps.owner.inputs[0] in (r, other_r) and ps.owner.inputs[0] in (r, other_r)
): ):
...@@ -602,7 +602,7 @@ class ShapeFeature(Feature): ...@@ -602,7 +602,7 @@ class ShapeFeature(Feature):
# r is *scheduled*. # r is *scheduled*.
# At that point, node is no longer a client of r, but of new_r # At that point, node is no longer a client of r, but of new_r
for shpnode, idx in fgraph.clients[r] + [(node, i)]: for shpnode, idx in fgraph.clients[r] + [(node, i)]:
if isinstance(getattr(shpnode, "op", None), Shape_i): if isinstance(shpnode.op, Shape_i):
idx = shpnode.op.i idx = shpnode.op.i
repl = self.shape_of[new_r][idx] repl = self.shape_of[new_r][idx]
if repl.owner is shpnode: if repl.owner is shpnode:
...@@ -1057,7 +1057,10 @@ def local_Shape_of_SpecifyShape(fgraph, node): ...@@ -1057,7 +1057,10 @@ def local_Shape_of_SpecifyShape(fgraph, node):
specified_shape = node.inputs[0] specified_shape = node.inputs[0]
if not isinstance(getattr(specified_shape.owner, "op", None), SpecifyShape): if not (
specified_shape.owner is not None
and isinstance(specified_shape.owner.op, SpecifyShape)
):
return False return False
x, *shape = specified_shape.owner.inputs x, *shape = specified_shape.owner.inputs
......
...@@ -6,7 +6,7 @@ import pytest ...@@ -6,7 +6,7 @@ import pytest
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import NominalVariable from pytensor.graph.basic import NominalVariable
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.utils import MissingInputError from pytensor.graph.utils import MissingInputError
from pytensor.printing import debugprint from pytensor.printing import debugprint
from tests.graph.utils import ( from tests.graph.utils import (
...@@ -78,8 +78,13 @@ class TestFunctionGraph: ...@@ -78,8 +78,13 @@ class TestFunctionGraph:
assert fg.variables == {var1, var2, var3, var4} assert fg.variables == {var1, var2, var3, var4}
assert fg.get_clients(var1) == [(var3.owner, 0)] assert fg.get_clients(var1) == [(var3.owner, 0)]
assert fg.get_clients(var2) == [(var4.owner, 1)] assert fg.get_clients(var2) == [(var4.owner, 1)]
assert fg.get_clients(var3) == [("output", 0), (var4.owner, 0)] var3_clients = fg.get_clients(var3)
assert fg.get_clients(var4) == [("output", 1)] assert len(var3_clients) == 2
assert var3_clients[0][0].op == Output(0)
assert var3_clients[1] == (var4.owner, 0)
var4_clients = fg.get_clients(var4)
assert len(var4_clients) == 1
assert var4_clients[0][0].op == Output(1)
varC = MyConstant("varC") varC = MyConstant("varC")
var5 = op1(var1, varC) var5 = op1(var1, varC)
...@@ -208,8 +213,11 @@ class TestFunctionGraph: ...@@ -208,8 +213,11 @@ class TestFunctionGraph:
fg = FunctionGraph([var1, var2], [var3, var5], clone=False) fg = FunctionGraph([var1, var2], [var3, var5], clone=False)
var6 = MyVariable2("var6") var6 = MyVariable2("var6")
[out_client] = [
cl for cl, _ in fg.clients[fg.outputs[0]] if isinstance(cl.op, Output)
]
with pytest.raises(TypeError): with pytest.raises(TypeError):
fg.change_node_input("output", 1, var6) fg.change_node_input(out_client, 0, var6)
with pytest.raises(TypeError): with pytest.raises(TypeError):
fg.change_node_input(var5.owner, 1, var6) fg.change_node_input(var5.owner, 1, var6)
...@@ -358,12 +366,13 @@ class TestFunctionGraph: ...@@ -358,12 +366,13 @@ class TestFunctionGraph:
# TODO: What if the index value is greater than 1? It will throw an # TODO: What if the index value is greater than 1? It will throw an
# `IndexError`, but that doesn't sound like anything we'd want. # `IndexError`, but that doesn't sound like anything we'd want.
out_node = Output(idx=1).make_node(var4)
with pytest.raises(Exception, match="Inconsistent clients list.*"): with pytest.raises(Exception, match="Inconsistent clients list.*"):
fg.add_client(var4, ("output", 1)) fg.add_client(var4, (out_node, 0))
fg.check_integrity() fg.check_integrity()
fg.remove_client(var4, ("output", 1)) fg.remove_client(var4, (out_node, 0))
with pytest.raises(TypeError, match="The first entry of.*"): with pytest.raises(TypeError, match="The first entry of.*"):
fg.add_client(var4, (None, 0)) fg.add_client(var4, (None, 0))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论