提交 8483b046 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Move io_connection_pattern to graph/op.py

上级 c6dae89f
...@@ -17,11 +17,10 @@ from pytensor.graph.basic import ( ...@@ -17,11 +17,10 @@ from pytensor.graph.basic import (
NominalVariable, NominalVariable,
Variable, Variable,
graph_inputs, graph_inputs,
io_connection_pattern,
) )
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.null_type import NullType from pytensor.graph.null_type import NullType
from pytensor.graph.op import HasInnerGraph, Op from pytensor.graph.op import HasInnerGraph, Op, io_connection_pattern
from pytensor.graph.replace import clone_replace from pytensor.graph.replace import clone_replace
from pytensor.graph.utils import MissingInputError from pytensor.graph.utils import MissingInputError
......
...@@ -1633,71 +1633,6 @@ def default_node_formatter(op, argstrings): ...@@ -1633,71 +1633,6 @@ def default_node_formatter(op, argstrings):
return f"{op.op}({', '.join(argstrings)})" return f"{op.op}({', '.join(argstrings)})"
def io_connection_pattern(inputs, outputs):
"""Return the connection pattern of a subgraph defined by given inputs and outputs."""
inner_nodes = io_toposort(inputs, outputs)
# Initialize 'connect_pattern_by_var' by establishing each input as
# connected only to itself
connect_pattern_by_var = {}
nb_inputs = len(inputs)
for i in range(nb_inputs):
input = inputs[i]
inp_connection_pattern = [i == j for j in range(nb_inputs)]
connect_pattern_by_var[input] = inp_connection_pattern
# Iterate through the nodes used to produce the outputs from the
# inputs and, for every node, infer their connection pattern to
# every input from the connection patterns of their parents.
for n in inner_nodes:
# Get the connection pattern of the inner node's op. If the op
# does not define a connection_pattern method, assume that
# every node output is connected to every node input
try:
op_connection_pattern = n.op.connection_pattern(n)
except AttributeError:
op_connection_pattern = [[True] * len(n.outputs)] * len(n.inputs)
# For every output of the inner node, figure out which inputs it
# is connected to by combining the connection pattern of the inner
# node and the connection patterns of the inner node's inputs.
for out_idx in range(len(n.outputs)):
out = n.outputs[out_idx]
out_connection_pattern = [False] * nb_inputs
for inp_idx in range(len(n.inputs)):
inp = n.inputs[inp_idx]
if inp in connect_pattern_by_var:
inp_connection_pattern = connect_pattern_by_var[inp]
# If the node output is connected to the node input, it
# means it is connected to every inner input that the
# node inputs is connected to
if op_connection_pattern[inp_idx][out_idx]:
out_connection_pattern = [
out_connection_pattern[i] or inp_connection_pattern[i]
for i in range(nb_inputs)
]
# Store the connection pattern of the node output
connect_pattern_by_var[out] = out_connection_pattern
# Obtain the global connection pattern by combining the
# connection patterns of the individual outputs
global_connection_pattern = [[] for o in range(len(inputs))]
for out in outputs:
out_connection_pattern = connect_pattern_by_var.get(out)
if out_connection_pattern is None:
# the output is completely isolated from inputs
out_connection_pattern = [False] * len(inputs)
for i in range(len(inputs)):
global_connection_pattern[i].append(out_connection_pattern[i])
return global_connection_pattern
def op_as_string( def op_as_string(
i, op, leaf_formatter=default_leaf_formatter, node_formatter=default_node_formatter i, op, leaf_formatter=default_leaf_formatter, node_formatter=default_node_formatter
): ):
......
...@@ -13,7 +13,7 @@ from typing import ( ...@@ -13,7 +13,7 @@ from typing import (
import pytensor import pytensor
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Variable from pytensor.graph.basic import Apply, Variable, io_toposort
from pytensor.graph.utils import ( from pytensor.graph.utils import (
MetaObject, MetaObject,
TestValueError, TestValueError,
...@@ -753,3 +753,68 @@ def get_test_values(*args: Variable) -> Any | list[Any]: ...@@ -753,3 +753,68 @@ def get_test_values(*args: Variable) -> Any | list[Any]:
return rval return rval
return [tuple(rval)] return [tuple(rval)]
def io_connection_pattern(inputs, outputs):
"""Return the connection pattern of a subgraph defined by given inputs and outputs."""
inner_nodes = io_toposort(inputs, outputs)
# Initialize 'connect_pattern_by_var' by establishing each input as
# connected only to itself
connect_pattern_by_var = {}
nb_inputs = len(inputs)
for i in range(nb_inputs):
input = inputs[i]
inp_connection_pattern = [i == j for j in range(nb_inputs)]
connect_pattern_by_var[input] = inp_connection_pattern
# Iterate through the nodes used to produce the outputs from the
# inputs and, for every node, infer their connection pattern to
# every input from the connection patterns of their parents.
for n in inner_nodes:
# Get the connection pattern of the inner node's op. If the op
# does not define a connection_pattern method, assume that
# every node output is connected to every node input
try:
op_connection_pattern = n.op.connection_pattern(n)
except AttributeError:
op_connection_pattern = [[True] * len(n.outputs)] * len(n.inputs)
# For every output of the inner node, figure out which inputs it
# is connected to by combining the connection pattern of the inner
# node and the connection patterns of the inner node's inputs.
for out_idx in range(len(n.outputs)):
out = n.outputs[out_idx]
out_connection_pattern = [False] * nb_inputs
for inp_idx in range(len(n.inputs)):
inp = n.inputs[inp_idx]
if inp in connect_pattern_by_var:
inp_connection_pattern = connect_pattern_by_var[inp]
# If the node output is connected to the node input, it
# means it is connected to every inner input that the
# node inputs is connected to
if op_connection_pattern[inp_idx][out_idx]:
out_connection_pattern = [
out_connection_pattern[i] or inp_connection_pattern[i]
for i in range(nb_inputs)
]
# Store the connection pattern of the node output
connect_pattern_by_var[out] = out_connection_pattern
# Obtain the global connection pattern by combining the
# connection patterns of the individual outputs
global_connection_pattern = [[] for o in range(len(inputs))]
for out in outputs:
out_connection_pattern = connect_pattern_by_var.get(out)
if out_connection_pattern is None:
# the output is completely isolated from inputs
out_connection_pattern = [False] * len(inputs)
for i in range(len(inputs)):
global_connection_pattern[i].append(out_connection_pattern[i])
return global_connection_pattern
...@@ -68,10 +68,9 @@ from pytensor.graph.basic import ( ...@@ -68,10 +68,9 @@ from pytensor.graph.basic import (
Variable, Variable,
equal_computations, equal_computations,
graph_inputs, graph_inputs,
io_connection_pattern,
) )
from pytensor.graph.features import NoOutputFromInplace from pytensor.graph.features import NoOutputFromInplace
from pytensor.graph.op import HasInnerGraph, Op from pytensor.graph.op import HasInnerGraph, Op, io_connection_pattern
from pytensor.graph.replace import clone_replace from pytensor.graph.replace import clone_replace
from pytensor.graph.type import HasShape from pytensor.graph.type import HasShape
from pytensor.graph.utils import InconsistencyError, MissingInputError from pytensor.graph.utils import InconsistencyError, MissingInputError
......
...@@ -584,11 +584,6 @@ def test_apply_depends_on(): ...@@ -584,11 +584,6 @@ def test_apply_depends_on():
assert apply_depends_on(o3.owner, [o1.owner, o2.owner]) assert apply_depends_on(o3.owner, [o1.owner, o2.owner])
@pytest.mark.xfail(reason="Not implemented")
def test_io_connection_pattern():
raise AssertionError()
def test_get_var_by_name(): def test_get_var_by_name():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2) o1 = MyOp(r1, r2)
......
...@@ -275,3 +275,8 @@ def test_call_name(multi_output): ...@@ -275,3 +275,8 @@ def test_call_name(multi_output):
res_nameless = single_op(x) res_nameless = single_op(x)
assert res_nameless.name is None assert res_nameless.name is None
@pytest.mark.xfail(reason="Not implemented")
def test_io_connection_pattern():
raise AssertionError()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论