提交 47efd4b1 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow single variable output in vectorize

Also: * rename `vectorize` kwarg by `replace` * add test for multiple outputs
上级 30e08e23
...@@ -213,9 +213,26 @@ def vectorize_node(node: Apply, *batched_inputs) -> Apply: ...@@ -213,9 +213,26 @@ def vectorize_node(node: Apply, *batched_inputs) -> Apply:
return _vectorize_node(op, node, *batched_inputs) return _vectorize_node(op, node, *batched_inputs)
@overload
def vectorize(
outputs: Variable,
replace: Mapping[Variable, Variable],
) -> Variable:
...
@overload
def vectorize( def vectorize(
outputs: Sequence[Variable], vectorize: Mapping[Variable, Variable] outputs: Sequence[Variable],
replace: Mapping[Variable, Variable],
) -> Sequence[Variable]: ) -> Sequence[Variable]:
...
def vectorize(
outputs: Union[Variable, Sequence[Variable]],
replace: Mapping[Variable, Variable],
) -> Union[Variable, Sequence[Variable]]:
"""Vectorize outputs graph given mapping from old variables to expanded counterparts version. """Vectorize outputs graph given mapping from old variables to expanded counterparts version.
Expanded dimensions must be on the left. Behavior is similar to the functional `numpy.vectorize`. Expanded dimensions must be on the left. Behavior is similar to the functional `numpy.vectorize`.
...@@ -235,20 +252,44 @@ def vectorize( ...@@ -235,20 +252,44 @@ def vectorize(
# Vectorized graph # Vectorized graph
new_x = pt.matrix("new_x") new_x = pt.matrix("new_x")
[new_y] = vectorize([y], {x: new_x}) new_y = vectorize(y, replace={x: new_x})
fn = pytensor.function([new_x], new_y) fn = pytensor.function([new_x], new_y)
fn([[0, 1, 2], [2, 1, 0]]) fn([[0, 1, 2], [2, 1, 0]])
# array([[0.09003057, 0.24472847, 0.66524096], # array([[0.09003057, 0.24472847, 0.66524096],
# [0.66524096, 0.24472847, 0.09003057]]) # [0.66524096, 0.24472847, 0.09003057]])
.. code-block:: python
import pytensor
import pytensor.tensor as pt
from pytensor.graph import vectorize
# Original graph
x = pt.vector("x")
y1 = x[0]
y2 = x[-1]
# Vectorized graph
new_x = pt.matrix("new_x")
[new_y1, new_y2] = vectorize([y1, y2], replace={x: new_x})
fn = pytensor.function([new_x], [new_y1, new_y2])
fn([[-10, 0, 10], [-11, 0, 11]])
# [array([-10., -11.]), array([10., 11.])]
""" """
# Avoid circular import if isinstance(outputs, Sequence):
seq_outputs = outputs
else:
seq_outputs = [outputs]
inputs = truncated_graph_inputs(outputs, ancestors_to_include=vectorize.keys()) inputs = truncated_graph_inputs(seq_outputs, ancestors_to_include=replace.keys())
new_inputs = [vectorize.get(inp, inp) for inp in inputs] new_inputs = [replace.get(inp, inp) for inp in inputs]
def transform(var): def transform(var: Variable) -> Variable:
if var in inputs: if var in inputs:
return new_inputs[inputs.index(var)] return new_inputs[inputs.index(var)]
...@@ -257,7 +298,13 @@ def vectorize( ...@@ -257,7 +298,13 @@ def vectorize(
batched_node = vectorize_node(node, *batched_inputs) batched_node = vectorize_node(node, *batched_inputs)
batched_var = batched_node.outputs[var.owner.outputs.index(var)] batched_var = batched_node.outputs[var.owner.outputs.index(var)]
return batched_var return cast(Variable, batched_var)
# TODO: MergeOptimization or node caching? # TODO: MergeOptimization or node caching?
return [transform(out) for out in outputs] seq_vect_outputs = [transform(out) for out in seq_outputs]
if isinstance(outputs, Sequence):
return seq_vect_outputs
else:
[vect_output] = seq_vect_outputs
return vect_output
...@@ -275,7 +275,7 @@ class Blockwise(Op): ...@@ -275,7 +275,7 @@ class Blockwise(Op):
igrads = vectorize( igrads = vectorize(
[core_igrad for core_igrad in core_igrads if core_igrad is not None], [core_igrad for core_igrad in core_igrads if core_igrad is not None],
vectorize=dict( replace=dict(
zip(core_inputs + core_outputs + core_ograds, inputs + outputs + ograds) zip(core_inputs + core_outputs + core_ograds, inputs + outputs + ograds)
), ),
) )
......
...@@ -4,7 +4,7 @@ import scipy.special ...@@ -4,7 +4,7 @@ import scipy.special
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config, function, shared from pytensor import config, function, shared
from pytensor.graph.basic import graph_inputs from pytensor.graph.basic import equal_computations, graph_inputs
from pytensor.graph.replace import clone_replace, graph_replace, vectorize from pytensor.graph.replace import clone_replace, graph_replace, vectorize
from pytensor.tensor import dvector, fvector, vector from pytensor.tensor import dvector, fvector, vector
from tests import unittest_tools as utt from tests import unittest_tools as utt
...@@ -236,9 +236,27 @@ class TestVectorize: ...@@ -236,9 +236,27 @@ class TestVectorize:
new_x = pt.matrix("new_x") new_x = pt.matrix("new_x")
[new_y] = vectorize([y], {x: new_x}) [new_y] = vectorize([y], {x: new_x})
# Check we can pass both a sequence or a single variable
alt_new_y = vectorize(y, {x: new_x})
assert equal_computations([new_y], [alt_new_y])
fn = function([new_x], new_y) fn = function([new_x], new_y)
test_new_y = np.array([[0, 1, 2], [2, 1, 0]]).astype(config.floatX) test_new_y = np.array([[0, 1, 2], [2, 1, 0]]).astype(config.floatX)
np.testing.assert_allclose( np.testing.assert_allclose(
fn(test_new_y), fn(test_new_y),
scipy.special.softmax(test_new_y, axis=-1), scipy.special.softmax(test_new_y, axis=-1),
) )
def test_multiple_outputs(self):
x = pt.vector("x")
y1 = x[0]
y2 = x[-1]
new_x = pt.matrix("new_x")
[new_y1, new_y2] = vectorize([y1, y2], {x: new_x})
fn = function([new_x], [new_y1, new_y2])
new_x_test = np.arange(9).reshape(3, 3).astype(config.floatX)
new_y1_res, new_y2_res = fn(new_x_test)
np.testing.assert_allclose(new_y1_res, [0, 3, 6])
np.testing.assert_allclose(new_y2_res, [2, 5, 8])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论