提交 d6113953 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Thomas Wiecki

Implement vectorize utility

上级 a3eed0b4
......@@ -9,7 +9,7 @@ from pytensor.graph.basic import (
clone,
ancestors,
)
from pytensor.graph.replace import clone_replace, graph_replace
from pytensor.graph.replace import clone_replace, graph_replace, vectorize
from pytensor.graph.op import Op
from pytensor.graph.type import Type
from pytensor.graph.fg import FunctionGraph
......
from functools import partial
from typing import Iterable, Optional, Sequence, Union, cast, overload
from functools import partial, singledispatch
from typing import Iterable, Mapping, Optional, Sequence, Union, cast, overload
from pytensor.graph.basic import Apply, Constant, Variable, truncated_graph_inputs
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
ReplaceTypes = Union[Iterable[tuple[Variable, Variable]], dict[Variable, Variable]]
......@@ -198,3 +199,65 @@ def graph_replace(
return list(fg.outputs)
else:
return fg.outputs[0]
@singledispatch
def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply:
# Default implementation is provided in pytensor.tensor.blockwise
raise NotImplementedError
def vectorize_node(node: Apply, *batched_inputs) -> Apply:
"""Returns vectorized version of node with new batched inputs."""
op = node.op
return _vectorize_node(op, node, *batched_inputs)
def vectorize(
outputs: Sequence[Variable], vectorize: Mapping[Variable, Variable]
) -> Sequence[Variable]:
"""Vectorize outputs graph given mapping from old variables to expanded counterparts version.
Expanded dimensions must be on the left. Behavior is similar to the functional `numpy.vectorize`.
Examples
--------
.. code-block:: python
import pytensor
import pytensor.tensor as pt
from pytensor.graph import vectorize
# Original graph
x = pt.vector("x")
y = pt.exp(x) / pt.sum(pt.exp(x))
# Vectorized graph
new_x = pt.matrix("new_x")
[new_y] = vectorize([y], {x: new_x})
fn = pytensor.function([new_x], new_y)
fn([[0, 1, 2], [2, 1, 0]])
# array([[0.09003057, 0.24472847, 0.66524096],
# [0.66524096, 0.24472847, 0.09003057]])
"""
# Avoid circular import
inputs = truncated_graph_inputs(outputs, ancestors_to_include=vectorize.keys())
new_inputs = [vectorize.get(inp, inp) for inp in inputs]
def transform(var):
if var in inputs:
return new_inputs[inputs.index(var)]
node = var.owner
batched_inputs = [transform(inp) for inp in node.inputs]
batched_node = vectorize_node(node, *batched_inputs)
batched_var = batched_node.outputs[var.owner.outputs.index(var)]
return batched_var
# TODO: MergeOptimization or node caching?
return [transform(out) for out in outputs]
......@@ -2,7 +2,8 @@ from itertools import chain
from typing import Optional, Sequence, Tuple
from pytensor.compile import rebuild_collect_shared
from pytensor.graph import Constant, FunctionGraph, Variable, clone
from pytensor.graph.basic import Constant, Variable, clone
from pytensor.graph.fg import FunctionGraph
from pytensor.scalar.basic import ScalarInnerGraphOp, as_scalar
......
import re
from functools import singledispatch
from typing import Any, Dict, List, Optional, Sequence, Tuple, cast
import numpy as np
......@@ -9,6 +8,7 @@ from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.null_type import NullType
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node, vectorize
from pytensor.tensor import as_tensor_variable
from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor
......@@ -72,8 +72,8 @@ def safe_signature(
return f"{inputs_sig}->{outputs_sig}"
@singledispatch
def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply:
@_vectorize_node.register(Op)
def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply:
if hasattr(op, "gufunc_signature"):
signature = op.gufunc_signature
else:
......@@ -83,12 +83,6 @@ def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply:
return cast(Apply, Blockwise(op, signature=signature).make_node(*bached_inputs))
def vectorize_node(node: Apply, *batched_inputs) -> Apply:
"""Returns vectorized version of node with new batched inputs."""
op = node.op
return _vectorize_node(op, node, *batched_inputs)
class Blockwise(Op):
"""Generalizes a core `Op` to work with batched dimensions.
......@@ -279,42 +273,18 @@ class Blockwise(Op):
core_igrads = self.core_op.L_op(core_inputs, core_outputs, core_ograds)
batch_ndims = self._batch_ndim_from_outputs(outputs)
def transform(var):
# From a graph of ScalarOps, make a graph of Broadcast ops.
if isinstance(var.type, (NullType, DisconnectedType)):
return var
if var in core_inputs:
return inputs[core_inputs.index(var)]
if var in core_outputs:
return outputs[core_outputs.index(var)]
if var in core_ograds:
return ograds[core_ograds.index(var)]
node = var.owner
# The gradient contains a constant, which may be responsible for broadcasting
if node is None:
if batch_ndims:
var = shape_padleft(var, batch_ndims)
return var
batched_inputs = [transform(inp) for inp in node.inputs]
batched_node = vectorize_node(node, *batched_inputs)
batched_var = batched_node.outputs[var.owner.outputs.index(var)]
return batched_var
ret = []
for core_igrad, ipt in zip(core_igrads, inputs):
# Undefined gradient
if core_igrad is None:
ret.append(None)
else:
ret.append(transform(core_igrad))
igrads = vectorize(
[core_igrad for core_igrad in core_igrads if core_igrad is not None],
vectorize=dict(
zip(core_inputs + core_outputs + core_ograds, inputs + outputs + ograds)
),
)
return ret
igrads_iter = iter(igrads)
return [
None if core_igrad is None else next(igrads_iter)
for core_igrad in core_igrads
]
def L_op(self, inputs, outs, ograds):
from pytensor.tensor.math import sum as pt_sum
......
......@@ -8,6 +8,7 @@ from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply
from pytensor.graph.null_type import NullType
from pytensor.graph.replace import _vectorize_node
from pytensor.graph.utils import MethodNotDefined
from pytensor.link.c.basic import failure_code
from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp
......@@ -22,7 +23,7 @@ from pytensor.scalar.basic import transfer_type, upcast
from pytensor.tensor import elemwise_cgen as cgen
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import _get_vector_length, as_tensor_variable
from pytensor.tensor.blockwise import _vectorize_node, vectorize_not_needed
from pytensor.tensor.blockwise import vectorize_not_needed
from pytensor.tensor.type import (
TensorType,
continuous_dtypes,
......
......@@ -7,6 +7,7 @@ import pytensor
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Variable, equal_computations
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
from pytensor.misc.safe_asarray import _asarray
from pytensor.scalar import ScalarVariable
from pytensor.tensor.basic import (
......@@ -17,7 +18,6 @@ from pytensor.tensor.basic import (
get_vector_length,
infer_static_shape,
)
from pytensor.tensor.blockwise import _vectorize_node
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
from pytensor.tensor.random.utils import (
broadcast_params,
......
from pytensor.compile.mode import optdb
from pytensor.graph import node_rewriter
from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
from pytensor.tensor.blockwise import Blockwise, vectorize_node
from pytensor.tensor.blockwise import Blockwise
@node_rewriter([Blockwise])
......
import numpy as np
import pytest
import scipy.special
import pytensor.tensor as pt
from pytensor import config, function, shared
from pytensor.graph.basic import graph_inputs
from pytensor.graph.replace import clone_replace, graph_replace
from pytensor.graph.replace import clone_replace, graph_replace, vectorize
from pytensor.tensor import dvector, fvector, vector
from tests import unittest_tools as utt
from tests.graph.utils import MyOp, MyVariable
......@@ -223,3 +224,21 @@ class TestGraphReplace:
assert oc[0] is o
with pytest.raises(ValueError, match="Some replacements were not used"):
oc = graph_replace([o], {fake: x.clone()}, strict=True)
class TestVectorize:
# TODO: Add tests with multiple outputs, constants, and other singleton types
def test_basic(self):
x = pt.vector("x")
y = pt.exp(x) / pt.sum(pt.exp(x))
new_x = pt.matrix("new_x")
[new_y] = vectorize([y], {x: new_x})
fn = function([new_x], new_y)
test_new_y = np.array([[0, 1, 2], [2, 1, 0]]).astype(config.floatX)
np.testing.assert_allclose(
fn(test_new_y),
scipy.special.softmax(test_new_y, axis=-1),
)
......@@ -4,8 +4,8 @@ import pytest
import pytensor.tensor as at
from pytensor import config, function
from pytensor.gradient import NullTypeGradError, grad
from pytensor.graph.replace import vectorize_node
from pytensor.raise_op import Assert
from pytensor.tensor.blockwise import vectorize_node
from pytensor.tensor.math import eq
from pytensor.tensor.random import normal
from pytensor.tensor.random.op import RandomState, RandomVariable, default_rng
......
......@@ -8,8 +8,9 @@ import pytensor
from pytensor import config
from pytensor.gradient import grad
from pytensor.graph import Apply, Op
from pytensor.graph.replace import vectorize_node
from pytensor.tensor import tensor
from pytensor.tensor.blockwise import Blockwise, _parse_gufunc_signature, vectorize_node
from pytensor.tensor.blockwise import Blockwise, _parse_gufunc_signature
from pytensor.tensor.nlinalg import MatrixInverse
from pytensor.tensor.slinalg import Cholesky, Solve
......
......@@ -13,11 +13,11 @@ from pytensor.compile.mode import Mode
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import vectorize_node
from pytensor.link.basic import PerformLinker
from pytensor.link.c.basic import CLinker, OpWiseCLinker
from pytensor.tensor import as_tensor_variable
from pytensor.tensor.basic import second
from pytensor.tensor.blockwise import vectorize_node
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import Any, Sum
from pytensor.tensor.math import all as pt_all
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论