提交 0ebc83b3 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix `vectorize_graph` bug when replacements were provided only some outputs of a node

The provided output could be silently ignored and replaced by the new output of the vectorized node. The changes also avoid vectorizing multiple-output nodes when none of the unreplaced outputs are needed.
上级 c4ae6e34
......@@ -1439,15 +1439,16 @@ def io_toposort(
order = []
while todo:
cur = todo.pop()
# We suppose that all outputs are always computed
if cur.outputs[0] in computed:
if all(out in computed for out in cur.outputs):
continue
if all(i in computed or i.owner is None for i in cur.inputs):
computed.update(cur.outputs)
order.append(cur)
else:
todo.append(cur)
todo.extend(i.owner for i in cur.inputs if i.owner)
todo.extend(
i.owner for i in cur.inputs if (i.owner and i not in computed)
)
return order
compute_deps = None
......
......@@ -306,6 +306,11 @@ def vectorize_graph(
vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs]
vect_node = vectorize_node(node, *vect_inputs)
for output, vect_output in zip(node.outputs, vect_node.outputs):
if output in vect_vars:
# This can happen when some outputs of a multi-output node are given a replacement,
# while some of the remaining outputs are still needed in the graph.
# We make sure we don't overwrite the provided replacement with the newly vectorized output
continue
vect_vars[output] = vect_output
seq_vect_outputs = [vect_vars[out] for out in seq_outputs]
......
......@@ -34,7 +34,7 @@ from pytensor.tensor.math import max_and_argmax
from pytensor.tensor.type import TensorType, iscalars, matrix, scalars, vector
from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.variable import TensorVariable
from tests.graph.utils import MyInnerGraphOp
from tests.graph.utils import MyInnerGraphOp, op_multiple_outputs
class MyType(Type):
......@@ -287,6 +287,45 @@ class TestToposort:
all = io_toposort([], o0.outputs)
assert all == [o0]
def test_multi_output_nodes(self):
l0, r0 = op_multiple_outputs(shared(0.0))
l1, r1 = op_multiple_outputs(shared(0.0))
v0 = r0 + 1
v1 = pt.exp(v0)
out = r1 * v1
# When either r0 or r1 is provided as an input, the respective node shouldn't be part of the toposort
assert set(io_toposort([], [out])) == {
r0.owner,
r1.owner,
v0.owner,
v1.owner,
out.owner,
}
assert set(io_toposort([r0], [out])) == {
r1.owner,
v0.owner,
v1.owner,
out.owner,
}
assert set(io_toposort([r1], [out])) == {
r0.owner,
v0.owner,
v1.owner,
out.owner,
}
assert set(io_toposort([r0, r1], [out])) == {v0.owner, v1.owner, out.owner}
# When l0 and/or l1 are provided, we still need to compute the respective nodes
assert set(io_toposort([l0, l1], [out])) == {
r0.owner,
r1.owner,
v0.owner,
v1.owner,
out.owner,
}
class TestEval:
def setup_method(self):
......
......@@ -5,10 +5,15 @@ import scipy.special
import pytensor.tensor as pt
from pytensor import config, function, shared
from pytensor.graph.basic import equal_computations, graph_inputs
from pytensor.graph.replace import clone_replace, graph_replace, vectorize_graph
from pytensor.graph.replace import (
clone_replace,
graph_replace,
vectorize_graph,
vectorize_node,
)
from pytensor.tensor import dvector, fvector, vector
from tests import unittest_tools as utt
from tests.graph.utils import MyOp, MyVariable
from tests.graph.utils import MyOp, MyVariable, op_multiple_outputs
class TestCloneReplace:
......@@ -227,8 +232,6 @@ class TestGraphReplace:
class TestVectorizeGraph:
# TODO: Add tests with multiple outputs, constants, and other singleton types
def test_basic(self):
x = pt.vector("x")
y = pt.exp(x) / pt.sum(pt.exp(x))
......@@ -260,3 +263,63 @@ class TestVectorizeGraph:
new_y1_res, new_y2_res = fn(new_x_test)
np.testing.assert_allclose(new_y1_res, [0, 3, 6])
np.testing.assert_allclose(new_y2_res, [2, 5, 8])
def test_multi_output_node(self):
x = pt.scalar("x")
node = op_multiple_outputs.make_node(x)
y1, y2 = node.outputs
out = pt.add(y1, y2)
new_x = pt.vector("new_x")
new_y1 = pt.vector("new_y1")
new_y2 = pt.vector("new_y2")
# Cases where either x or both of y1 and y2 are given replacements
new_out = vectorize_graph(out, {x: new_x})
expected_new_out = pt.add(*vectorize_node(node, new_x).outputs)
assert equal_computations([new_out], [expected_new_out])
new_out = vectorize_graph(out, {y1: new_y1, y2: new_y2})
expected_new_out = pt.add(new_y1, new_y2)
assert equal_computations([new_out], [expected_new_out])
new_out = vectorize_graph(out, {x: new_x, y1: new_y1, y2: new_y2})
expected_new_out = pt.add(new_y1, new_y2)
assert equal_computations([new_out], [expected_new_out])
# Special case where x is given a replacement as well as only one of y1 and y2
# The graph combines the replaced variable with the other vectorized output
new_out = vectorize_graph(out, {x: new_x, y1: new_y1})
expected_new_out = pt.add(new_y1, vectorize_node(node, new_x).outputs[1])
assert equal_computations([new_out], [expected_new_out])
def test_multi_output_node_random_variable(self):
"""This is a regression test for #569.
Functionally, it covers the same case as `test_multiple_output_node`
"""
# RandomVariables have two outputs, a hidden RNG and the visible draws
beta0 = pt.random.normal(name="beta0")
beta1 = pt.random.normal(name="beta1")
out1 = beta0 + 1
out2 = beta1 * pt.exp(out1)
# We replace the second output of each RandomVariable
new_beta0 = pt.tensor("new_beta0", shape=(3,))
new_beta1 = pt.tensor("new_beta1", shape=(3,))
new_outs = vectorize_graph(
[out1, out2],
replace={
beta0: new_beta0,
beta1: new_beta1,
},
)
expected_new_outs = [
new_beta0 + 1,
new_beta1 * pt.exp(new_beta0 + 1),
]
assert equal_computations(new_outs, expected_new_outs)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论