提交 27781606 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Rename replace/vectorize to replace/vectorize_graph

上级 902eeb6f
......@@ -9,7 +9,7 @@ from pytensor.graph.basic import (
clone,
ancestors,
)
from pytensor.graph.replace import clone_replace, graph_replace, vectorize
from pytensor.graph.replace import clone_replace, graph_replace, vectorize_graph
from pytensor.graph.op import Op
from pytensor.graph.type import Type
from pytensor.graph.fg import FunctionGraph
......
import warnings
from collections.abc import Iterable, Mapping, Sequence
from functools import partial, singledispatch
from typing import Optional, Union, cast, overload
......@@ -215,7 +216,7 @@ def vectorize_node(node: Apply, *batched_inputs) -> Apply:
@overload
def vectorize(
def vectorize_graph(
outputs: Variable,
replace: Mapping[Variable, Variable],
) -> Variable:
......@@ -223,14 +224,14 @@ def vectorize(
@overload
def vectorize(
def vectorize_graph(
outputs: Sequence[Variable],
replace: Mapping[Variable, Variable],
) -> Sequence[Variable]:
...
def vectorize(
def vectorize_graph(
outputs: Union[Variable, Sequence[Variable]],
replace: Mapping[Variable, Variable],
) -> Union[Variable, Sequence[Variable]]:
......@@ -309,3 +310,8 @@ def vectorize(
else:
[vect_output] = seq_vect_outputs
return vect_output
def vectorize(*args, **kwargs):
warnings.warn("vectorize was renamed to vectorize_graph", UserWarning)
return vectorize_node(*args, **kwargs)
......@@ -9,7 +9,7 @@ from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.null_type import NullType
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node, vectorize
from pytensor.graph.replace import _vectorize_node, vectorize_graph
from pytensor.tensor import as_tensor_variable
from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor
......@@ -274,7 +274,7 @@ class Blockwise(Op):
core_igrads = self.core_op.L_op(core_inputs, core_outputs, core_ograds)
igrads = vectorize(
igrads = vectorize_graph(
[core_igrad for core_igrad in core_igrads if core_igrad is not None],
replace=dict(
zip(core_inputs + core_outputs + core_ograds, inputs + outputs + ograds)
......
......@@ -5,7 +5,7 @@ import scipy.special
import pytensor.tensor as pt
from pytensor import config, function, shared
from pytensor.graph.basic import equal_computations, graph_inputs
from pytensor.graph.replace import clone_replace, graph_replace, vectorize
from pytensor.graph.replace import clone_replace, graph_replace, vectorize_graph
from pytensor.tensor import dvector, fvector, vector
from tests import unittest_tools as utt
from tests.graph.utils import MyOp, MyVariable
......@@ -226,7 +226,7 @@ class TestGraphReplace:
oc = graph_replace([o], {fake: x.clone()}, strict=True)
class TestVectorize:
class TestVectorizeGraph:
# TODO: Add tests with multiple outputs, constants, and other singleton types
def test_basic(self):
......@@ -234,10 +234,10 @@ class TestVectorize:
y = pt.exp(x) / pt.sum(pt.exp(x))
new_x = pt.matrix("new_x")
[new_y] = vectorize([y], {x: new_x})
[new_y] = vectorize_graph([y], {x: new_x})
# Check we can pass both a sequence or a single variable
alt_new_y = vectorize(y, {x: new_x})
alt_new_y = vectorize_graph(y, {x: new_x})
assert equal_computations([new_y], [alt_new_y])
fn = function([new_x], new_y)
......@@ -253,7 +253,7 @@ class TestVectorize:
y2 = x[-1]
new_x = pt.matrix("new_x")
[new_y1, new_y2] = vectorize([y1, y2], {x: new_x})
[new_y1, new_y2] = vectorize_graph([y1, y2], {x: new_x})
fn = function([new_x], [new_y1, new_y2])
new_x_test = np.arange(9).reshape(3, 3).astype(config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论