提交 d3eda996 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Shape only accepts TensorVariables

Before: shape(NoneType) -> np.shape(None) -> () (numpy tends to do this for things it doesn't understand) shape(TypedLists) -> shape(np.asarray(TypedList)) -> either converts or fails. Now TypedList are considered vectors / 1d like python lists
上级 1243d583
......@@ -4270,6 +4270,9 @@ class Choose(Op):
def infer_shape(self, fgraph, node, shapes):
a_shape, choices_shape = shapes
if choices_shape is None:
# choices is a TypedList, not a tensor; no shape to broadcast
return [a_shape]
out_shape = pytensor.tensor.extra_ops.broadcast_shape(
a_shape, choices_shape[1:], arrays_are_shapes=True
)
......@@ -4291,8 +4294,10 @@ class Choose(Op):
# otherwise use as_tensor_variable
if isinstance(choices, tuple | list):
choice = pytensor.typed_list.make_list(choices)
choice_dtype = choice.ttype.dtype
else:
choice = as_tensor_variable(choices)
choice_dtype = choice.dtype
(out_shape,) = self.infer_shape(
None, None, [shape_tuple(a), shape_tuple(choice)]
......@@ -4310,7 +4315,7 @@ class Choose(Op):
else:
static_out_shape += (None,)
o = TensorType(choice.dtype, shape=static_out_shape)
o = TensorType(choice_dtype, shape=static_out_shape)
return Apply(self, [a, choice], [o()])
def perform(self, node, inputs, outputs):
......
......@@ -71,13 +71,8 @@ class Shape(COp):
__props__ = ()
def make_node(self, x):
if not isinstance(x, Variable):
x = ptb.as_tensor_variable(x)
if isinstance(x.type, TensorType):
out_var = TensorType("int64", (x.type.ndim,))()
else:
out_var = pytensor.tensor.type.lvector()
out_var = tensor(dtype="int64", shape=(x.type.ndim,))
return Apply(self, [x], [out_var])
......
......@@ -46,8 +46,6 @@ class _typed_list_py_operators:
return length(self)
ttype = property(lambda self: self.type.ttype)
dtype = property(lambda self: self.type.ttype.dtype)
ndim = property(lambda self: self.type.ttype.ndim + 1)
class TypedListVariable(_typed_list_py_operators, Variable):
......
......@@ -7,13 +7,15 @@ class TypedListType(CType):
Parameters
----------
ttype
Type of pytensor variable this list will contains, can be another list.
Type of pytensor variable this list contains, can be another list.
depth
Optional parameters, any value above 0 will create a nested list of
this depth. (0-based)
"""
dtype = property(lambda self: self.ttype)
def __init__(self, ttype, depth=0):
if depth < 0:
raise ValueError("Please specify a depth superior or equal to 0")
......@@ -137,6 +139,3 @@ class TypedListType(CType):
def c_code_cache_version(self):
return (2,)
dtype = property(lambda self: self.ttype)
ndim = property(lambda self: self.ttype.ndim + 1)
......@@ -3,7 +3,6 @@ import pytest
from pytensor import (
Mode,
Variable,
config,
function,
shared,
......@@ -16,7 +15,6 @@ from pytensor.graph import (
FunctionGraph,
Op,
RewriteDatabaseQuery,
Type,
rewrite_graph,
)
from pytensor.graph.basic import equal_computations
......@@ -31,7 +29,6 @@ from pytensor.tensor import (
lscalar,
lscalars,
matrix,
shape,
specify_shape,
tensor,
tensor3,
......@@ -769,18 +766,6 @@ def test_local_subtensor_shape_constant():
assert isinstance(res, Constant)
assert np.array_equal(res.data, [1, 1])
# A test for a non-`TensorType`
class MyType(Type):
def filter(self, *args, **kwargs):
raise NotImplementedError()
def __eq__(self, other):
return isinstance(other, MyType) and other.thingy == self.thingy
x = shape(Variable(MyType(), None, None))[0]
assert not local_subtensor_shape_constant.transform(None, x.owner)
@pytest.mark.parametrize(
"original_fn, supported",
......
......@@ -9,7 +9,6 @@ from pytensor.compile.ops import DeepCopyOp
from pytensor.configdefaults import config
from pytensor.graph.basic import Variable, equal_computations
from pytensor.graph.replace import clone_replace, vectorize_graph
from pytensor.graph.type import Type
from pytensor.scalar.basic import ScalarConstant
from pytensor.tensor import as_tensor_variable, broadcast_to, get_vector_length, row
from pytensor.tensor.basic import MakeVector, arange, constant, stack
......@@ -26,7 +25,6 @@ from pytensor.tensor.shape import (
specify_broadcastable,
specify_shape,
)
from pytensor.tensor.subtensor import Subtensor
from pytensor.tensor.type import (
TensorType,
dmatrix,
......@@ -61,16 +59,6 @@ def test_shape_basic():
s = shape(lscalar())
assert s.type.shape == (0,)
class MyType(Type):
def filter(self, *args, **kwargs):
raise NotImplementedError()
def __eq__(self, other):
return isinstance(other, MyType) and other.thingy == self.thingy
s = shape(Variable(MyType(), None))
assert s.type.shape == (None,)
s = shape(np.array(1))
assert np.array_equal(eval_outputs([s]), [])
......@@ -681,28 +669,20 @@ class TestRopLop(RopLopChecker):
)
@config.change_flags(compute_test_value="raise")
def test_nonstandard_shapes():
def test_shape_rejects_non_tensor_type():
"""Shape raises TypeError for non-TensorType inputs."""
with pytest.raises(TypeError, match="TensorType"):
shape(NoneConst)
unknown_type_var = Variable(MyType2(), None, None)
with pytest.raises(TypeError, match="TensorType"):
shape(unknown_type_var)
a = tensor3(config.floatX)
a.tag.test_value = np.random.random((2, 3, 4)).astype(config.floatX)
b = tensor3(config.floatX)
b.tag.test_value = np.random.random((2, 3, 4)).astype(config.floatX)
tl = make_list([a, b])
tl_shape = shape(tl)
assert np.array_equal(tl_shape.get_test_value(), (2, 2, 3, 4))
# Test specific dim
tl_shape_i = shape(tl)[0]
assert isinstance(tl_shape_i.owner.op, Subtensor)
assert tl_shape_i.get_test_value() == 2
tl_shape_i = Shape_i(0)(tl)
assert not isinstance(tl_shape_i.owner.op, Subtensor)
assert tl_shape_i.get_test_value() == 2
none_shape = shape(NoneConst)
assert np.array_equal(none_shape.get_test_value(), [])
with pytest.raises(TypeError, match="TensorType"):
shape(tl)
def test_shape_i_basics():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论