提交 27781606 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Rename replace/vectorize to replace/vectorize_graph

上级 902eeb6f
...@@ -9,7 +9,7 @@ from pytensor.graph.basic import ( ...@@ -9,7 +9,7 @@ from pytensor.graph.basic import (
clone, clone,
ancestors, ancestors,
) )
from pytensor.graph.replace import clone_replace, graph_replace, vectorize from pytensor.graph.replace import clone_replace, graph_replace, vectorize_graph
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.type import Type from pytensor.graph.type import Type
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
......
import warnings
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import partial, singledispatch from functools import partial, singledispatch
from typing import Optional, Union, cast, overload from typing import Optional, Union, cast, overload
...@@ -215,7 +216,7 @@ def vectorize_node(node: Apply, *batched_inputs) -> Apply: ...@@ -215,7 +216,7 @@ def vectorize_node(node: Apply, *batched_inputs) -> Apply:
@overload @overload
def vectorize( def vectorize_graph(
outputs: Variable, outputs: Variable,
replace: Mapping[Variable, Variable], replace: Mapping[Variable, Variable],
) -> Variable: ) -> Variable:
...@@ -223,14 +224,14 @@ def vectorize( ...@@ -223,14 +224,14 @@ def vectorize(
@overload @overload
def vectorize( def vectorize_graph(
outputs: Sequence[Variable], outputs: Sequence[Variable],
replace: Mapping[Variable, Variable], replace: Mapping[Variable, Variable],
) -> Sequence[Variable]: ) -> Sequence[Variable]:
... ...
def vectorize( def vectorize_graph(
outputs: Union[Variable, Sequence[Variable]], outputs: Union[Variable, Sequence[Variable]],
replace: Mapping[Variable, Variable], replace: Mapping[Variable, Variable],
) -> Union[Variable, Sequence[Variable]]: ) -> Union[Variable, Sequence[Variable]]:
...@@ -309,3 +310,8 @@ def vectorize( ...@@ -309,3 +310,8 @@ def vectorize(
else: else:
[vect_output] = seq_vect_outputs [vect_output] = seq_vect_outputs
return vect_output return vect_output
def vectorize(*args, **kwargs):
warnings.warn("vectorize was renamed to vectorize_graph", UserWarning)
return vectorize_node(*args, **kwargs)
...@@ -9,7 +9,7 @@ from pytensor.gradient import DisconnectedType ...@@ -9,7 +9,7 @@ from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.null_type import NullType from pytensor.graph.null_type import NullType
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node, vectorize from pytensor.graph.replace import _vectorize_node, vectorize_graph
from pytensor.tensor import as_tensor_variable from pytensor.tensor import as_tensor_variable
from pytensor.tensor.shape import shape_padleft from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor
...@@ -274,7 +274,7 @@ class Blockwise(Op): ...@@ -274,7 +274,7 @@ class Blockwise(Op):
core_igrads = self.core_op.L_op(core_inputs, core_outputs, core_ograds) core_igrads = self.core_op.L_op(core_inputs, core_outputs, core_ograds)
igrads = vectorize( igrads = vectorize_graph(
[core_igrad for core_igrad in core_igrads if core_igrad is not None], [core_igrad for core_igrad in core_igrads if core_igrad is not None],
replace=dict( replace=dict(
zip(core_inputs + core_outputs + core_ograds, inputs + outputs + ograds) zip(core_inputs + core_outputs + core_ograds, inputs + outputs + ograds)
......
...@@ -5,7 +5,7 @@ import scipy.special ...@@ -5,7 +5,7 @@ import scipy.special
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config, function, shared from pytensor import config, function, shared
from pytensor.graph.basic import equal_computations, graph_inputs from pytensor.graph.basic import equal_computations, graph_inputs
from pytensor.graph.replace import clone_replace, graph_replace, vectorize from pytensor.graph.replace import clone_replace, graph_replace, vectorize_graph
from pytensor.tensor import dvector, fvector, vector from pytensor.tensor import dvector, fvector, vector
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.graph.utils import MyOp, MyVariable from tests.graph.utils import MyOp, MyVariable
...@@ -226,7 +226,7 @@ class TestGraphReplace: ...@@ -226,7 +226,7 @@ class TestGraphReplace:
oc = graph_replace([o], {fake: x.clone()}, strict=True) oc = graph_replace([o], {fake: x.clone()}, strict=True)
class TestVectorize: class TestVectorizeGraph:
# TODO: Add tests with multiple outputs, constants, and other singleton types # TODO: Add tests with multiple outputs, constants, and other singleton types
def test_basic(self): def test_basic(self):
...@@ -234,10 +234,10 @@ class TestVectorize: ...@@ -234,10 +234,10 @@ class TestVectorize:
y = pt.exp(x) / pt.sum(pt.exp(x)) y = pt.exp(x) / pt.sum(pt.exp(x))
new_x = pt.matrix("new_x") new_x = pt.matrix("new_x")
[new_y] = vectorize([y], {x: new_x}) [new_y] = vectorize_graph([y], {x: new_x})
# Check we can pass both a sequence or a single variable # Check we can pass both a sequence or a single variable
alt_new_y = vectorize(y, {x: new_x}) alt_new_y = vectorize_graph(y, {x: new_x})
assert equal_computations([new_y], [alt_new_y]) assert equal_computations([new_y], [alt_new_y])
fn = function([new_x], new_y) fn = function([new_x], new_y)
...@@ -253,7 +253,7 @@ class TestVectorize: ...@@ -253,7 +253,7 @@ class TestVectorize:
y2 = x[-1] y2 = x[-1]
new_x = pt.matrix("new_x") new_x = pt.matrix("new_x")
[new_y1, new_y2] = vectorize([y1, y2], {x: new_x}) [new_y1, new_y2] = vectorize_graph([y1, y2], {x: new_x})
fn = function([new_x], [new_y1, new_y2]) fn = function([new_x], [new_y1, new_y2])
new_x_test = np.arange(9).reshape(3, 3).astype(config.floatX) new_x_test = np.arange(9).reshape(3, 3).astype(config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论