提交 9dd929ab authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Remove public vectorize_node helper and allow `_vectorize_node` to return list of variables

上级 10a4b9f0
......@@ -208,19 +208,13 @@ def graph_replace(
@singledispatch
def _vectorize_node(op: Op, node: Apply, *batched_inputs) -> Apply:
def _vectorize_node(op: Op, node: Apply, *batched_inputs) -> Apply | Sequence[Variable]:
# Default implementation is provided in pytensor.tensor.blockwise
raise NotImplementedError
def vectorize_node(node: Apply, *batched_inputs) -> Apply:
"""Returns vectorized version of node with new batched inputs."""
op = node.op
return _vectorize_node(op, node, *batched_inputs)
def _vectorize_not_needed(op, node, *batched_inputs):
return op.make_node(*batched_inputs)
return op.make_node(*batched_inputs).outputs
@overload
......@@ -306,8 +300,16 @@ def vectorize_graph(
vect_vars = dict(zip(inputs, new_inputs, strict=True))
for node in toposort(seq_outputs, blockers=inputs):
vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs]
vect_node = vectorize_node(node, *vect_inputs)
for output, vect_output in zip(node.outputs, vect_node.outputs, strict=True):
vect_node_or_outputs = _vectorize_node(node.op, node, *vect_inputs)
# Compatibility with the old API
vect_outputs = (
vect_node_or_outputs.outputs
if isinstance(vect_node_or_outputs, Apply)
else vect_node_or_outputs
)
for output, vect_output in zip(node.outputs, vect_outputs, strict=True):
if output in vect_vars:
# This can happen when some outputs of a multi-output node are given a replacement,
# while some of the remaining outputs are still needed in the graph.
......
from pytensor.compile.mode import optdb
from pytensor.graph import Constant, Op, node_rewriter
from pytensor.graph.destroyhandler import inplace_candidates
from pytensor.graph.replace import vectorize_node
from pytensor.graph.replace import vectorize_graph
from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter
from pytensor.graph.rewriting.unify import OpPattern, OpPatternOpTypeType
from pytensor.graph.traversal import apply_ancestors
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
from pytensor.tensor.blockwise import Blockwise, _squeeze_left
from pytensor.tensor.math import Dot
......@@ -37,10 +38,15 @@ def local_useless_blockwise(fgraph, node):
"""
op = node.op
inputs = node.inputs
dummy_core_node = op._create_dummy_core_node(node.inputs)
vect_node = vectorize_node(dummy_core_node, *inputs)
if not isinstance(vect_node.op, Blockwise):
return copy_stack_trace(node.outputs, vect_node.outputs)
dummy_core_node, dummy_inputs = op._create_dummy_core_node(
inputs, return_dummy_inputs=True
)
outputs = vectorize_graph(dummy_core_node.outputs, dict(zip(dummy_inputs, inputs)))
if not any(
isinstance(vect_node.op, Blockwise)
for vect_node in apply_ancestors(outputs, blockers=inputs)
):
return copy_stack_trace(node.outputs, outputs)
@node_rewriter([Blockwise])
......
......@@ -1602,7 +1602,7 @@ def local_blockwise_of_subtensor(fgraph, node):
def local_blockwise_inc_subtensor(fgraph, node):
"""Rewrite blockwised inc_subtensors.
Note: The reason we don't apply this rewrite eagerly in the `vectorize_node` dispatch
Note: The reason we don't apply this rewrite eagerly in the `_vectorize_node` dispatch
Is that we often have batch dimensions from alloc of shapes/reshape that can be removed by rewrites
such as x[:vectorized(w.shape[0])].set(y), that will later be rewritten as x[:w.shape[1]].set(y),
......
......@@ -856,7 +856,7 @@ def _vectorize_reshape(op, node, x, shape):
else:
raise ValueError("Invalid shape length passed into vectorize node of Reshape")
return reshape(x, new_shape, ndim=len(new_shape)).owner
return reshape(x, new_shape, ndim=len(tuple(new_shape))).owner
def reshape(
......
......@@ -6,10 +6,10 @@ import pytensor.tensor as pt
from pytensor import config, function, shared
from pytensor.graph.basic import equal_computations
from pytensor.graph.replace import (
_vectorize_node,
clone_replace,
graph_replace,
vectorize_graph,
vectorize_node,
)
from pytensor.graph.traversal import graph_inputs
from pytensor.tensor import dvector, fvector, vector
......@@ -277,7 +277,7 @@ class TestVectorizeGraph:
# Cases where either x or both of y1 and y2 are given replacements
new_out = vectorize_graph(out, {x: new_x})
expected_new_out = pt.add(*vectorize_node(node, new_x).outputs)
expected_new_out = pt.add(*_vectorize_node(node.op, node, new_x).outputs)
assert equal_computations([new_out], [expected_new_out])
new_out = vectorize_graph(out, {y1: new_y1, y2: new_y2})
......@@ -291,7 +291,9 @@ class TestVectorizeGraph:
# Special case where x is given a replacement as well as only one of y1 and y2
# The graph combines the replaced variable with the other vectorized output
new_out = vectorize_graph(out, {x: new_x, y1: new_y1})
expected_new_out = pt.add(new_y1, vectorize_node(node, new_x).outputs[1])
expected_new_out = pt.add(
new_y1, _vectorize_node(node.op, node, new_x).outputs[1]
)
assert equal_computations([new_out], [expected_new_out])
def test_multi_output_node_random_variable(self):
......
......@@ -7,6 +7,7 @@ from pytensor.compile import get_default_mode
from pytensor.graph.replace import vectorize_graph
from pytensor.link.numba import NumbaLinker
from pytensor.raise_op import Assert
from pytensor.tensor import as_tensor
from pytensor.tensor.math import eq
from pytensor.tensor.random import normal
from pytensor.tensor.random.basic import NormalRV
......@@ -247,7 +248,7 @@ def test_vectorize():
# Test with size, new size provided
size = pt.as_tensor(np.array((3,), dtype="int64"))
out = normal(vec, size=size)
vect_node = vectorize_graph(out, {vec: mat, size: (2, 3)}).owner
vect_node = vectorize_graph(out, {vec: mat, size: as_tensor((2, 3))}).owner
assert isinstance(vect_node.op, NormalRV)
assert tuple(vect_node.op.size_param(vect_node).eval()) == (2, 3)
assert vect_node.op.dist_params(vect_node)[0] is mat
......
......@@ -11,7 +11,7 @@ from pytensor.compile import get_default_mode, get_mode
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.gradient import grad
from pytensor.graph import Apply, FunctionGraph, Op, rewrite_graph
from pytensor.graph.replace import vectorize_graph, vectorize_node
from pytensor.graph.replace import _vectorize_node, vectorize_graph
from pytensor.link.numba import NumbaLinker
from pytensor.raise_op import assert_op
from pytensor.tensor import (
......@@ -95,8 +95,8 @@ def test_vectorize_blockwise():
tns = tensor(shape=(None, None, None))
# Something that falls back to Blockwise
node = MatrixInverse()(mat).owner
vect_node = vectorize_node(node, tns)
out = MatrixInverse()(mat)
vect_node = vectorize_graph(out, {mat: tns}).owner
assert isinstance(vect_node.op, Blockwise) and isinstance(
vect_node.op.core_op, MatrixInverse
)
......@@ -105,7 +105,7 @@ def test_vectorize_blockwise():
# Useless blockwise
tns4 = tensor(shape=(5, None, None, None))
new_vect_node = vectorize_node(vect_node, tns4)
new_vect_node = vectorize_graph(vect_node.out, {tns: tns4}).owner
assert new_vect_node.op is vect_node.op
assert isinstance(new_vect_node.op, Blockwise) and isinstance(
new_vect_node.op.core_op, MatrixInverse
......@@ -204,7 +204,7 @@ def test_vectorize_node_default_signature():
mat = tensor(shape=(5, None))
node = my_test_op.make_node(vec, mat)
vect_node = vectorize_node(node, mat, mat)
vect_node = _vectorize_node(node.op, node, mat, mat)
assert isinstance(vect_node.op, Blockwise) and isinstance(
vect_node.op.core_op, MyTestOp
)
......
......@@ -16,7 +16,7 @@ from pytensor.compile.function import function
from pytensor.compile.mode import Mode, get_default_mode
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import vectorize_node
from pytensor.graph.replace import vectorize_graph
from pytensor.link.basic import PerformLinker
from pytensor.link.c.basic import CLinker, OpWiseCLinker
from pytensor.link.numba import NumbaLinker
......@@ -1042,46 +1042,39 @@ class TestVectorize:
vec = tensor(shape=(None,))
mat = tensor(shape=(None, None))
node = exp(vec).owner
vect_node = vectorize_node(node, mat)
assert vect_node.op == exp
assert vect_node.inputs[0] is mat
out = exp(vec)
vect_out = vectorize_graph(out, {vec: mat})
assert vect_out.owner.op == exp
assert vect_out.owner.inputs[0] is mat
def test_dimshuffle(self):
vec = tensor(shape=(None,))
mat = tensor(shape=(None, None))
node = exp(vec).owner
vect_node = vectorize_node(node, mat)
assert vect_node.op == exp
assert vect_node.inputs[0] is mat
col_mat = tensor(shape=(None, 1))
tcol_mat = tensor(shape=(None, None, 1))
node = col_mat.dimshuffle(0).owner # drop column
vect_node = vectorize_node(node, tcol_mat)
assert isinstance(vect_node.op, DimShuffle)
assert vect_node.op.new_order == (0, 1)
assert vect_node.inputs[0] is tcol_mat
assert vect_node.outputs[0].type.shape == (None, None)
out = col_mat.dimshuffle(0) # drop column
vect_out = vectorize_graph(out, {col_mat: tcol_mat})
assert isinstance(vect_out.owner.op, DimShuffle)
assert vect_out.owner.op.new_order == (0, 1)
assert vect_out.owner.inputs[0] is tcol_mat
assert vect_out.owner.outputs[0].type.shape == (None, None)
def test_CAReduce(self):
mat = tensor(shape=(None, None))
tns = tensor(shape=(None, None, None))
node = pt_sum(mat).owner
vect_node = vectorize_node(node, tns)
assert isinstance(vect_node.op, Sum)
assert vect_node.op.axis == (1, 2)
assert vect_node.inputs[0] is tns
out = pt_sum(mat)
vect_out = vectorize_graph(out, {mat: tns})
assert isinstance(vect_out.owner.op, Sum)
assert vect_out.owner.op.axis == (1, 2)
assert vect_out.owner.inputs[0] is tns
bool_mat = tensor(dtype="bool", shape=(None, None))
bool_tns = tensor(dtype="bool", shape=(None, None, None))
node = pt_any(bool_mat, axis=-2).owner
vect_node = vectorize_node(node, bool_tns)
assert isinstance(vect_node.op, Any)
assert vect_node.op.axis == (1,)
assert vect_node.inputs[0] is bool_tns
out = pt_any(bool_mat, axis=-2)
vect_out = vectorize_graph(out, {bool_mat: bool_tns})
assert isinstance(vect_out.owner.op, Any)
assert vect_out.owner.op.axis == (1,)
assert vect_out.owner.inputs[0] is bool_tns
def careduce_benchmark_tester(axis, c_contiguous, mode, benchmark):
......
......@@ -21,7 +21,7 @@ from pytensor.configdefaults import config
from pytensor.gradient import NullTypeGradError, grad, numeric_grad
from pytensor.graph.basic import Variable, equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import vectorize_node
from pytensor.graph.replace import vectorize_graph
from pytensor.graph.traversal import ancestors, applys_between
from pytensor.link.c.basic import DualLinker
from pytensor.link.numba import NumbaLinker
......@@ -1070,11 +1070,10 @@ class TestMaxAndArgmax:
argmax_x = argmax(x, axis=core_axis)
arg_max_node = argmax_x.owner
new_node = vectorize_node(arg_max_node, batch_x)
vect_out = vectorize_graph(argmax_x, {x: batch_x})
assert isinstance(new_node.op, Argmax)
assert new_node.op.axis == batch_axis
assert isinstance(vect_out.owner.op, Argmax)
assert vect_out.owner.op.axis == batch_axis
class TestArgminArgmax:
......
......@@ -8,7 +8,7 @@ from pytensor import In, Mode, Out, function, grad
from pytensor.compile.ops import DeepCopyOp
from pytensor.configdefaults import config
from pytensor.graph.basic import Variable, equal_computations
from pytensor.graph.replace import clone_replace, vectorize_node
from pytensor.graph.replace import clone_replace, vectorize_graph
from pytensor.graph.type import Type
from pytensor.scalar.basic import ScalarConstant
from pytensor.tensor import as_tensor_variable, broadcast_to, get_vector_length, row
......@@ -742,9 +742,9 @@ class TestVectorize:
def test_shape(self):
vec = tensor(shape=(None,), dtype="float64")
mat = tensor(shape=(None, None), dtype="float64")
node = shape(vec).owner
out = shape(vec)
[vect_out] = vectorize_node(node, mat).outputs
vect_out = vectorize_graph(out, {vec: mat})
assert equal_computations(
[vect_out], [broadcast_to(mat.shape[1:], (*mat.shape[:1], 1))]
)
......@@ -758,8 +758,8 @@ class TestVectorize:
mat = tensor(shape=(None, None), dtype="float64")
tns = tensor(shape=(None, None, None, None), dtype="float64")
node = shape(mat).owner
[vect_out] = vectorize_node(node, tns).outputs
out = shape(mat)
vect_out = vectorize_graph(out, {mat: tns})
assert equal_computations(
[vect_out], [broadcast_to(tns.shape[2:], (*tns.shape[:2], 2))]
)
......@@ -779,11 +779,13 @@ class TestVectorize:
vec = tensor(shape=(None,), dtype="float64")
mat = tensor(shape=(None, None), dtype="float64")
shape = (-1, x)
node = reshape(vec, shape).owner
shape = as_tensor_variable([-1, x])
out = reshape(vec, shape)
[vect_out] = vectorize_node(node, mat, shape).outputs
assert equal_computations([vect_out], [reshape(mat, (*mat.shape[:1], -1, x))])
vect_out = vectorize_graph(out, {vec: mat})
utt.assert_equal_computations(
[vect_out], [reshape(mat, (*mat.shape[:1], *stack((-1, x))))]
)
x_test_value = 2
mat_test_value = np.ones((5, 6))
......@@ -795,12 +797,12 @@ class TestVectorize:
ref_fn(x_test_value, mat_test_value),
)
new_shape = (5, -1, x)
[vect_out] = vectorize_node(node, mat, new_shape).outputs
assert equal_computations([vect_out], [reshape(mat, new_shape)])
new_shape = as_tensor_variable((5, -1, x))
vect_out = vectorize_graph(out, {vec: mat, shape: new_shape})
utt.assert_equal_computations([vect_out], [reshape(mat, new_shape)])
new_shape = stack([[-1, x], [x - 1, -1]], axis=0)
[vect_out] = vectorize_node(node, vec, new_shape).outputs
vect_out = vectorize_graph(out, {shape: new_shape})
vec_test_value = np.arange(6)
np.testing.assert_allclose(
vect_out.eval({x: 3, vec: vec_test_value}),
......@@ -811,13 +813,13 @@ class TestVectorize:
ValueError,
match="Invalid shape length passed into vectorize node of Reshape",
):
vectorize_node(node, vec, (5, 2, x))
vectorize_graph(out, {shape: as_tensor_variable((5, 2, x))})
with pytest.raises(
ValueError,
match="Invalid shape length passed into vectorize node of Reshape",
):
vectorize_node(node, mat, (5, 3, 2, x))
vectorize_graph(out, {vec: mat, shape: as_tensor_variable((5, 3, 2, x))})
def test_specify_shape(self):
x = scalar("x", dtype=int)
......@@ -825,27 +827,9 @@ class TestVectorize:
tns = tensor(shape=(None, None, None))
shape = (x, None)
node = specify_shape(mat, shape).owner
vect_node = vectorize_node(node, tns, *shape)
assert equal_computations(
vect_node.outputs, [specify_shape(tns, (None, x, None))]
)
new_shape = (5, 2, x)
vect_node = vectorize_node(node, tns, *new_shape)
assert equal_computations(vect_node.outputs, [specify_shape(tns, (5, 2, x))])
out = specify_shape(mat, shape)
vect_out = vectorize_graph(out, {mat: tns})
assert equal_computations([vect_out], [specify_shape(tns, (None, x, None))])
with pytest.raises(NotImplementedError):
vectorize_node(node, mat, *([x, x], None))
with pytest.raises(
ValueError,
match="Invalid number of shape arguments passed into vectorize node of SpecifyShape",
):
vectorize_node(node, mat, *(5, 2, x))
with pytest.raises(
ValueError,
match="Invalid number of shape arguments passed into vectorize node of SpecifyShape",
):
vectorize_node(node, tns, *(5, 3, 2, x))
vectorize_graph(out, {x: as_tensor_variable([x, x])})
......@@ -9,7 +9,7 @@ from scipy.special import softmax as scipy_softmax
from pytensor.compile.function import function
from pytensor.configdefaults import config
from pytensor.graph.replace import vectorize_node
from pytensor.graph.replace import vectorize_graph
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.special import (
LogSoftmax,
......@@ -168,18 +168,18 @@ def test_vectorize_softmax(op, constructor, core_axis, batch_axis):
x = tensor(shape=(5, 5, 5, 5))
batch_x = tensor(shape=(3, 5, 5, 5, 5))
node = constructor(x, axis=core_axis).owner
assert isinstance(node.op, op)
out = constructor(x, axis=core_axis)
assert isinstance(out.owner.op, op)
new_node = vectorize_node(node, batch_x)
new_out = vectorize_graph(out, {x: batch_x})
if len(batch_axis) == 1:
assert isinstance(new_node.op, op)
assert (new_node.op.axis,) == batch_axis
assert isinstance(new_out.owner.op, op)
assert (new_out.owner.op.axis,) == batch_axis
else:
assert isinstance(new_node.op, Blockwise) and isinstance(
new_node.op.core_op, op
assert isinstance(new_out.owner.op, Blockwise) and isinstance(
new_out.owner.op.core_op, op
)
assert new_node.op.core_op.axis == core_axis
assert new_out.owner.op.core_op.axis == core_axis
def test_poch():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论