提交 47efd4b1 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow single variable output in vectorize

Also: * rename `vectorize` kwarg by `replace` * add test for multiple outputs
上级 30e08e23
......@@ -213,9 +213,26 @@ def vectorize_node(node: Apply, *batched_inputs) -> Apply:
return _vectorize_node(op, node, *batched_inputs)
@overload
def vectorize(
outputs: Variable,
replace: Mapping[Variable, Variable],
) -> Variable:
...
@overload
def vectorize(
outputs: Sequence[Variable], vectorize: Mapping[Variable, Variable]
outputs: Sequence[Variable],
replace: Mapping[Variable, Variable],
) -> Sequence[Variable]:
...
def vectorize(
outputs: Union[Variable, Sequence[Variable]],
replace: Mapping[Variable, Variable],
) -> Union[Variable, Sequence[Variable]]:
"""Vectorize outputs graph given mapping from old variables to expanded counterparts version.
Expanded dimensions must be on the left. Behavior is similar to the functional `numpy.vectorize`.
......@@ -235,20 +252,44 @@ def vectorize(
# Vectorized graph
new_x = pt.matrix("new_x")
[new_y] = vectorize([y], {x: new_x})
new_y = vectorize(y, replace={x: new_x})
fn = pytensor.function([new_x], new_y)
fn([[0, 1, 2], [2, 1, 0]])
# array([[0.09003057, 0.24472847, 0.66524096],
# [0.66524096, 0.24472847, 0.09003057]])
.. code-block:: python
import pytensor
import pytensor.tensor as pt
from pytensor.graph import vectorize
# Original graph
x = pt.vector("x")
y1 = x[0]
y2 = x[-1]
# Vectorized graph
new_x = pt.matrix("new_x")
[new_y1, new_y2] = vectorize([y1, y2], replace={x: new_x})
fn = pytensor.function([new_x], [new_y1, new_y2])
fn([[-10, 0, 10], [-11, 0, 11]])
# [array([-10., -11.]), array([10., 11.])]
"""
# Avoid circular import
if isinstance(outputs, Sequence):
seq_outputs = outputs
else:
seq_outputs = [outputs]
inputs = truncated_graph_inputs(outputs, ancestors_to_include=vectorize.keys())
new_inputs = [vectorize.get(inp, inp) for inp in inputs]
inputs = truncated_graph_inputs(seq_outputs, ancestors_to_include=replace.keys())
new_inputs = [replace.get(inp, inp) for inp in inputs]
def transform(var):
def transform(var: Variable) -> Variable:
if var in inputs:
return new_inputs[inputs.index(var)]
......@@ -257,7 +298,13 @@ def vectorize(
batched_node = vectorize_node(node, *batched_inputs)
batched_var = batched_node.outputs[var.owner.outputs.index(var)]
return batched_var
return cast(Variable, batched_var)
# TODO: MergeOptimization or node caching?
return [transform(out) for out in outputs]
seq_vect_outputs = [transform(out) for out in seq_outputs]
if isinstance(outputs, Sequence):
return seq_vect_outputs
else:
[vect_output] = seq_vect_outputs
return vect_output
......@@ -275,7 +275,7 @@ class Blockwise(Op):
igrads = vectorize(
[core_igrad for core_igrad in core_igrads if core_igrad is not None],
vectorize=dict(
replace=dict(
zip(core_inputs + core_outputs + core_ograds, inputs + outputs + ograds)
),
)
......
......@@ -4,7 +4,7 @@ import scipy.special
import pytensor.tensor as pt
from pytensor import config, function, shared
from pytensor.graph.basic import graph_inputs
from pytensor.graph.basic import equal_computations, graph_inputs
from pytensor.graph.replace import clone_replace, graph_replace, vectorize
from pytensor.tensor import dvector, fvector, vector
from tests import unittest_tools as utt
......@@ -236,9 +236,27 @@ class TestVectorize:
new_x = pt.matrix("new_x")
[new_y] = vectorize([y], {x: new_x})
# Check we can pass both a sequence or a single variable
alt_new_y = vectorize(y, {x: new_x})
assert equal_computations([new_y], [alt_new_y])
fn = function([new_x], new_y)
test_new_y = np.array([[0, 1, 2], [2, 1, 0]]).astype(config.floatX)
np.testing.assert_allclose(
fn(test_new_y),
scipy.special.softmax(test_new_y, axis=-1),
)
def test_multiple_outputs(self):
x = pt.vector("x")
y1 = x[0]
y2 = x[-1]
new_x = pt.matrix("new_x")
[new_y1, new_y2] = vectorize([y1, y2], {x: new_x})
fn = function([new_x], [new_y1, new_y2])
new_x_test = np.arange(9).reshape(3, 3).astype(config.floatX)
new_y1_res, new_y2_res = fn(new_x_test)
np.testing.assert_allclose(new_y1_res, [0, 3, 6])
np.testing.assert_allclose(new_y2_res, [2, 5, 8])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论