提交 d6113953 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Thomas Wiecki

Implement vectorize utility

上级 a3eed0b4
...@@ -9,7 +9,7 @@ from pytensor.graph.basic import ( ...@@ -9,7 +9,7 @@ from pytensor.graph.basic import (
clone, clone,
ancestors, ancestors,
) )
from pytensor.graph.replace import clone_replace, graph_replace from pytensor.graph.replace import clone_replace, graph_replace, vectorize
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.type import Type from pytensor.graph.type import Type
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
......
from functools import partial from functools import partial, singledispatch
from typing import Iterable, Optional, Sequence, Union, cast, overload from typing import Iterable, Mapping, Optional, Sequence, Union, cast, overload
from pytensor.graph.basic import Apply, Constant, Variable, truncated_graph_inputs from pytensor.graph.basic import Apply, Constant, Variable, truncated_graph_inputs
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
ReplaceTypes = Union[Iterable[tuple[Variable, Variable]], dict[Variable, Variable]] ReplaceTypes = Union[Iterable[tuple[Variable, Variable]], dict[Variable, Variable]]
...@@ -198,3 +199,65 @@ def graph_replace( ...@@ -198,3 +199,65 @@ def graph_replace(
return list(fg.outputs) return list(fg.outputs)
else: else:
return fg.outputs[0] return fg.outputs[0]
@singledispatch
def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply:
# Default implementation is provided in pytensor.tensor.blockwise
raise NotImplementedError
def vectorize_node(node: Apply, *batched_inputs) -> Apply:
"""Returns vectorized version of node with new batched inputs."""
op = node.op
return _vectorize_node(op, node, *batched_inputs)
def vectorize(
outputs: Sequence[Variable], vectorize: Mapping[Variable, Variable]
) -> Sequence[Variable]:
"""Vectorize outputs graph given mapping from old variables to expanded counterparts version.
Expanded dimensions must be on the left. Behavior is similar to the functional `numpy.vectorize`.
Examples
--------
.. code-block:: python
import pytensor
import pytensor.tensor as pt
from pytensor.graph import vectorize
# Original graph
x = pt.vector("x")
y = pt.exp(x) / pt.sum(pt.exp(x))
# Vectorized graph
new_x = pt.matrix("new_x")
[new_y] = vectorize([y], {x: new_x})
fn = pytensor.function([new_x], new_y)
fn([[0, 1, 2], [2, 1, 0]])
# array([[0.09003057, 0.24472847, 0.66524096],
# [0.66524096, 0.24472847, 0.09003057]])
"""
# Avoid circular import
inputs = truncated_graph_inputs(outputs, ancestors_to_include=vectorize.keys())
new_inputs = [vectorize.get(inp, inp) for inp in inputs]
def transform(var):
if var in inputs:
return new_inputs[inputs.index(var)]
node = var.owner
batched_inputs = [transform(inp) for inp in node.inputs]
batched_node = vectorize_node(node, *batched_inputs)
batched_var = batched_node.outputs[var.owner.outputs.index(var)]
return batched_var
# TODO: MergeOptimization or node caching?
return [transform(out) for out in outputs]
...@@ -2,7 +2,8 @@ from itertools import chain ...@@ -2,7 +2,8 @@ from itertools import chain
from typing import Optional, Sequence, Tuple from typing import Optional, Sequence, Tuple
from pytensor.compile import rebuild_collect_shared from pytensor.compile import rebuild_collect_shared
from pytensor.graph import Constant, FunctionGraph, Variable, clone from pytensor.graph.basic import Constant, Variable, clone
from pytensor.graph.fg import FunctionGraph
from pytensor.scalar.basic import ScalarInnerGraphOp, as_scalar from pytensor.scalar.basic import ScalarInnerGraphOp, as_scalar
......
import re import re
from functools import singledispatch
from typing import Any, Dict, List, Optional, Sequence, Tuple, cast from typing import Any, Dict, List, Optional, Sequence, Tuple, cast
import numpy as np import numpy as np
...@@ -9,6 +8,7 @@ from pytensor.gradient import DisconnectedType ...@@ -9,6 +8,7 @@ from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.null_type import NullType from pytensor.graph.null_type import NullType
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node, vectorize
from pytensor.tensor import as_tensor_variable from pytensor.tensor import as_tensor_variable
from pytensor.tensor.shape import shape_padleft from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor
...@@ -72,8 +72,8 @@ def safe_signature( ...@@ -72,8 +72,8 @@ def safe_signature(
return f"{inputs_sig}->{outputs_sig}" return f"{inputs_sig}->{outputs_sig}"
@singledispatch @_vectorize_node.register(Op)
def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply: def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply:
if hasattr(op, "gufunc_signature"): if hasattr(op, "gufunc_signature"):
signature = op.gufunc_signature signature = op.gufunc_signature
else: else:
...@@ -83,12 +83,6 @@ def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply: ...@@ -83,12 +83,6 @@ def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply:
return cast(Apply, Blockwise(op, signature=signature).make_node(*bached_inputs)) return cast(Apply, Blockwise(op, signature=signature).make_node(*bached_inputs))
def vectorize_node(node: Apply, *batched_inputs) -> Apply:
"""Returns vectorized version of node with new batched inputs."""
op = node.op
return _vectorize_node(op, node, *batched_inputs)
class Blockwise(Op): class Blockwise(Op):
"""Generalizes a core `Op` to work with batched dimensions. """Generalizes a core `Op` to work with batched dimensions.
...@@ -279,42 +273,18 @@ class Blockwise(Op): ...@@ -279,42 +273,18 @@ class Blockwise(Op):
core_igrads = self.core_op.L_op(core_inputs, core_outputs, core_ograds) core_igrads = self.core_op.L_op(core_inputs, core_outputs, core_ograds)
batch_ndims = self._batch_ndim_from_outputs(outputs) igrads = vectorize(
[core_igrad for core_igrad in core_igrads if core_igrad is not None],
def transform(var): vectorize=dict(
# From a graph of ScalarOps, make a graph of Broadcast ops. zip(core_inputs + core_outputs + core_ograds, inputs + outputs + ograds)
if isinstance(var.type, (NullType, DisconnectedType)): ),
return var )
if var in core_inputs:
return inputs[core_inputs.index(var)]
if var in core_outputs:
return outputs[core_outputs.index(var)]
if var in core_ograds:
return ograds[core_ograds.index(var)]
node = var.owner
# The gradient contains a constant, which may be responsible for broadcasting
if node is None:
if batch_ndims:
var = shape_padleft(var, batch_ndims)
return var
batched_inputs = [transform(inp) for inp in node.inputs]
batched_node = vectorize_node(node, *batched_inputs)
batched_var = batched_node.outputs[var.owner.outputs.index(var)]
return batched_var
ret = []
for core_igrad, ipt in zip(core_igrads, inputs):
# Undefined gradient
if core_igrad is None:
ret.append(None)
else:
ret.append(transform(core_igrad))
return ret igrads_iter = iter(igrads)
return [
None if core_igrad is None else next(igrads_iter)
for core_igrad in core_igrads
]
def L_op(self, inputs, outs, ograds): def L_op(self, inputs, outs, ograds):
from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.math import sum as pt_sum
......
...@@ -8,6 +8,7 @@ from pytensor.configdefaults import config ...@@ -8,6 +8,7 @@ from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.graph.null_type import NullType from pytensor.graph.null_type import NullType
from pytensor.graph.replace import _vectorize_node
from pytensor.graph.utils import MethodNotDefined from pytensor.graph.utils import MethodNotDefined
from pytensor.link.c.basic import failure_code from pytensor.link.c.basic import failure_code
from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp
...@@ -22,7 +23,7 @@ from pytensor.scalar.basic import transfer_type, upcast ...@@ -22,7 +23,7 @@ from pytensor.scalar.basic import transfer_type, upcast
from pytensor.tensor import elemwise_cgen as cgen from pytensor.tensor import elemwise_cgen as cgen
from pytensor.tensor import get_vector_length from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import _get_vector_length, as_tensor_variable from pytensor.tensor.basic import _get_vector_length, as_tensor_variable
from pytensor.tensor.blockwise import _vectorize_node, vectorize_not_needed from pytensor.tensor.blockwise import vectorize_not_needed
from pytensor.tensor.type import ( from pytensor.tensor.type import (
TensorType, TensorType,
continuous_dtypes, continuous_dtypes,
......
...@@ -7,6 +7,7 @@ import pytensor ...@@ -7,6 +7,7 @@ import pytensor
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Variable, equal_computations from pytensor.graph.basic import Apply, Variable, equal_computations
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
from pytensor.misc.safe_asarray import _asarray from pytensor.misc.safe_asarray import _asarray
from pytensor.scalar import ScalarVariable from pytensor.scalar import ScalarVariable
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
...@@ -17,7 +18,6 @@ from pytensor.tensor.basic import ( ...@@ -17,7 +18,6 @@ from pytensor.tensor.basic import (
get_vector_length, get_vector_length,
infer_static_shape, infer_static_shape,
) )
from pytensor.tensor.blockwise import _vectorize_node
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
from pytensor.tensor.random.utils import ( from pytensor.tensor.random.utils import (
broadcast_params, broadcast_params,
......
from pytensor.compile.mode import optdb from pytensor.compile.mode import optdb
from pytensor.graph import node_rewriter from pytensor.graph import node_rewriter
from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
from pytensor.tensor.blockwise import Blockwise, vectorize_node from pytensor.tensor.blockwise import Blockwise
@node_rewriter([Blockwise]) @node_rewriter([Blockwise])
......
import numpy as np import numpy as np
import pytest import pytest
import scipy.special
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config, function, shared from pytensor import config, function, shared
from pytensor.graph.basic import graph_inputs from pytensor.graph.basic import graph_inputs
from pytensor.graph.replace import clone_replace, graph_replace from pytensor.graph.replace import clone_replace, graph_replace, vectorize
from pytensor.tensor import dvector, fvector, vector from pytensor.tensor import dvector, fvector, vector
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.graph.utils import MyOp, MyVariable from tests.graph.utils import MyOp, MyVariable
...@@ -223,3 +224,21 @@ class TestGraphReplace: ...@@ -223,3 +224,21 @@ class TestGraphReplace:
assert oc[0] is o assert oc[0] is o
with pytest.raises(ValueError, match="Some replacements were not used"): with pytest.raises(ValueError, match="Some replacements were not used"):
oc = graph_replace([o], {fake: x.clone()}, strict=True) oc = graph_replace([o], {fake: x.clone()}, strict=True)
class TestVectorize:
# TODO: Add tests with multiple outputs, constants, and other singleton types
def test_basic(self):
x = pt.vector("x")
y = pt.exp(x) / pt.sum(pt.exp(x))
new_x = pt.matrix("new_x")
[new_y] = vectorize([y], {x: new_x})
fn = function([new_x], new_y)
test_new_y = np.array([[0, 1, 2], [2, 1, 0]]).astype(config.floatX)
np.testing.assert_allclose(
fn(test_new_y),
scipy.special.softmax(test_new_y, axis=-1),
)
...@@ -4,8 +4,8 @@ import pytest ...@@ -4,8 +4,8 @@ import pytest
import pytensor.tensor as at import pytensor.tensor as at
from pytensor import config, function from pytensor import config, function
from pytensor.gradient import NullTypeGradError, grad from pytensor.gradient import NullTypeGradError, grad
from pytensor.graph.replace import vectorize_node
from pytensor.raise_op import Assert from pytensor.raise_op import Assert
from pytensor.tensor.blockwise import vectorize_node
from pytensor.tensor.math import eq from pytensor.tensor.math import eq
from pytensor.tensor.random import normal from pytensor.tensor.random import normal
from pytensor.tensor.random.op import RandomState, RandomVariable, default_rng from pytensor.tensor.random.op import RandomState, RandomVariable, default_rng
......
...@@ -8,8 +8,9 @@ import pytensor ...@@ -8,8 +8,9 @@ import pytensor
from pytensor import config from pytensor import config
from pytensor.gradient import grad from pytensor.gradient import grad
from pytensor.graph import Apply, Op from pytensor.graph import Apply, Op
from pytensor.graph.replace import vectorize_node
from pytensor.tensor import tensor from pytensor.tensor import tensor
from pytensor.tensor.blockwise import Blockwise, _parse_gufunc_signature, vectorize_node from pytensor.tensor.blockwise import Blockwise, _parse_gufunc_signature
from pytensor.tensor.nlinalg import MatrixInverse from pytensor.tensor.nlinalg import MatrixInverse
from pytensor.tensor.slinalg import Cholesky, Solve from pytensor.tensor.slinalg import Cholesky, Solve
......
...@@ -13,11 +13,11 @@ from pytensor.compile.mode import Mode ...@@ -13,11 +13,11 @@ from pytensor.compile.mode import Mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Variable from pytensor.graph.basic import Apply, Variable
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import vectorize_node
from pytensor.link.basic import PerformLinker from pytensor.link.basic import PerformLinker
from pytensor.link.c.basic import CLinker, OpWiseCLinker from pytensor.link.c.basic import CLinker, OpWiseCLinker
from pytensor.tensor import as_tensor_variable from pytensor.tensor import as_tensor_variable
from pytensor.tensor.basic import second from pytensor.tensor.basic import second
from pytensor.tensor.blockwise import vectorize_node
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import Any, Sum from pytensor.tensor.math import Any, Sum
from pytensor.tensor.math import all as pt_all from pytensor.tensor.math import all as pt_all
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论