Unverified 提交 5c63ee70 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

Allow passing static shape to tensor creation helpers (#118)

* Allow passing static shape to tensor creation helpers * Also default dtype to "floatX" when using `tensor` * Make tensor API similar to that of other variable constructors * Name is now the only optional non-keyword argument for all constructors
上级 43d91d0a
......@@ -3451,7 +3451,12 @@ class StructuredDot(Op):
return Apply(
self,
[a, b],
[tensor(dtype_out, shape=(None, 1 if b.type.shape[1] == 1 else None))],
[
tensor(
dtype=dtype_out,
shape=(None, 1 if b.type.shape[1] == 1 else None),
)
],
)
def perform(self, node, inputs, outputs):
......@@ -3582,7 +3587,9 @@ class StructuredDotGradCSC(COp):
def make_node(self, a_indices, a_indptr, b, g_ab):
return Apply(
self, [a_indices, a_indptr, b, g_ab], [tensor(g_ab.dtype, shape=(None,))]
self,
[a_indices, a_indptr, b, g_ab],
[tensor(dtype=g_ab.dtype, shape=(None,))],
)
def perform(self, node, inputs, outputs):
......@@ -3716,7 +3723,7 @@ class StructuredDotGradCSR(COp):
def make_node(self, a_indices, a_indptr, b, g_ab):
return Apply(
self, [a_indices, a_indptr, b, g_ab], [tensor(b.dtype, shape=(None,))]
self, [a_indices, a_indptr, b, g_ab], [tensor(dtype=b.dtype, shape=(None,))]
)
def perform(self, node, inputs, outputs):
......
......@@ -270,7 +270,11 @@ class StructuredDotCSC(COp):
r = Apply(
self,
[a_val, a_ind, a_ptr, a_nrows, b],
[tensor(dtype_out, shape=(None, 1 if b.type.shape[1] == 1 else None))],
[
tensor(
dtype=dtype_out, shape=(None, 1 if b.type.shape[1] == 1 else None)
)
],
)
return r
......@@ -465,7 +469,12 @@ class StructuredDotCSR(COp):
r = Apply(
self,
[a_val, a_ind, a_ptr, b],
[tensor(self.dtype_out, shape=(None, 1 if b.type.shape[1] == 1 else None))],
[
tensor(
dtype=self.dtype_out,
shape=(None, 1 if b.type.shape[1] == 1 else None),
)
],
)
return r
......@@ -705,7 +714,11 @@ class UsmmCscDense(_NoPythonCOp):
r = Apply(
self,
[alpha, x_val, x_ind, x_ptr, x_nrows, y, z],
[tensor(dtype_out, shape=(None, 1 if y.type.shape[1] == 1 else None))],
[
tensor(
dtype=dtype_out, shape=(None, 1 if y.type.shape[1] == 1 else None)
)
],
)
return r
......@@ -1142,7 +1155,9 @@ class MulSDCSC(_NoPythonCOp):
"""
assert b.type.ndim == 2
return Apply(
self, [a_data, a_indices, a_indptr, b], [tensor(b.dtype, shape=(None,))]
self,
[a_data, a_indices, a_indptr, b],
[tensor(dtype=b.dtype, shape=(None,))],
)
def c_code_cache_version(self):
......@@ -1280,7 +1295,9 @@ class MulSDCSR(_NoPythonCOp):
"""
assert b.type.ndim == 2
return Apply(
self, [a_data, a_indices, a_indptr, b], [tensor(b.dtype, shape=(None,))]
self,
[a_data, a_indices, a_indptr, b],
[tensor(dtype=b.dtype, shape=(None,))],
)
def c_code_cache_version(self):
......@@ -1470,7 +1487,9 @@ class MulSVCSR(_NoPythonCOp):
"""
assert b.type.ndim == 1
return Apply(
self, [a_data, a_indices, a_indptr, b], [tensor(b.dtype, shape=(None,))]
self,
[a_data, a_indices, a_indptr, b],
[tensor(dtype=b.dtype, shape=(None,))],
)
def c_code_cache_version(self):
......@@ -1642,7 +1661,9 @@ class StructuredAddSVCSR(_NoPythonCOp):
assert a_indptr.type.ndim == 1
assert b.type.ndim == 1
return Apply(
self, [a_data, a_indices, a_indptr, b], [tensor(b.dtype, shape=(None,))]
self,
[a_data, a_indices, a_indptr, b],
[tensor(dtype=b.dtype, shape=(None,))],
)
def c_code_cache_version(self):
......
......@@ -2882,7 +2882,7 @@ class ARange(Op):
assert step.ndim == 0
inputs = [start, stop, step]
outputs = [tensor(self.dtype, shape=(None,))]
outputs = [tensor(dtype=self.dtype, shape=(None,))]
return Apply(self, inputs, outputs)
......
......@@ -1680,7 +1680,7 @@ class Dot22(GemmRelated):
raise TypeError(y)
if y.type.dtype != x.type.dtype:
raise TypeError("dtype mismatch to Dot22")
outputs = [tensor(x.type.dtype, shape=(x.type.shape[0], y.type.shape[1]))]
outputs = [tensor(dtype=x.type.dtype, shape=(x.type.shape[0], y.type.shape[1]))]
return Apply(self, [x, y], outputs)
def perform(self, node, inp, out):
......@@ -1985,7 +1985,7 @@ class Dot22Scalar(GemmRelated):
raise TypeError("Dot22Scalar requires float or complex args", a.dtype)
sz = (x.type.shape[0], y.type.shape[1])
outputs = [tensor(x.type.dtype, shape=sz)]
outputs = [tensor(dtype=x.type.dtype, shape=sz)]
return Apply(self, [x, y, a], outputs)
def perform(self, node, inp, out):
......@@ -2221,7 +2221,7 @@ class BatchedDot(COp):
+ inputs[1].type.shape[2:]
)
out_shape = tuple(1 if s == 1 else None for s in out_shape)
return Apply(self, upcasted_inputs, [tensor(dtype, shape=out_shape)])
return Apply(self, upcasted_inputs, [tensor(dtype=dtype, shape=out_shape)])
def perform(self, node, inp, out):
x, y = inp
......
......@@ -36,7 +36,7 @@ class LoadFromDisk(Op):
def make_node(self, path):
if isinstance(path, str):
path = Constant(Generic(), path)
return Apply(self, [path], [tensor(self.dtype, shape=self.shape)])
return Apply(self, [path], [tensor(dtype=self.dtype, shape=self.shape)])
def perform(self, node, inp, out):
path = inp[0]
......@@ -135,7 +135,7 @@ class MPIRecv(Op):
[],
[
Variable(Generic(), None),
tensor(self.dtype, shape=self.static_shape),
tensor(dtype=self.dtype, shape=self.static_shape),
],
)
......@@ -180,7 +180,7 @@ class MPIRecvWait(Op):
return Apply(
self,
[request, data],
[tensor(data.dtype, shape=data.type.shape)],
[tensor(dtype=data.dtype, shape=data.type.shape)],
)
def perform(self, node, inp, out):
......
......@@ -152,8 +152,8 @@ class MaxAndArgmax(COp):
if i not in all_axes
)
outputs = [
tensor(x.type.dtype, shape=out_shape, name="max"),
tensor("int64", shape=out_shape, name="argmax"),
tensor(dtype=x.type.dtype, shape=out_shape, name="max"),
tensor(dtype="int64", shape=out_shape, name="argmax"),
]
return Apply(self, inputs, outputs)
......@@ -370,7 +370,7 @@ class Argmax(COp):
# We keep the original broadcastable flags for dimensions on which
# we do not perform the argmax.
out_shape = tuple(s for i, s in enumerate(x.type.shape) if i not in all_axes)
outputs = [tensor("int64", shape=out_shape, name="argmax")]
outputs = [tensor(dtype="int64", shape=out_shape, name="argmax")]
return Apply(self, inputs, outputs)
def prepare_node(self, node, storage_map, compute_map, impl):
......@@ -1922,7 +1922,7 @@ class Dot(Op):
sz = sx[:-1]
i_dtypes = [input.type.dtype for input in inputs]
outputs = [tensor(aes.upcast(*i_dtypes), shape=sz)]
outputs = [tensor(dtype=aes.upcast(*i_dtypes), shape=sz)]
return Apply(self, inputs, outputs)
def perform(self, node, inp, out):
......
......@@ -641,7 +641,7 @@ class Reshape(COp):
except NotScalarConstantError:
pass
return Apply(self, [x, shp], [tensor(x.type.dtype, shape=out_shape)])
return Apply(self, [x, shp], [tensor(dtype=x.type.dtype, shape=out_shape)])
def perform(self, node, inp, out_, params):
x, shp = inp
......
差异被折叠。
......@@ -556,13 +556,13 @@ def test_get_var_by_name():
def test_clone_new_inputs():
"""Make sure that `Apply.clone_with_new_inputs` properly handles `Type` changes."""
x = at.tensor(np.float64, shape=(None,))
y = at.tensor(np.float64, shape=(1,))
x = at.tensor(dtype=np.float64, shape=(None,))
y = at.tensor(dtype=np.float64, shape=(1,))
z = at.add(x, y)
assert z.type.shape == (None,)
x_new = at.tensor(np.float64, shape=(1,))
x_new = at.tensor(dtype=np.float64, shape=(1,))
# The output nodes should be reconstructed, because the input types' static
# shape information increased in specificity
......
......@@ -146,7 +146,7 @@ def test_Elemwise(inputs, input_vals, output_fn, exc):
# `{'drop': [1], 'shuffle': [2, 0], 'augment': [0, 2, 4]}`
(
set_test_value(
at.tensor(config.floatX, shape=(None, 1, None), name="a"),
at.tensor(dtype=config.floatX, shape=(None, 1, None), name="a"),
np.array([[[1.0, 2.0]], [[3.0, 4.0]]], dtype=config.floatX),
),
("x", 2, "x", 0, "x"),
......@@ -155,21 +155,21 @@ def test_Elemwise(inputs, input_vals, output_fn, exc):
# `{'drop': [1], 'shuffle': [0], 'augment': []}`
(
set_test_value(
at.tensor(config.floatX, shape=(None, 1), name="a"),
at.tensor(dtype=config.floatX, shape=(None, 1), name="a"),
np.array([[1.0], [2.0], [3.0], [4.0]], dtype=config.floatX),
),
(0,),
),
(
set_test_value(
at.tensor(config.floatX, shape=(None, 1), name="a"),
at.tensor(dtype=config.floatX, shape=(None, 1), name="a"),
np.array([[1.0], [2.0], [3.0], [4.0]], dtype=config.floatX),
),
(0,),
),
(
set_test_value(
at.tensor(config.floatX, shape=(1, 1, 1), name="a"),
at.tensor(dtype=config.floatX, shape=(1, 1, 1), name="a"),
np.array([[[1.0]]], dtype=config.floatX),
),
(),
......
......@@ -270,7 +270,7 @@ rng = np.random.default_rng(42849)
np.array([[1, 2], [3, 4]], dtype=np.float64),
),
set_test_value(
at.tensor("float64", shape=(1, None, None)),
at.tensor(dtype="float64", shape=(1, None, None)),
np.eye(2)[None, ...],
),
],
......
......@@ -582,7 +582,7 @@ def test_debugprint_mitmot():
def test_debugprint_compiled_fn():
M = at.tensor(np.float64, shape=(20000, 2, 2))
M = at.tensor(dtype=np.float64, shape=(20000, 2, 2))
one = at.as_tensor(1, dtype=np.int64)
zero = at.as_tensor(0, dtype=np.int64)
......
......@@ -607,7 +607,7 @@ def test_mvnormal_ShapeFeature():
assert M_at in graph_inputs([s2])
# Test broadcasted shapes
mean = tensor(config.floatX, shape=(1, None))
mean = tensor(dtype=config.floatX, shape=(1, None))
mean.tag.test_value = np.array([[0, 1, 2]], dtype=config.floatX)
test_covar = np.diag(np.array([1, 10, 100], dtype=config.floatX))
......
......@@ -125,9 +125,9 @@ def test_RandomVariable_basics():
def test_RandomVariable_bcast():
rv = RandomVariable("normal", 0, [0, 0], config.floatX, inplace=True)
mu = tensor(config.floatX, shape=(1, None, None))
mu = tensor(dtype=config.floatX, shape=(1, None, None))
mu.tag.test_value = np.zeros((1, 2, 3)).astype(config.floatX)
sd = tensor(config.floatX, shape=(None, None))
sd = tensor(dtype=config.floatX, shape=(None, None))
sd.tag.test_value = np.ones((2, 3)).astype(config.floatX)
s1 = iscalar()
......@@ -163,10 +163,10 @@ def test_RandomVariable_bcast_specify_shape():
s3 = Assert("testing")(s3, eq(s1, 1))
size = specify_shape(at.as_tensor([s1, s3, s2, s2, s1]), (5,))
mu = tensor(config.floatX, shape=(None, None, 1))
mu = tensor(dtype=config.floatX, shape=(None, None, 1))
mu.tag.test_value = np.random.normal(size=(2, 2, 1)).astype(config.floatX)
std = tensor(config.floatX, shape=(None, 1, 1))
std = tensor(dtype=config.floatX, shape=(None, 1, 1))
std.tag.test_value = np.ones((2, 1, 1)).astype(config.floatX)
res = rv(mu, std, size=size)
......
......@@ -69,7 +69,7 @@ def test_broadcast_params():
# Try it in PyTensor
with config.change_flags(compute_test_value="raise"):
mean = tensor(config.floatX, shape=(None, 1))
mean = tensor(dtype=config.floatX, shape=(None, 1))
mean.tag.test_value = np.array([[0], [10], [100]], dtype=config.floatX)
cov = matrix()
cov.tag.test_value = np.diag(np.array([1e-6], dtype=config.floatX))
......
......@@ -559,8 +559,8 @@ class TestUnbroadcast:
self.mode = get_default_mode().including("canonicalize")
def test_local_useless_unbroadcast(self):
x1 = tensor("float64", shape=(1, 2))
x2 = tensor("float64", shape=(2, 1))
x1 = tensor(dtype="float64", shape=(1, 2))
x2 = tensor(dtype="float64", shape=(2, 1))
unbroadcast_op = Unbroadcast(0)
f = function([x1], unbroadcast_op(x1), mode=self.mode)
......@@ -576,7 +576,7 @@ class TestUnbroadcast:
)
def test_local_unbroadcast_lift(self):
x = tensor("float64", shape=(1, 1))
x = tensor(dtype="float64", shape=(1, 1))
y = unbroadcast(at.exp(unbroadcast(x, 0)), 1)
assert (
......@@ -1693,8 +1693,8 @@ class TestLocalElemwiseAlloc:
],
)
def test_basic(self, expr, x_shape, y_shape):
x = at.tensor("int64", shape=(None,) * len(x_shape), name="x")
y = at.tensor("int64", shape=(None,) * len(y_shape), name="y")
x = at.tensor(dtype="int64", shape=(None,) * len(x_shape), name="x")
y = at.tensor(dtype="int64", shape=(None,) * len(y_shape), name="y")
z = expr(x, y)
z_opt = pytensor.function(
......
......@@ -1125,7 +1125,7 @@ class TestFusion:
"inplace",
)
x = tensor("floatX", shape=(None, None, None), name="x")
x = tensor(dtype="floatX", shape=(None, None, None), name="x")
out = exp(x).sum(axis=axis)
out_fn = function([x], out, mode=mode)
......@@ -1151,7 +1151,7 @@ class TestFusion:
)
# `Elemwise`s with more than one client shouldn't be rewritten
x = tensor("floatX", shape=(None, None, None), name="x")
x = tensor(dtype="floatX", shape=(None, None, None), name="x")
exp_x = exp(x)
out = exp_x.sum(axis=axis) + exp(x)
......@@ -1176,8 +1176,8 @@ class TestFusion:
"inplace",
)
x = tensor("floatX", shape=(None, None, None), name="x")
y = tensor("floatX", shape=(None, None, None), name="y")
x = tensor(dtype="floatX", shape=(None, None, None), name="x")
y = tensor(dtype="floatX", shape=(None, None, None), name="y")
out = (x + y).sum(axis=axis)
out_fn = function([x, y], out, mode=mode)
......
......@@ -959,11 +959,11 @@ class TestAlgebraicCanonizer:
def test_mismatching_types(self):
a = at.as_tensor([[0.0]], dtype=np.float64)
b = tensor("float64", shape=(None,)).dimshuffle("x", 0)
b = tensor(dtype="float64", shape=(None,)).dimshuffle("x", 0)
z = add(a, b)
# Construct a node with the wrong output `Type`
z = Apply(
z.owner.op, z.owner.inputs, [tensor("float64", shape=(None, None))]
z.owner.op, z.owner.inputs, [tensor(dtype="float64", shape=(None, None))]
).outputs[0]
z_rewritten = rewrite_graph(
......
......@@ -494,7 +494,7 @@ def test_local_Shape_of_SpecifyShape_partial(s1):
def test_local_Shape_i_ground():
x = tensor(np.float64, shape=(None, 2))
x = tensor(dtype=np.float64, shape=(None, 2))
s = Shape_i(1)(x)
fgraph = FunctionGraph(outputs=[s], clone=False)
......
......@@ -92,7 +92,7 @@ z = create_pytensor_param(np.random.default_rng().integers(0, 4, size=(2, 2)))
def test_local_replace_AdvancedSubtensor(indices, is_none):
X_val = np.random.normal(size=(4, 4, 4))
X = tensor(np.float64, shape=(None, None, None), name="X")
X = tensor(dtype=np.float64, shape=(None, None, None), name="X")
X.tag.test_value = X_val
Y = X[indices]
......@@ -932,7 +932,7 @@ class TestLocalSubtensorLift:
assert (f1(xval) == xval[:2, :5]).all()
# corner case 1: Unbroadcast changes dims which are dropped through subtensor
y = tensor("float64", shape=(1, 10, 1, 3), name="x")
y = tensor(dtype="float64", shape=(1, 10, 1, 3), name="x")
yval = np.random.random((1, 10, 1, 3)).astype(config.floatX)
assert y.broadcastable == (True, False, True, False)
newy = Unbroadcast(0, 2)(y)
......@@ -956,7 +956,7 @@ class TestLocalSubtensorLift:
assert (f3(yval) == yval[:, 3, 0]).all()
# corner case 3: subtensor idx_list is shorter than Unbroadcast.axis
z = tensor("float64", shape=(4, 10, 3, 1), name="x")
z = tensor(dtype="float64", shape=(4, 10, 3, 1), name="x")
zval = np.random.random((4, 10, 3, 1)).astype(config.floatX)
assert z.broadcastable == (False, False, False, True)
newz = Unbroadcast(3)(z)
......@@ -1911,7 +1911,7 @@ def test_local_subtensor_of_alloc():
def test_local_subtensor_shape_constant():
x = tensor(np.float64, shape=(1, None)).shape[0]
x = tensor(dtype=np.float64, shape=(1, None)).shape[0]
(res,) = local_subtensor_shape_constant.transform(None, x.owner)
assert isinstance(res, Constant)
assert res.data == 1
......@@ -1921,21 +1921,21 @@ def test_local_subtensor_shape_constant():
assert isinstance(res, Constant)
assert res.data == 1
x = _shape(tensor(np.float64, shape=(1, None)))[lscalar()]
x = _shape(tensor(dtype=np.float64, shape=(1, None)))[lscalar()]
assert not local_subtensor_shape_constant.transform(None, x.owner)
x = _shape(tensor(np.float64, shape=(1, None)))[0:]
x = _shape(tensor(dtype=np.float64, shape=(1, None)))[0:]
assert not local_subtensor_shape_constant.transform(None, x.owner)
x = _shape(tensor(np.float64, shape=(1, None)))[lscalar() :]
x = _shape(tensor(dtype=np.float64, shape=(1, None)))[lscalar() :]
assert not local_subtensor_shape_constant.transform(None, x.owner)
x = _shape(tensor(np.float64, shape=(1, 1)))[1:]
x = _shape(tensor(dtype=np.float64, shape=(1, 1)))[1:]
(res,) = local_subtensor_shape_constant.transform(None, x.owner)
assert isinstance(res, Constant)
assert np.array_equal(res.data, [1])
x = _shape(tensor(np.float64, shape=(None, 1, 1)))[1:]
x = _shape(tensor(dtype=np.float64, shape=(None, 1, 1)))[1:]
(res,) = local_subtensor_shape_constant.transform(None, x.owner)
assert isinstance(res, Constant)
assert np.array_equal(res.data, [1, 1])
......
......@@ -1177,7 +1177,7 @@ def test_get_vector_length():
# Test `Alloc`s
assert 3 == get_vector_length(alloc(0, 3))
assert 5 == get_vector_length(tensor(np.float64, shape=(5,)))
assert 5 == get_vector_length(tensor(dtype=np.float64, shape=(5,)))
class TestJoinAndSplit:
......@@ -4263,10 +4263,10 @@ class TestTakeAlongAxis:
indices = rng.integers(low=0, high=shape[axis or 0], size=indices_size)
arr_in = at.tensor(
config.floatX, shape=tuple(1 if s == 1 else None for s in arr.shape)
dtype=config.floatX, shape=tuple(1 if s == 1 else None for s in arr.shape)
)
indices_in = at.tensor(
np.int64, shape=tuple(1 if s == 1 else None for s in indices.shape)
dtype=np.int64, shape=tuple(1 if s == 1 else None for s in indices.shape)
)
out = at.take_along_axis(arr_in, indices_in, axis)
......@@ -4278,12 +4278,12 @@ class TestTakeAlongAxis:
)
def test_ndim_dtype_failures(self):
arr = at.tensor(config.floatX, shape=(None,) * 2)
indices = at.tensor(np.int64, shape=(None,) * 3)
arr = at.tensor(dtype=config.floatX, shape=(None,) * 2)
indices = at.tensor(dtype=np.int64, shape=(None,) * 3)
with pytest.raises(ValueError):
at.take_along_axis(arr, indices)
indices = at.tensor(np.float64, shape=(None,) * 2)
indices = at.tensor(dtype=np.float64, shape=(None,) * 2)
with pytest.raises(IndexError):
at.take_along_axis(arr, indices)
......@@ -4310,7 +4310,7 @@ def test_oriented_stack_functions(func):
with pytest.raises(ValueError):
func()
a = at.tensor(np.float64, shape=(None, None, None))
a = at.tensor(dtype=np.float64, shape=(None, None, None))
with pytest.raises(ValueError):
func(a, a)
......@@ -185,7 +185,7 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
assert np.allclose(np.mean(block_diffs), 0)
def test_static_shape(self):
x = tensor(np.float64, shape=(1, 2), name="x")
x = tensor(dtype=np.float64, shape=(1, 2), name="x")
y = x.dimshuffle([0, 1, "x"])
assert y.type.shape == (1, 2, 1)
......@@ -852,8 +852,8 @@ class TestElemwise(unittest_tools.InferShapeTester):
z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_1_shape])
def test_shape_types(self):
x = tensor(np.float64, (None, 1))
y = tensor(np.float64, (50, 10))
x = tensor(dtype=np.float64, shape=(None, 1))
y = tensor(dtype=np.float64, shape=(50, 10))
z = x * y
......@@ -864,33 +864,33 @@ class TestElemwise(unittest_tools.InferShapeTester):
assert all(isinstance(v.type, TensorType) for v in out_shape)
def test_static_shape_unary(self):
x = tensor("float64", shape=(None, 0, 1, 5))
x = tensor(dtype="float64", shape=(None, 0, 1, 5))
assert exp(x).type.shape == (None, 0, 1, 5)
def test_static_shape_binary(self):
x = tensor("float64", shape=(None, 5))
y = tensor("float64", shape=(None, 5))
x = tensor(dtype="float64", shape=(None, 5))
y = tensor(dtype="float64", shape=(None, 5))
assert (x + y).type.shape == (None, 5)
x = tensor("float64", shape=(None, 5))
y = tensor("float64", shape=(10, 5))
x = tensor(dtype="float64", shape=(None, 5))
y = tensor(dtype="float64", shape=(10, 5))
assert (x + y).type.shape == (10, 5)
x = tensor("float64", shape=(1, 5))
y = tensor("float64", shape=(10, 5))
x = tensor(dtype="float64", shape=(1, 5))
y = tensor(dtype="float64", shape=(10, 5))
assert (x + y).type.shape == (10, 5)
x = tensor("float64", shape=(None, 1))
y = tensor("float64", shape=(1, 1))
x = tensor(dtype="float64", shape=(None, 1))
y = tensor(dtype="float64", shape=(1, 1))
assert (x + y).type.shape == (None, 1)
x = tensor("float64", shape=(0, 0, 0))
y = tensor("float64", shape=(0, 1, None))
x = tensor(dtype="float64", shape=(0, 0, 0))
y = tensor(dtype="float64", shape=(0, 1, None))
assert (x + y).type.shape == (0, 0, 0)
def test_invalid_static_shape(self):
x = tensor("float64", shape=(2,))
y = tensor("float64", shape=(3,))
x = tensor(dtype="float64", shape=(2,))
y = tensor(dtype="float64", shape=(3,))
with pytest.raises(
ValueError,
match=re.escape("Incompatible Elemwise input shapes [(2,), (3,)]"),
......
......@@ -1335,7 +1335,7 @@ class TestBroadcastTo(utt.InferShapeTester):
def test_infer_shape(self):
rng = np.random.default_rng(43)
a = tensor(config.floatX, shape=(None, 1, None))
a = tensor(dtype=config.floatX, shape=(None, 1, None))
shape = list(a.shape)
out = self.op(a, shape)
......@@ -1346,7 +1346,7 @@ class TestBroadcastTo(utt.InferShapeTester):
self.op_class,
)
a = tensor(config.floatX, shape=(None, 1, None))
a = tensor(dtype=config.floatX, shape=(None, 1, None))
shape = [iscalar() for i in range(4)]
self._compile_and_check(
[a] + shape,
......
......@@ -650,7 +650,7 @@ class TestUnbroadcast:
class TestUnbroadcastInferShape(utt.InferShapeTester):
def test_basic(self):
rng = np.random.default_rng(3453)
adtens4 = tensor("float64", shape=(1, 1, 1, None))
adtens4 = tensor(dtype="float64", shape=(1, 1, 1, None))
adtens4_val = rng.random((1, 1, 1, 3)).astype(config.floatX)
self._compile_and_check(
[adtens4],
......@@ -666,7 +666,7 @@ def test_shape_tuple():
x = Variable(MyType2(), None, None)
assert shape_tuple(x) == ()
x = tensor(np.float64, shape=(1, 2, None))
x = tensor(dtype=np.float64, shape=(1, 2, None))
res = shape_tuple(x)
assert isinstance(res, tuple)
assert isinstance(res[0], ScalarConstant)
......
......@@ -7,7 +7,20 @@ import pytest
import pytensor.tensor as at
from pytensor.configdefaults import config
from pytensor.tensor.shape import SpecifyShape
from pytensor.tensor.type import TensorType
from pytensor.tensor.type import (
TensorType,
col,
matrix,
row,
scalar,
tensor,
tensor3,
tensor4,
tensor5,
tensor6,
tensor7,
vector,
)
@pytest.mark.parametrize(
......@@ -326,3 +339,114 @@ def test_deprecated_kwargs():
new_res = res.clone(broadcastable=(False, True))
assert new_res.shape == (None, 1)
def test_tensor_creator_helper():
res = tensor(shape=(5, None))
assert res.type == TensorType(config.floatX, shape=(5, None))
assert res.name is None
res = tensor(dtype="int64", shape=(5, None), name="custom")
assert res.type == TensorType(dtype="int64", shape=(5, None))
assert res.name == "custom"
# Test with positional name argument
res = tensor("custom", dtype="int64", shape=(5, None))
assert res.type == TensorType(dtype="int64", shape=(5, None))
assert res.name == "custom"
with pytest.warns(
DeprecationWarning, match="The `broadcastable` keyword is deprecated"
):
res = tensor(dtype="int64", broadcastable=(True, False), name="custom")
assert res.type == TensorType("int64", shape=(1, None))
assert res.name == "custom"
@pytest.mark.parametrize("dtype", ("floatX", "float64", bool, np.int64))
def test_tensor_creator_dtype_catch(dtype):
with pytest.raises(
ValueError,
match="This name looks like a dtype, which you should pass as a keyword argument only",
):
tensor(dtype, shape=(None,))
# This should work
assert tensor(dtype=dtype, shape=(None,))
def test_tensor_creator_ignores_rare_dtype_name():
# This could be a dtype, but we assume it's a name
assert tensor("a", shape=(None,)).type.dtype == config.floatX
def test_scalar_creator_helper():
default = scalar()
assert default.type.dtype == config.floatX
assert default.type.ndim == 0
assert default.type.shape == ()
assert default.name is None
custom = scalar(name="custom", dtype="int64")
assert custom.dtype == "int64"
assert custom.type.ndim == 0
assert custom.type.shape == ()
@pytest.mark.parametrize(
"helper, ndims",
[
(vector, 1),
(matrix, 2),
(row, 2),
(col, 2),
(tensor3, 3),
(tensor4, 4),
(tensor5, 5),
(tensor6, 6),
(tensor7, 7),
],
)
def test_tensor_creator_helpers(helper, ndims):
if helper is row:
default_shape = (1, None)
custom_shape = (1, 5)
elif helper is col:
default_shape = (None, 1)
custom_shape = (5, 1)
else:
default_shape = (None,) * ndims
custom_shape = tuple(range(ndims))
default = helper()
assert default.type.dtype == config.floatX
assert default.type.ndim == ndims
assert default.type.shape == default_shape
assert default.name is None
assert helper(shape=default_shape).type == default.type
custom = helper(name="custom", dtype="int64", shape=custom_shape)
assert custom.type.dtype == "int64"
assert custom.type.ndim == ndims
assert custom.type.shape == custom_shape
assert custom.name == "custom"
with pytest.raises(TypeError, match="Shape must be a tuple"):
helper(shape=list(default_shape))
with pytest.raises(ValueError, match=f"Shape must be a tuple of length {ndims}"):
helper(shape=(None,) + default_shape)
with pytest.raises(TypeError, match="Shape entries must be None or integer"):
helper(shape=(1.0,) * ndims)
@pytest.mark.parametrize("helper", (row, col))
def test_row_matrix_creator_helpers(helper):
if helper is row:
match = "The first dimension of a `row` must have shape 1, got 2"
else:
match = "The second dimension of a `col` must have shape 1, got 5"
with pytest.raises(ValueError, match=match):
helper(shape=(2, 5))
......@@ -303,7 +303,7 @@ class TestIfelse(utt.OptimizationTestMixin):
rng = np.random.default_rng(utt.fetch_seed())
data = rng.random(5).astype(self.dtype)
x = self.shared(data)
y = col("y", self.dtype)
y = col("y", dtype=self.dtype)
cond = iscalar("cond")
with pytest.raises(TypeError):
......@@ -316,7 +316,7 @@ class TestIfelse(utt.OptimizationTestMixin):
data = rng.random(5).astype(self.dtype)
x = self.shared(data)
# print x.broadcastable
y = row("y", self.dtype)
y = row("y", dtype=self.dtype)
# print y.broadcastable
cond = iscalar("cond")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论