提交 1c2bc8fe authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove rarely used shape_i helpers

上级 d9b3924f
...@@ -42,7 +42,6 @@ from pytensor.tensor.shape import ( ...@@ -42,7 +42,6 @@ from pytensor.tensor.shape import (
Shape_i, Shape_i,
SpecifyShape, SpecifyShape,
Unbroadcast, Unbroadcast,
shape_i,
specify_shape, specify_shape,
unbroadcast, unbroadcast,
) )
...@@ -1060,7 +1059,7 @@ def local_Shape_of_SpecifyShape(fgraph, node): ...@@ -1060,7 +1059,7 @@ def local_Shape_of_SpecifyShape(fgraph, node):
# Replace `NoneConst` by `shape_i` # Replace `NoneConst` by `shape_i`
for i, sh in enumerate(shape): for i, sh in enumerate(shape):
if NoneConst.equals(sh): if NoneConst.equals(sh):
shape[i] = shape_i(x, i, fgraph) shape[i] = x.shape[i]
return [stack(shape).astype(np.int64)] return [stack(shape).astype(np.int64)]
......
...@@ -363,16 +363,6 @@ def shape_i(var, i, fgraph=None): ...@@ -363,16 +363,6 @@ def shape_i(var, i, fgraph=None):
return shape(var)[i] return shape(var)[i]
def shape_i_op(i):
key = i
if key not in shape_i_op.cache:
shape_i_op.cache[key] = Shape_i(i)
return shape_i_op.cache[key]
shape_i_op.cache = {} # type: ignore
def register_shape_i_c_code(typ, code, check_input, version=()): def register_shape_i_c_code(typ, code, check_input, version=()):
""" """
Tell Shape_i how to generate C code for an PyTensor Type. Tell Shape_i how to generate C code for an PyTensor Type.
......
...@@ -38,7 +38,7 @@ from pytensor.tensor.blockwise import vectorize_node_fallback ...@@ -38,7 +38,7 @@ from pytensor.tensor.blockwise import vectorize_node_fallback
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
from pytensor.tensor.math import clip from pytensor.tensor.math import clip
from pytensor.tensor.shape import Reshape, shape_i, specify_broadcastable from pytensor.tensor.shape import Reshape, Shape_i, specify_broadcastable
from pytensor.tensor.type import ( from pytensor.tensor.type import (
TensorType, TensorType,
bscalar, bscalar,
...@@ -2705,10 +2705,9 @@ class AdvancedSubtensor(Op): ...@@ -2705,10 +2705,9 @@ class AdvancedSubtensor(Op):
index_shapes = [] index_shapes = []
for idx, ishape in zip(indices, ishapes[1:]): for idx, ishape in zip(indices, ishapes[1:]):
# Mixed bool indexes are converted to nonzero entries # Mixed bool indexes are converted to nonzero entries
shape0_op = Shape_i(0)
if is_bool_index(idx): if is_bool_index(idx):
index_shapes.extend( index_shapes.extend((shape0_op(nz_dim),) for nz_dim in nonzero(idx))
(shape_i(nz_dim, 0, fgraph=fgraph),) for nz_dim in nonzero(idx)
)
# The `ishapes` entries for `SliceType`s will be None, and # The `ishapes` entries for `SliceType`s will be None, and
# we need to give `indexed_result_shape` the actual slices. # we need to give `indexed_result_shape` the actual slices.
elif isinstance(getattr(idx, "type", None), SliceType): elif isinstance(getattr(idx, "type", None), SliceType):
......
...@@ -8,7 +8,6 @@ from pytensor import Mode, function, grad ...@@ -8,7 +8,6 @@ from pytensor import Mode, function, grad
from pytensor.compile.ops import DeepCopyOp from pytensor.compile.ops import DeepCopyOp
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Variable, equal_computations from pytensor.graph.basic import Variable, equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import clone_replace, vectorize_node from pytensor.graph.replace import clone_replace, vectorize_node
from pytensor.graph.type import Type from pytensor.graph.type import Type
from pytensor.misc.safe_asarray import _asarray from pytensor.misc.safe_asarray import _asarray
...@@ -16,7 +15,6 @@ from pytensor.scalar.basic import ScalarConstant ...@@ -16,7 +15,6 @@ from pytensor.scalar.basic import ScalarConstant
from pytensor.tensor import as_tensor_variable, broadcast_to, get_vector_length, row from pytensor.tensor import as_tensor_variable, broadcast_to, get_vector_length, row
from pytensor.tensor.basic import MakeVector, constant, stack from pytensor.tensor.basic import MakeVector, constant, stack
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.shape import ( from pytensor.tensor.shape import (
Reshape, Reshape,
Shape, Shape,
...@@ -26,7 +24,6 @@ from pytensor.tensor.shape import ( ...@@ -26,7 +24,6 @@ from pytensor.tensor.shape import (
_specify_shape, _specify_shape,
reshape, reshape,
shape, shape,
shape_i,
shape_tuple, shape_tuple,
specify_broadcastable, specify_broadcastable,
specify_shape, specify_shape,
...@@ -633,13 +630,12 @@ def test_nonstandard_shapes(): ...@@ -633,13 +630,12 @@ def test_nonstandard_shapes():
tl_shape = shape(tl) tl_shape = shape(tl)
assert np.array_equal(tl_shape.get_test_value(), (2, 2, 3, 4)) assert np.array_equal(tl_shape.get_test_value(), (2, 2, 3, 4))
# There's no `FunctionGraph`, so it should return a `Subtensor` # Test specific dim
tl_shape_i = shape_i(tl, 0) tl_shape_i = shape(tl)[0]
assert isinstance(tl_shape_i.owner.op, Subtensor) assert isinstance(tl_shape_i.owner.op, Subtensor)
assert tl_shape_i.get_test_value() == 2 assert tl_shape_i.get_test_value() == 2
tl_fg = FunctionGraph([a, b], [tl], features=[ShapeFeature()]) tl_shape_i = Shape_i(0)(tl)
tl_shape_i = shape_i(tl, 0, fgraph=tl_fg)
assert not isinstance(tl_shape_i.owner.op, Subtensor) assert not isinstance(tl_shape_i.owner.op, Subtensor)
assert tl_shape_i.get_test_value() == 2 assert tl_shape_i.get_test_value() == 2
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论