提交 9dd929ab authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Remove public vectorize_node helper and allow `_vectorize_node` to return list of variables

上级 10a4b9f0
...@@ -208,19 +208,13 @@ def graph_replace( ...@@ -208,19 +208,13 @@ def graph_replace(
@singledispatch @singledispatch
def _vectorize_node(op: Op, node: Apply, *batched_inputs) -> Apply: def _vectorize_node(op: Op, node: Apply, *batched_inputs) -> Apply | Sequence[Variable]:
# Default implementation is provided in pytensor.tensor.blockwise # Default implementation is provided in pytensor.tensor.blockwise
raise NotImplementedError raise NotImplementedError
def vectorize_node(node: Apply, *batched_inputs) -> Apply:
"""Returns vectorized version of node with new batched inputs."""
op = node.op
return _vectorize_node(op, node, *batched_inputs)
def _vectorize_not_needed(op, node, *batched_inputs): def _vectorize_not_needed(op, node, *batched_inputs):
return op.make_node(*batched_inputs) return op.make_node(*batched_inputs).outputs
@overload @overload
...@@ -306,8 +300,16 @@ def vectorize_graph( ...@@ -306,8 +300,16 @@ def vectorize_graph(
vect_vars = dict(zip(inputs, new_inputs, strict=True)) vect_vars = dict(zip(inputs, new_inputs, strict=True))
for node in toposort(seq_outputs, blockers=inputs): for node in toposort(seq_outputs, blockers=inputs):
vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs] vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs]
vect_node = vectorize_node(node, *vect_inputs)
for output, vect_output in zip(node.outputs, vect_node.outputs, strict=True): vect_node_or_outputs = _vectorize_node(node.op, node, *vect_inputs)
# Compatibility with the old API
vect_outputs = (
vect_node_or_outputs.outputs
if isinstance(vect_node_or_outputs, Apply)
else vect_node_or_outputs
)
for output, vect_output in zip(node.outputs, vect_outputs, strict=True):
if output in vect_vars: if output in vect_vars:
# This can happen when some outputs of a multi-output node are given a replacement, # This can happen when some outputs of a multi-output node are given a replacement,
# while some of the remaining outputs are still needed in the graph. # while some of the remaining outputs are still needed in the graph.
......
from pytensor.compile.mode import optdb from pytensor.compile.mode import optdb
from pytensor.graph import Constant, Op, node_rewriter from pytensor.graph import Constant, Op, node_rewriter
from pytensor.graph.destroyhandler import inplace_candidates from pytensor.graph.destroyhandler import inplace_candidates
from pytensor.graph.replace import vectorize_node from pytensor.graph.replace import vectorize_graph
from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter
from pytensor.graph.rewriting.unify import OpPattern, OpPatternOpTypeType from pytensor.graph.rewriting.unify import OpPattern, OpPatternOpTypeType
from pytensor.graph.traversal import apply_ancestors
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
from pytensor.tensor.blockwise import Blockwise, _squeeze_left from pytensor.tensor.blockwise import Blockwise, _squeeze_left
from pytensor.tensor.math import Dot from pytensor.tensor.math import Dot
...@@ -37,10 +38,15 @@ def local_useless_blockwise(fgraph, node): ...@@ -37,10 +38,15 @@ def local_useless_blockwise(fgraph, node):
""" """
op = node.op op = node.op
inputs = node.inputs inputs = node.inputs
dummy_core_node = op._create_dummy_core_node(node.inputs) dummy_core_node, dummy_inputs = op._create_dummy_core_node(
vect_node = vectorize_node(dummy_core_node, *inputs) inputs, return_dummy_inputs=True
if not isinstance(vect_node.op, Blockwise): )
return copy_stack_trace(node.outputs, vect_node.outputs) outputs = vectorize_graph(dummy_core_node.outputs, dict(zip(dummy_inputs, inputs)))
if not any(
isinstance(vect_node.op, Blockwise)
for vect_node in apply_ancestors(outputs, blockers=inputs)
):
return copy_stack_trace(node.outputs, outputs)
@node_rewriter([Blockwise]) @node_rewriter([Blockwise])
......
...@@ -1602,7 +1602,7 @@ def local_blockwise_of_subtensor(fgraph, node): ...@@ -1602,7 +1602,7 @@ def local_blockwise_of_subtensor(fgraph, node):
def local_blockwise_inc_subtensor(fgraph, node): def local_blockwise_inc_subtensor(fgraph, node):
"""Rewrite blockwised inc_subtensors. """Rewrite blockwised inc_subtensors.
Note: The reason we don't apply this rewrite eagerly in the `vectorize_node` dispatch Note: The reason we don't apply this rewrite eagerly in the `_vectorize_node` dispatch
Is that we often have batch dimensions from alloc of shapes/reshape that can be removed by rewrites Is that we often have batch dimensions from alloc of shapes/reshape that can be removed by rewrites
such as x[:vectorized(w.shape[0])].set(y), that will later be rewritten as x[:w.shape[1]].set(y), such as x[:vectorized(w.shape[0])].set(y), that will later be rewritten as x[:w.shape[1]].set(y),
......
...@@ -856,7 +856,7 @@ def _vectorize_reshape(op, node, x, shape): ...@@ -856,7 +856,7 @@ def _vectorize_reshape(op, node, x, shape):
else: else:
raise ValueError("Invalid shape length passed into vectorize node of Reshape") raise ValueError("Invalid shape length passed into vectorize node of Reshape")
return reshape(x, new_shape, ndim=len(new_shape)).owner return reshape(x, new_shape, ndim=len(tuple(new_shape))).owner
def reshape( def reshape(
......
...@@ -6,10 +6,10 @@ import pytensor.tensor as pt ...@@ -6,10 +6,10 @@ import pytensor.tensor as pt
from pytensor import config, function, shared from pytensor import config, function, shared
from pytensor.graph.basic import equal_computations from pytensor.graph.basic import equal_computations
from pytensor.graph.replace import ( from pytensor.graph.replace import (
_vectorize_node,
clone_replace, clone_replace,
graph_replace, graph_replace,
vectorize_graph, vectorize_graph,
vectorize_node,
) )
from pytensor.graph.traversal import graph_inputs from pytensor.graph.traversal import graph_inputs
from pytensor.tensor import dvector, fvector, vector from pytensor.tensor import dvector, fvector, vector
...@@ -277,7 +277,7 @@ class TestVectorizeGraph: ...@@ -277,7 +277,7 @@ class TestVectorizeGraph:
# Cases where either x or both of y1 and y2 are given replacements # Cases where either x or both of y1 and y2 are given replacements
new_out = vectorize_graph(out, {x: new_x}) new_out = vectorize_graph(out, {x: new_x})
expected_new_out = pt.add(*vectorize_node(node, new_x).outputs) expected_new_out = pt.add(*_vectorize_node(node.op, node, new_x).outputs)
assert equal_computations([new_out], [expected_new_out]) assert equal_computations([new_out], [expected_new_out])
new_out = vectorize_graph(out, {y1: new_y1, y2: new_y2}) new_out = vectorize_graph(out, {y1: new_y1, y2: new_y2})
...@@ -291,7 +291,9 @@ class TestVectorizeGraph: ...@@ -291,7 +291,9 @@ class TestVectorizeGraph:
# Special case where x is given a replacement as well as only one of y1 and y2 # Special case where x is given a replacement as well as only one of y1 and y2
# The graph combines the replaced variable with the other vectorized output # The graph combines the replaced variable with the other vectorized output
new_out = vectorize_graph(out, {x: new_x, y1: new_y1}) new_out = vectorize_graph(out, {x: new_x, y1: new_y1})
expected_new_out = pt.add(new_y1, vectorize_node(node, new_x).outputs[1]) expected_new_out = pt.add(
new_y1, _vectorize_node(node.op, node, new_x).outputs[1]
)
assert equal_computations([new_out], [expected_new_out]) assert equal_computations([new_out], [expected_new_out])
def test_multi_output_node_random_variable(self): def test_multi_output_node_random_variable(self):
......
...@@ -7,6 +7,7 @@ from pytensor.compile import get_default_mode ...@@ -7,6 +7,7 @@ from pytensor.compile import get_default_mode
from pytensor.graph.replace import vectorize_graph from pytensor.graph.replace import vectorize_graph
from pytensor.link.numba import NumbaLinker from pytensor.link.numba import NumbaLinker
from pytensor.raise_op import Assert from pytensor.raise_op import Assert
from pytensor.tensor import as_tensor
from pytensor.tensor.math import eq from pytensor.tensor.math import eq
from pytensor.tensor.random import normal from pytensor.tensor.random import normal
from pytensor.tensor.random.basic import NormalRV from pytensor.tensor.random.basic import NormalRV
...@@ -247,7 +248,7 @@ def test_vectorize(): ...@@ -247,7 +248,7 @@ def test_vectorize():
# Test with size, new size provided # Test with size, new size provided
size = pt.as_tensor(np.array((3,), dtype="int64")) size = pt.as_tensor(np.array((3,), dtype="int64"))
out = normal(vec, size=size) out = normal(vec, size=size)
vect_node = vectorize_graph(out, {vec: mat, size: (2, 3)}).owner vect_node = vectorize_graph(out, {vec: mat, size: as_tensor((2, 3))}).owner
assert isinstance(vect_node.op, NormalRV) assert isinstance(vect_node.op, NormalRV)
assert tuple(vect_node.op.size_param(vect_node).eval()) == (2, 3) assert tuple(vect_node.op.size_param(vect_node).eval()) == (2, 3)
assert vect_node.op.dist_params(vect_node)[0] is mat assert vect_node.op.dist_params(vect_node)[0] is mat
......
...@@ -11,7 +11,7 @@ from pytensor.compile import get_default_mode, get_mode ...@@ -11,7 +11,7 @@ from pytensor.compile import get_default_mode, get_mode
from pytensor.compile.function.types import add_supervisor_to_fgraph from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.gradient import grad from pytensor.gradient import grad
from pytensor.graph import Apply, FunctionGraph, Op, rewrite_graph from pytensor.graph import Apply, FunctionGraph, Op, rewrite_graph
from pytensor.graph.replace import vectorize_graph, vectorize_node from pytensor.graph.replace import _vectorize_node, vectorize_graph
from pytensor.link.numba import NumbaLinker from pytensor.link.numba import NumbaLinker
from pytensor.raise_op import assert_op from pytensor.raise_op import assert_op
from pytensor.tensor import ( from pytensor.tensor import (
...@@ -95,8 +95,8 @@ def test_vectorize_blockwise(): ...@@ -95,8 +95,8 @@ def test_vectorize_blockwise():
tns = tensor(shape=(None, None, None)) tns = tensor(shape=(None, None, None))
# Something that falls back to Blockwise # Something that falls back to Blockwise
node = MatrixInverse()(mat).owner out = MatrixInverse()(mat)
vect_node = vectorize_node(node, tns) vect_node = vectorize_graph(out, {mat: tns}).owner
assert isinstance(vect_node.op, Blockwise) and isinstance( assert isinstance(vect_node.op, Blockwise) and isinstance(
vect_node.op.core_op, MatrixInverse vect_node.op.core_op, MatrixInverse
) )
...@@ -105,7 +105,7 @@ def test_vectorize_blockwise(): ...@@ -105,7 +105,7 @@ def test_vectorize_blockwise():
# Useless blockwise # Useless blockwise
tns4 = tensor(shape=(5, None, None, None)) tns4 = tensor(shape=(5, None, None, None))
new_vect_node = vectorize_node(vect_node, tns4) new_vect_node = vectorize_graph(vect_node.out, {tns: tns4}).owner
assert new_vect_node.op is vect_node.op assert new_vect_node.op is vect_node.op
assert isinstance(new_vect_node.op, Blockwise) and isinstance( assert isinstance(new_vect_node.op, Blockwise) and isinstance(
new_vect_node.op.core_op, MatrixInverse new_vect_node.op.core_op, MatrixInverse
...@@ -204,7 +204,7 @@ def test_vectorize_node_default_signature(): ...@@ -204,7 +204,7 @@ def test_vectorize_node_default_signature():
mat = tensor(shape=(5, None)) mat = tensor(shape=(5, None))
node = my_test_op.make_node(vec, mat) node = my_test_op.make_node(vec, mat)
vect_node = vectorize_node(node, mat, mat) vect_node = _vectorize_node(node.op, node, mat, mat)
assert isinstance(vect_node.op, Blockwise) and isinstance( assert isinstance(vect_node.op, Blockwise) and isinstance(
vect_node.op.core_op, MyTestOp vect_node.op.core_op, MyTestOp
) )
......
...@@ -16,7 +16,7 @@ from pytensor.compile.function import function ...@@ -16,7 +16,7 @@ from pytensor.compile.function import function
from pytensor.compile.mode import Mode, get_default_mode from pytensor.compile.mode import Mode, get_default_mode
from pytensor.graph.basic import Apply, Variable from pytensor.graph.basic import Apply, Variable
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import vectorize_node from pytensor.graph.replace import vectorize_graph
from pytensor.link.basic import PerformLinker from pytensor.link.basic import PerformLinker
from pytensor.link.c.basic import CLinker, OpWiseCLinker from pytensor.link.c.basic import CLinker, OpWiseCLinker
from pytensor.link.numba import NumbaLinker from pytensor.link.numba import NumbaLinker
...@@ -1042,46 +1042,39 @@ class TestVectorize: ...@@ -1042,46 +1042,39 @@ class TestVectorize:
vec = tensor(shape=(None,)) vec = tensor(shape=(None,))
mat = tensor(shape=(None, None)) mat = tensor(shape=(None, None))
node = exp(vec).owner out = exp(vec)
vect_node = vectorize_node(node, mat) vect_out = vectorize_graph(out, {vec: mat})
assert vect_node.op == exp assert vect_out.owner.op == exp
assert vect_node.inputs[0] is mat assert vect_out.owner.inputs[0] is mat
def test_dimshuffle(self): def test_dimshuffle(self):
vec = tensor(shape=(None,))
mat = tensor(shape=(None, None))
node = exp(vec).owner
vect_node = vectorize_node(node, mat)
assert vect_node.op == exp
assert vect_node.inputs[0] is mat
col_mat = tensor(shape=(None, 1)) col_mat = tensor(shape=(None, 1))
tcol_mat = tensor(shape=(None, None, 1)) tcol_mat = tensor(shape=(None, None, 1))
node = col_mat.dimshuffle(0).owner # drop column
vect_node = vectorize_node(node, tcol_mat) out = col_mat.dimshuffle(0) # drop column
assert isinstance(vect_node.op, DimShuffle) vect_out = vectorize_graph(out, {col_mat: tcol_mat})
assert vect_node.op.new_order == (0, 1) assert isinstance(vect_out.owner.op, DimShuffle)
assert vect_node.inputs[0] is tcol_mat assert vect_out.owner.op.new_order == (0, 1)
assert vect_node.outputs[0].type.shape == (None, None) assert vect_out.owner.inputs[0] is tcol_mat
assert vect_out.owner.outputs[0].type.shape == (None, None)
def test_CAReduce(self): def test_CAReduce(self):
mat = tensor(shape=(None, None)) mat = tensor(shape=(None, None))
tns = tensor(shape=(None, None, None)) tns = tensor(shape=(None, None, None))
node = pt_sum(mat).owner out = pt_sum(mat)
vect_node = vectorize_node(node, tns) vect_out = vectorize_graph(out, {mat: tns})
assert isinstance(vect_node.op, Sum) assert isinstance(vect_out.owner.op, Sum)
assert vect_node.op.axis == (1, 2) assert vect_out.owner.op.axis == (1, 2)
assert vect_node.inputs[0] is tns assert vect_out.owner.inputs[0] is tns
bool_mat = tensor(dtype="bool", shape=(None, None)) bool_mat = tensor(dtype="bool", shape=(None, None))
bool_tns = tensor(dtype="bool", shape=(None, None, None)) bool_tns = tensor(dtype="bool", shape=(None, None, None))
node = pt_any(bool_mat, axis=-2).owner out = pt_any(bool_mat, axis=-2)
vect_node = vectorize_node(node, bool_tns) vect_out = vectorize_graph(out, {bool_mat: bool_tns})
assert isinstance(vect_node.op, Any) assert isinstance(vect_out.owner.op, Any)
assert vect_node.op.axis == (1,) assert vect_out.owner.op.axis == (1,)
assert vect_node.inputs[0] is bool_tns assert vect_out.owner.inputs[0] is bool_tns
def careduce_benchmark_tester(axis, c_contiguous, mode, benchmark): def careduce_benchmark_tester(axis, c_contiguous, mode, benchmark):
......
...@@ -21,7 +21,7 @@ from pytensor.configdefaults import config ...@@ -21,7 +21,7 @@ from pytensor.configdefaults import config
from pytensor.gradient import NullTypeGradError, grad, numeric_grad from pytensor.gradient import NullTypeGradError, grad, numeric_grad
from pytensor.graph.basic import Variable, equal_computations from pytensor.graph.basic import Variable, equal_computations
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import vectorize_node from pytensor.graph.replace import vectorize_graph
from pytensor.graph.traversal import ancestors, applys_between from pytensor.graph.traversal import ancestors, applys_between
from pytensor.link.c.basic import DualLinker from pytensor.link.c.basic import DualLinker
from pytensor.link.numba import NumbaLinker from pytensor.link.numba import NumbaLinker
...@@ -1070,11 +1070,10 @@ class TestMaxAndArgmax: ...@@ -1070,11 +1070,10 @@ class TestMaxAndArgmax:
argmax_x = argmax(x, axis=core_axis) argmax_x = argmax(x, axis=core_axis)
arg_max_node = argmax_x.owner vect_out = vectorize_graph(argmax_x, {x: batch_x})
new_node = vectorize_node(arg_max_node, batch_x)
assert isinstance(new_node.op, Argmax) assert isinstance(vect_out.owner.op, Argmax)
assert new_node.op.axis == batch_axis assert vect_out.owner.op.axis == batch_axis
class TestArgminArgmax: class TestArgminArgmax:
......
...@@ -8,7 +8,7 @@ from pytensor import In, Mode, Out, function, grad ...@@ -8,7 +8,7 @@ from pytensor import In, Mode, Out, function, grad
from pytensor.compile.ops import DeepCopyOp from pytensor.compile.ops import DeepCopyOp
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Variable, equal_computations from pytensor.graph.basic import Variable, equal_computations
from pytensor.graph.replace import clone_replace, vectorize_node from pytensor.graph.replace import clone_replace, vectorize_graph
from pytensor.graph.type import Type from pytensor.graph.type import Type
from pytensor.scalar.basic import ScalarConstant from pytensor.scalar.basic import ScalarConstant
from pytensor.tensor import as_tensor_variable, broadcast_to, get_vector_length, row from pytensor.tensor import as_tensor_variable, broadcast_to, get_vector_length, row
...@@ -742,9 +742,9 @@ class TestVectorize: ...@@ -742,9 +742,9 @@ class TestVectorize:
def test_shape(self): def test_shape(self):
vec = tensor(shape=(None,), dtype="float64") vec = tensor(shape=(None,), dtype="float64")
mat = tensor(shape=(None, None), dtype="float64") mat = tensor(shape=(None, None), dtype="float64")
node = shape(vec).owner out = shape(vec)
[vect_out] = vectorize_node(node, mat).outputs vect_out = vectorize_graph(out, {vec: mat})
assert equal_computations( assert equal_computations(
[vect_out], [broadcast_to(mat.shape[1:], (*mat.shape[:1], 1))] [vect_out], [broadcast_to(mat.shape[1:], (*mat.shape[:1], 1))]
) )
...@@ -758,8 +758,8 @@ class TestVectorize: ...@@ -758,8 +758,8 @@ class TestVectorize:
mat = tensor(shape=(None, None), dtype="float64") mat = tensor(shape=(None, None), dtype="float64")
tns = tensor(shape=(None, None, None, None), dtype="float64") tns = tensor(shape=(None, None, None, None), dtype="float64")
node = shape(mat).owner out = shape(mat)
[vect_out] = vectorize_node(node, tns).outputs vect_out = vectorize_graph(out, {mat: tns})
assert equal_computations( assert equal_computations(
[vect_out], [broadcast_to(tns.shape[2:], (*tns.shape[:2], 2))] [vect_out], [broadcast_to(tns.shape[2:], (*tns.shape[:2], 2))]
) )
...@@ -779,11 +779,13 @@ class TestVectorize: ...@@ -779,11 +779,13 @@ class TestVectorize:
vec = tensor(shape=(None,), dtype="float64") vec = tensor(shape=(None,), dtype="float64")
mat = tensor(shape=(None, None), dtype="float64") mat = tensor(shape=(None, None), dtype="float64")
shape = (-1, x) shape = as_tensor_variable([-1, x])
node = reshape(vec, shape).owner out = reshape(vec, shape)
[vect_out] = vectorize_node(node, mat, shape).outputs vect_out = vectorize_graph(out, {vec: mat})
assert equal_computations([vect_out], [reshape(mat, (*mat.shape[:1], -1, x))]) utt.assert_equal_computations(
[vect_out], [reshape(mat, (*mat.shape[:1], *stack((-1, x))))]
)
x_test_value = 2 x_test_value = 2
mat_test_value = np.ones((5, 6)) mat_test_value = np.ones((5, 6))
...@@ -795,12 +797,12 @@ class TestVectorize: ...@@ -795,12 +797,12 @@ class TestVectorize:
ref_fn(x_test_value, mat_test_value), ref_fn(x_test_value, mat_test_value),
) )
new_shape = (5, -1, x) new_shape = as_tensor_variable((5, -1, x))
[vect_out] = vectorize_node(node, mat, new_shape).outputs vect_out = vectorize_graph(out, {vec: mat, shape: new_shape})
assert equal_computations([vect_out], [reshape(mat, new_shape)]) utt.assert_equal_computations([vect_out], [reshape(mat, new_shape)])
new_shape = stack([[-1, x], [x - 1, -1]], axis=0) new_shape = stack([[-1, x], [x - 1, -1]], axis=0)
[vect_out] = vectorize_node(node, vec, new_shape).outputs vect_out = vectorize_graph(out, {shape: new_shape})
vec_test_value = np.arange(6) vec_test_value = np.arange(6)
np.testing.assert_allclose( np.testing.assert_allclose(
vect_out.eval({x: 3, vec: vec_test_value}), vect_out.eval({x: 3, vec: vec_test_value}),
...@@ -811,13 +813,13 @@ class TestVectorize: ...@@ -811,13 +813,13 @@ class TestVectorize:
ValueError, ValueError,
match="Invalid shape length passed into vectorize node of Reshape", match="Invalid shape length passed into vectorize node of Reshape",
): ):
vectorize_node(node, vec, (5, 2, x)) vectorize_graph(out, {shape: as_tensor_variable((5, 2, x))})
with pytest.raises( with pytest.raises(
ValueError, ValueError,
match="Invalid shape length passed into vectorize node of Reshape", match="Invalid shape length passed into vectorize node of Reshape",
): ):
vectorize_node(node, mat, (5, 3, 2, x)) vectorize_graph(out, {vec: mat, shape: as_tensor_variable((5, 3, 2, x))})
def test_specify_shape(self): def test_specify_shape(self):
x = scalar("x", dtype=int) x = scalar("x", dtype=int)
...@@ -825,27 +827,9 @@ class TestVectorize: ...@@ -825,27 +827,9 @@ class TestVectorize:
tns = tensor(shape=(None, None, None)) tns = tensor(shape=(None, None, None))
shape = (x, None) shape = (x, None)
node = specify_shape(mat, shape).owner out = specify_shape(mat, shape)
vect_node = vectorize_node(node, tns, *shape) vect_out = vectorize_graph(out, {mat: tns})
assert equal_computations( assert equal_computations([vect_out], [specify_shape(tns, (None, x, None))])
vect_node.outputs, [specify_shape(tns, (None, x, None))]
)
new_shape = (5, 2, x)
vect_node = vectorize_node(node, tns, *new_shape)
assert equal_computations(vect_node.outputs, [specify_shape(tns, (5, 2, x))])
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
vectorize_node(node, mat, *([x, x], None)) vectorize_graph(out, {x: as_tensor_variable([x, x])})
with pytest.raises(
ValueError,
match="Invalid number of shape arguments passed into vectorize node of SpecifyShape",
):
vectorize_node(node, mat, *(5, 2, x))
with pytest.raises(
ValueError,
match="Invalid number of shape arguments passed into vectorize node of SpecifyShape",
):
vectorize_node(node, tns, *(5, 3, 2, x))
...@@ -9,7 +9,7 @@ from scipy.special import softmax as scipy_softmax ...@@ -9,7 +9,7 @@ from scipy.special import softmax as scipy_softmax
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.replace import vectorize_node from pytensor.graph.replace import vectorize_graph
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.special import ( from pytensor.tensor.special import (
LogSoftmax, LogSoftmax,
...@@ -168,18 +168,18 @@ def test_vectorize_softmax(op, constructor, core_axis, batch_axis): ...@@ -168,18 +168,18 @@ def test_vectorize_softmax(op, constructor, core_axis, batch_axis):
x = tensor(shape=(5, 5, 5, 5)) x = tensor(shape=(5, 5, 5, 5))
batch_x = tensor(shape=(3, 5, 5, 5, 5)) batch_x = tensor(shape=(3, 5, 5, 5, 5))
node = constructor(x, axis=core_axis).owner out = constructor(x, axis=core_axis)
assert isinstance(node.op, op) assert isinstance(out.owner.op, op)
new_node = vectorize_node(node, batch_x) new_out = vectorize_graph(out, {x: batch_x})
if len(batch_axis) == 1: if len(batch_axis) == 1:
assert isinstance(new_node.op, op) assert isinstance(new_out.owner.op, op)
assert (new_node.op.axis,) == batch_axis assert (new_out.owner.op.axis,) == batch_axis
else: else:
assert isinstance(new_node.op, Blockwise) and isinstance( assert isinstance(new_out.owner.op, Blockwise) and isinstance(
new_node.op.core_op, op new_out.owner.op.core_op, op
) )
assert new_node.op.core_op.axis == core_axis assert new_out.owner.op.core_op.axis == core_axis
def test_poch(): def test_poch():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论