提交 d3eda996 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Shape only accepts TensorVariables

Before: shape(NoneType) -> np.shape(None) -> () (numpy tends to do this for things it doesn't understand) shape(TypedLists) -> shape(np.asarray(TypedList)) -> either converts or fails. Now TypedList are considered vectors / 1d like python lists
上级 1243d583
...@@ -4270,6 +4270,9 @@ class Choose(Op): ...@@ -4270,6 +4270,9 @@ class Choose(Op):
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
a_shape, choices_shape = shapes a_shape, choices_shape = shapes
if choices_shape is None:
# choices is a TypedList, not a tensor; no shape to broadcast
return [a_shape]
out_shape = pytensor.tensor.extra_ops.broadcast_shape( out_shape = pytensor.tensor.extra_ops.broadcast_shape(
a_shape, choices_shape[1:], arrays_are_shapes=True a_shape, choices_shape[1:], arrays_are_shapes=True
) )
...@@ -4291,8 +4294,10 @@ class Choose(Op): ...@@ -4291,8 +4294,10 @@ class Choose(Op):
# otherwise use as_tensor_variable # otherwise use as_tensor_variable
if isinstance(choices, tuple | list): if isinstance(choices, tuple | list):
choice = pytensor.typed_list.make_list(choices) choice = pytensor.typed_list.make_list(choices)
choice_dtype = choice.ttype.dtype
else: else:
choice = as_tensor_variable(choices) choice = as_tensor_variable(choices)
choice_dtype = choice.dtype
(out_shape,) = self.infer_shape( (out_shape,) = self.infer_shape(
None, None, [shape_tuple(a), shape_tuple(choice)] None, None, [shape_tuple(a), shape_tuple(choice)]
...@@ -4310,7 +4315,7 @@ class Choose(Op): ...@@ -4310,7 +4315,7 @@ class Choose(Op):
else: else:
static_out_shape += (None,) static_out_shape += (None,)
o = TensorType(choice.dtype, shape=static_out_shape) o = TensorType(choice_dtype, shape=static_out_shape)
return Apply(self, [a, choice], [o()]) return Apply(self, [a, choice], [o()])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
......
...@@ -71,13 +71,8 @@ class Shape(COp): ...@@ -71,13 +71,8 @@ class Shape(COp):
__props__ = () __props__ = ()
def make_node(self, x): def make_node(self, x):
if not isinstance(x, Variable): x = ptb.as_tensor_variable(x)
x = ptb.as_tensor_variable(x) out_var = tensor(dtype="int64", shape=(x.type.ndim,))
if isinstance(x.type, TensorType):
out_var = TensorType("int64", (x.type.ndim,))()
else:
out_var = pytensor.tensor.type.lvector()
return Apply(self, [x], [out_var]) return Apply(self, [x], [out_var])
......
...@@ -46,8 +46,6 @@ class _typed_list_py_operators: ...@@ -46,8 +46,6 @@ class _typed_list_py_operators:
return length(self) return length(self)
ttype = property(lambda self: self.type.ttype) ttype = property(lambda self: self.type.ttype)
dtype = property(lambda self: self.type.ttype.dtype)
ndim = property(lambda self: self.type.ttype.ndim + 1)
class TypedListVariable(_typed_list_py_operators, Variable): class TypedListVariable(_typed_list_py_operators, Variable):
......
...@@ -7,13 +7,15 @@ class TypedListType(CType): ...@@ -7,13 +7,15 @@ class TypedListType(CType):
Parameters Parameters
---------- ----------
ttype ttype
Type of pytensor variable this list will contains, can be another list. Type of pytensor variable this list contains, can be another list.
depth depth
Optional parameters, any value above 0 will create a nested list of Optional parameters, any value above 0 will create a nested list of
this depth. (0-based) this depth. (0-based)
""" """
dtype = property(lambda self: self.ttype)
def __init__(self, ttype, depth=0): def __init__(self, ttype, depth=0):
if depth < 0: if depth < 0:
raise ValueError("Please specify a depth superior or equal to 0") raise ValueError("Please specify a depth superior or equal to 0")
...@@ -137,6 +139,3 @@ class TypedListType(CType): ...@@ -137,6 +139,3 @@ class TypedListType(CType):
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (2,)
dtype = property(lambda self: self.ttype)
ndim = property(lambda self: self.ttype.ndim + 1)
...@@ -3,7 +3,6 @@ import pytest ...@@ -3,7 +3,6 @@ import pytest
from pytensor import ( from pytensor import (
Mode, Mode,
Variable,
config, config,
function, function,
shared, shared,
...@@ -16,7 +15,6 @@ from pytensor.graph import ( ...@@ -16,7 +15,6 @@ from pytensor.graph import (
FunctionGraph, FunctionGraph,
Op, Op,
RewriteDatabaseQuery, RewriteDatabaseQuery,
Type,
rewrite_graph, rewrite_graph,
) )
from pytensor.graph.basic import equal_computations from pytensor.graph.basic import equal_computations
...@@ -31,7 +29,6 @@ from pytensor.tensor import ( ...@@ -31,7 +29,6 @@ from pytensor.tensor import (
lscalar, lscalar,
lscalars, lscalars,
matrix, matrix,
shape,
specify_shape, specify_shape,
tensor, tensor,
tensor3, tensor3,
...@@ -769,18 +766,6 @@ def test_local_subtensor_shape_constant(): ...@@ -769,18 +766,6 @@ def test_local_subtensor_shape_constant():
assert isinstance(res, Constant) assert isinstance(res, Constant)
assert np.array_equal(res.data, [1, 1]) assert np.array_equal(res.data, [1, 1])
# A test for a non-`TensorType`
class MyType(Type):
def filter(self, *args, **kwargs):
raise NotImplementedError()
def __eq__(self, other):
return isinstance(other, MyType) and other.thingy == self.thingy
x = shape(Variable(MyType(), None, None))[0]
assert not local_subtensor_shape_constant.transform(None, x.owner)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"original_fn, supported", "original_fn, supported",
......
...@@ -9,7 +9,6 @@ from pytensor.compile.ops import DeepCopyOp ...@@ -9,7 +9,6 @@ from pytensor.compile.ops import DeepCopyOp
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Variable, equal_computations from pytensor.graph.basic import Variable, equal_computations
from pytensor.graph.replace import clone_replace, vectorize_graph from pytensor.graph.replace import clone_replace, vectorize_graph
from pytensor.graph.type import Type
from pytensor.scalar.basic import ScalarConstant from pytensor.scalar.basic import ScalarConstant
from pytensor.tensor import as_tensor_variable, broadcast_to, get_vector_length, row from pytensor.tensor import as_tensor_variable, broadcast_to, get_vector_length, row
from pytensor.tensor.basic import MakeVector, arange, constant, stack from pytensor.tensor.basic import MakeVector, arange, constant, stack
...@@ -26,7 +25,6 @@ from pytensor.tensor.shape import ( ...@@ -26,7 +25,6 @@ from pytensor.tensor.shape import (
specify_broadcastable, specify_broadcastable,
specify_shape, specify_shape,
) )
from pytensor.tensor.subtensor import Subtensor
from pytensor.tensor.type import ( from pytensor.tensor.type import (
TensorType, TensorType,
dmatrix, dmatrix,
...@@ -61,16 +59,6 @@ def test_shape_basic(): ...@@ -61,16 +59,6 @@ def test_shape_basic():
s = shape(lscalar()) s = shape(lscalar())
assert s.type.shape == (0,) assert s.type.shape == (0,)
class MyType(Type):
def filter(self, *args, **kwargs):
raise NotImplementedError()
def __eq__(self, other):
return isinstance(other, MyType) and other.thingy == self.thingy
s = shape(Variable(MyType(), None))
assert s.type.shape == (None,)
s = shape(np.array(1)) s = shape(np.array(1))
assert np.array_equal(eval_outputs([s]), []) assert np.array_equal(eval_outputs([s]), [])
...@@ -681,28 +669,20 @@ class TestRopLop(RopLopChecker): ...@@ -681,28 +669,20 @@ class TestRopLop(RopLopChecker):
) )
@config.change_flags(compute_test_value="raise") def test_shape_rejects_non_tensor_type():
def test_nonstandard_shapes(): """Shape raises TypeError for non-TensorType inputs."""
with pytest.raises(TypeError, match="TensorType"):
shape(NoneConst)
unknown_type_var = Variable(MyType2(), None, None)
with pytest.raises(TypeError, match="TensorType"):
shape(unknown_type_var)
a = tensor3(config.floatX) a = tensor3(config.floatX)
a.tag.test_value = np.random.random((2, 3, 4)).astype(config.floatX)
b = tensor3(config.floatX) b = tensor3(config.floatX)
b.tag.test_value = np.random.random((2, 3, 4)).astype(config.floatX)
tl = make_list([a, b]) tl = make_list([a, b])
tl_shape = shape(tl) with pytest.raises(TypeError, match="TensorType"):
assert np.array_equal(tl_shape.get_test_value(), (2, 2, 3, 4)) shape(tl)
# Test specific dim
tl_shape_i = shape(tl)[0]
assert isinstance(tl_shape_i.owner.op, Subtensor)
assert tl_shape_i.get_test_value() == 2
tl_shape_i = Shape_i(0)(tl)
assert not isinstance(tl_shape_i.owner.op, Subtensor)
assert tl_shape_i.get_test_value() == 2
none_shape = shape(NoneConst)
assert np.array_equal(none_shape.get_test_value(), [])
def test_shape_i_basics(): def test_shape_i_basics():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论