提交 1c2bc8fe authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove rarely used shape_i helpers

上级 d9b3924f
......@@ -42,7 +42,6 @@ from pytensor.tensor.shape import (
Shape_i,
SpecifyShape,
Unbroadcast,
shape_i,
specify_shape,
unbroadcast,
)
......@@ -1060,7 +1059,7 @@ def local_Shape_of_SpecifyShape(fgraph, node):
# Replace `NoneConst` by `shape_i`
for i, sh in enumerate(shape):
if NoneConst.equals(sh):
shape[i] = shape_i(x, i, fgraph)
shape[i] = x.shape[i]
return [stack(shape).astype(np.int64)]
......
......@@ -363,16 +363,6 @@ def shape_i(var, i, fgraph=None):
return shape(var)[i]
def shape_i_op(i):
key = i
if key not in shape_i_op.cache:
shape_i_op.cache[key] = Shape_i(i)
return shape_i_op.cache[key]
shape_i_op.cache = {} # type: ignore
def register_shape_i_c_code(typ, code, check_input, version=()):
"""
Tell Shape_i how to generate C code for an PyTensor Type.
......
......@@ -38,7 +38,7 @@ from pytensor.tensor.blockwise import vectorize_node_fallback
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
from pytensor.tensor.math import clip
from pytensor.tensor.shape import Reshape, shape_i, specify_broadcastable
from pytensor.tensor.shape import Reshape, Shape_i, specify_broadcastable
from pytensor.tensor.type import (
TensorType,
bscalar,
......@@ -2705,10 +2705,9 @@ class AdvancedSubtensor(Op):
index_shapes = []
for idx, ishape in zip(indices, ishapes[1:]):
# Mixed bool indexes are converted to nonzero entries
shape0_op = Shape_i(0)
if is_bool_index(idx):
index_shapes.extend(
(shape_i(nz_dim, 0, fgraph=fgraph),) for nz_dim in nonzero(idx)
)
index_shapes.extend((shape0_op(nz_dim),) for nz_dim in nonzero(idx))
# The `ishapes` entries for `SliceType`s will be None, and
# we need to give `indexed_result_shape` the actual slices.
elif isinstance(getattr(idx, "type", None), SliceType):
......
......@@ -8,7 +8,6 @@ from pytensor import Mode, function, grad
from pytensor.compile.ops import DeepCopyOp
from pytensor.configdefaults import config
from pytensor.graph.basic import Variable, equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import clone_replace, vectorize_node
from pytensor.graph.type import Type
from pytensor.misc.safe_asarray import _asarray
......@@ -16,7 +15,6 @@ from pytensor.scalar.basic import ScalarConstant
from pytensor.tensor import as_tensor_variable, broadcast_to, get_vector_length, row
from pytensor.tensor.basic import MakeVector, constant, stack
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.shape import (
Reshape,
Shape,
......@@ -26,7 +24,6 @@ from pytensor.tensor.shape import (
_specify_shape,
reshape,
shape,
shape_i,
shape_tuple,
specify_broadcastable,
specify_shape,
......@@ -633,13 +630,12 @@ def test_nonstandard_shapes():
tl_shape = shape(tl)
assert np.array_equal(tl_shape.get_test_value(), (2, 2, 3, 4))
# There's no `FunctionGraph`, so it should return a `Subtensor`
tl_shape_i = shape_i(tl, 0)
# Test specific dim
tl_shape_i = shape(tl)[0]
assert isinstance(tl_shape_i.owner.op, Subtensor)
assert tl_shape_i.get_test_value() == 2
tl_fg = FunctionGraph([a, b], [tl], features=[ShapeFeature()])
tl_shape_i = shape_i(tl, 0, fgraph=tl_fg)
tl_shape_i = Shape_i(0)(tl)
assert not isinstance(tl_shape_i.owner.op, Subtensor)
assert tl_shape_i.get_test_value() == 2
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论