Unverified 提交 5c63ee70 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

Allow passing static shape to tensor creation helpers (#118)

* Allow passing static shape to tensor creation helpers * Also default dtype to "floatX" when using `tensor` * Make tensor API similar to that of other variable constructors * Name is now the only optional non-keyword argument for all constructors
上级 43d91d0a
...@@ -3451,7 +3451,12 @@ class StructuredDot(Op): ...@@ -3451,7 +3451,12 @@ class StructuredDot(Op):
return Apply( return Apply(
self, self,
[a, b], [a, b],
[tensor(dtype_out, shape=(None, 1 if b.type.shape[1] == 1 else None))], [
tensor(
dtype=dtype_out,
shape=(None, 1 if b.type.shape[1] == 1 else None),
)
],
) )
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
...@@ -3582,7 +3587,9 @@ class StructuredDotGradCSC(COp): ...@@ -3582,7 +3587,9 @@ class StructuredDotGradCSC(COp):
def make_node(self, a_indices, a_indptr, b, g_ab): def make_node(self, a_indices, a_indptr, b, g_ab):
return Apply( return Apply(
self, [a_indices, a_indptr, b, g_ab], [tensor(g_ab.dtype, shape=(None,))] self,
[a_indices, a_indptr, b, g_ab],
[tensor(dtype=g_ab.dtype, shape=(None,))],
) )
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
...@@ -3716,7 +3723,7 @@ class StructuredDotGradCSR(COp): ...@@ -3716,7 +3723,7 @@ class StructuredDotGradCSR(COp):
def make_node(self, a_indices, a_indptr, b, g_ab): def make_node(self, a_indices, a_indptr, b, g_ab):
return Apply( return Apply(
self, [a_indices, a_indptr, b, g_ab], [tensor(b.dtype, shape=(None,))] self, [a_indices, a_indptr, b, g_ab], [tensor(dtype=b.dtype, shape=(None,))]
) )
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
......
...@@ -270,7 +270,11 @@ class StructuredDotCSC(COp): ...@@ -270,7 +270,11 @@ class StructuredDotCSC(COp):
r = Apply( r = Apply(
self, self,
[a_val, a_ind, a_ptr, a_nrows, b], [a_val, a_ind, a_ptr, a_nrows, b],
[tensor(dtype_out, shape=(None, 1 if b.type.shape[1] == 1 else None))], [
tensor(
dtype=dtype_out, shape=(None, 1 if b.type.shape[1] == 1 else None)
)
],
) )
return r return r
...@@ -465,7 +469,12 @@ class StructuredDotCSR(COp): ...@@ -465,7 +469,12 @@ class StructuredDotCSR(COp):
r = Apply( r = Apply(
self, self,
[a_val, a_ind, a_ptr, b], [a_val, a_ind, a_ptr, b],
[tensor(self.dtype_out, shape=(None, 1 if b.type.shape[1] == 1 else None))], [
tensor(
dtype=self.dtype_out,
shape=(None, 1 if b.type.shape[1] == 1 else None),
)
],
) )
return r return r
...@@ -705,7 +714,11 @@ class UsmmCscDense(_NoPythonCOp): ...@@ -705,7 +714,11 @@ class UsmmCscDense(_NoPythonCOp):
r = Apply( r = Apply(
self, self,
[alpha, x_val, x_ind, x_ptr, x_nrows, y, z], [alpha, x_val, x_ind, x_ptr, x_nrows, y, z],
[tensor(dtype_out, shape=(None, 1 if y.type.shape[1] == 1 else None))], [
tensor(
dtype=dtype_out, shape=(None, 1 if y.type.shape[1] == 1 else None)
)
],
) )
return r return r
...@@ -1142,7 +1155,9 @@ class MulSDCSC(_NoPythonCOp): ...@@ -1142,7 +1155,9 @@ class MulSDCSC(_NoPythonCOp):
""" """
assert b.type.ndim == 2 assert b.type.ndim == 2
return Apply( return Apply(
self, [a_data, a_indices, a_indptr, b], [tensor(b.dtype, shape=(None,))] self,
[a_data, a_indices, a_indptr, b],
[tensor(dtype=b.dtype, shape=(None,))],
) )
def c_code_cache_version(self): def c_code_cache_version(self):
...@@ -1280,7 +1295,9 @@ class MulSDCSR(_NoPythonCOp): ...@@ -1280,7 +1295,9 @@ class MulSDCSR(_NoPythonCOp):
""" """
assert b.type.ndim == 2 assert b.type.ndim == 2
return Apply( return Apply(
self, [a_data, a_indices, a_indptr, b], [tensor(b.dtype, shape=(None,))] self,
[a_data, a_indices, a_indptr, b],
[tensor(dtype=b.dtype, shape=(None,))],
) )
def c_code_cache_version(self): def c_code_cache_version(self):
...@@ -1470,7 +1487,9 @@ class MulSVCSR(_NoPythonCOp): ...@@ -1470,7 +1487,9 @@ class MulSVCSR(_NoPythonCOp):
""" """
assert b.type.ndim == 1 assert b.type.ndim == 1
return Apply( return Apply(
self, [a_data, a_indices, a_indptr, b], [tensor(b.dtype, shape=(None,))] self,
[a_data, a_indices, a_indptr, b],
[tensor(dtype=b.dtype, shape=(None,))],
) )
def c_code_cache_version(self): def c_code_cache_version(self):
...@@ -1642,7 +1661,9 @@ class StructuredAddSVCSR(_NoPythonCOp): ...@@ -1642,7 +1661,9 @@ class StructuredAddSVCSR(_NoPythonCOp):
assert a_indptr.type.ndim == 1 assert a_indptr.type.ndim == 1
assert b.type.ndim == 1 assert b.type.ndim == 1
return Apply( return Apply(
self, [a_data, a_indices, a_indptr, b], [tensor(b.dtype, shape=(None,))] self,
[a_data, a_indices, a_indptr, b],
[tensor(dtype=b.dtype, shape=(None,))],
) )
def c_code_cache_version(self): def c_code_cache_version(self):
......
...@@ -2882,7 +2882,7 @@ class ARange(Op): ...@@ -2882,7 +2882,7 @@ class ARange(Op):
assert step.ndim == 0 assert step.ndim == 0
inputs = [start, stop, step] inputs = [start, stop, step]
outputs = [tensor(self.dtype, shape=(None,))] outputs = [tensor(dtype=self.dtype, shape=(None,))]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
......
...@@ -1680,7 +1680,7 @@ class Dot22(GemmRelated): ...@@ -1680,7 +1680,7 @@ class Dot22(GemmRelated):
raise TypeError(y) raise TypeError(y)
if y.type.dtype != x.type.dtype: if y.type.dtype != x.type.dtype:
raise TypeError("dtype mismatch to Dot22") raise TypeError("dtype mismatch to Dot22")
outputs = [tensor(x.type.dtype, shape=(x.type.shape[0], y.type.shape[1]))] outputs = [tensor(dtype=x.type.dtype, shape=(x.type.shape[0], y.type.shape[1]))]
return Apply(self, [x, y], outputs) return Apply(self, [x, y], outputs)
def perform(self, node, inp, out): def perform(self, node, inp, out):
...@@ -1985,7 +1985,7 @@ class Dot22Scalar(GemmRelated): ...@@ -1985,7 +1985,7 @@ class Dot22Scalar(GemmRelated):
raise TypeError("Dot22Scalar requires float or complex args", a.dtype) raise TypeError("Dot22Scalar requires float or complex args", a.dtype)
sz = (x.type.shape[0], y.type.shape[1]) sz = (x.type.shape[0], y.type.shape[1])
outputs = [tensor(x.type.dtype, shape=sz)] outputs = [tensor(dtype=x.type.dtype, shape=sz)]
return Apply(self, [x, y, a], outputs) return Apply(self, [x, y, a], outputs)
def perform(self, node, inp, out): def perform(self, node, inp, out):
...@@ -2221,7 +2221,7 @@ class BatchedDot(COp): ...@@ -2221,7 +2221,7 @@ class BatchedDot(COp):
+ inputs[1].type.shape[2:] + inputs[1].type.shape[2:]
) )
out_shape = tuple(1 if s == 1 else None for s in out_shape) out_shape = tuple(1 if s == 1 else None for s in out_shape)
return Apply(self, upcasted_inputs, [tensor(dtype, shape=out_shape)]) return Apply(self, upcasted_inputs, [tensor(dtype=dtype, shape=out_shape)])
def perform(self, node, inp, out): def perform(self, node, inp, out):
x, y = inp x, y = inp
......
...@@ -36,7 +36,7 @@ class LoadFromDisk(Op): ...@@ -36,7 +36,7 @@ class LoadFromDisk(Op):
def make_node(self, path): def make_node(self, path):
if isinstance(path, str): if isinstance(path, str):
path = Constant(Generic(), path) path = Constant(Generic(), path)
return Apply(self, [path], [tensor(self.dtype, shape=self.shape)]) return Apply(self, [path], [tensor(dtype=self.dtype, shape=self.shape)])
def perform(self, node, inp, out): def perform(self, node, inp, out):
path = inp[0] path = inp[0]
...@@ -135,7 +135,7 @@ class MPIRecv(Op): ...@@ -135,7 +135,7 @@ class MPIRecv(Op):
[], [],
[ [
Variable(Generic(), None), Variable(Generic(), None),
tensor(self.dtype, shape=self.static_shape), tensor(dtype=self.dtype, shape=self.static_shape),
], ],
) )
...@@ -180,7 +180,7 @@ class MPIRecvWait(Op): ...@@ -180,7 +180,7 @@ class MPIRecvWait(Op):
return Apply( return Apply(
self, self,
[request, data], [request, data],
[tensor(data.dtype, shape=data.type.shape)], [tensor(dtype=data.dtype, shape=data.type.shape)],
) )
def perform(self, node, inp, out): def perform(self, node, inp, out):
......
...@@ -152,8 +152,8 @@ class MaxAndArgmax(COp): ...@@ -152,8 +152,8 @@ class MaxAndArgmax(COp):
if i not in all_axes if i not in all_axes
) )
outputs = [ outputs = [
tensor(x.type.dtype, shape=out_shape, name="max"), tensor(dtype=x.type.dtype, shape=out_shape, name="max"),
tensor("int64", shape=out_shape, name="argmax"), tensor(dtype="int64", shape=out_shape, name="argmax"),
] ]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
...@@ -370,7 +370,7 @@ class Argmax(COp): ...@@ -370,7 +370,7 @@ class Argmax(COp):
# We keep the original broadcastable flags for dimensions on which # We keep the original broadcastable flags for dimensions on which
# we do not perform the argmax. # we do not perform the argmax.
out_shape = tuple(s for i, s in enumerate(x.type.shape) if i not in all_axes) out_shape = tuple(s for i, s in enumerate(x.type.shape) if i not in all_axes)
outputs = [tensor("int64", shape=out_shape, name="argmax")] outputs = [tensor(dtype="int64", shape=out_shape, name="argmax")]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def prepare_node(self, node, storage_map, compute_map, impl): def prepare_node(self, node, storage_map, compute_map, impl):
...@@ -1922,7 +1922,7 @@ class Dot(Op): ...@@ -1922,7 +1922,7 @@ class Dot(Op):
sz = sx[:-1] sz = sx[:-1]
i_dtypes = [input.type.dtype for input in inputs] i_dtypes = [input.type.dtype for input in inputs]
outputs = [tensor(aes.upcast(*i_dtypes), shape=sz)] outputs = [tensor(dtype=aes.upcast(*i_dtypes), shape=sz)]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def perform(self, node, inp, out): def perform(self, node, inp, out):
......
...@@ -641,7 +641,7 @@ class Reshape(COp): ...@@ -641,7 +641,7 @@ class Reshape(COp):
except NotScalarConstantError: except NotScalarConstantError:
pass pass
return Apply(self, [x, shp], [tensor(x.type.dtype, shape=out_shape)]) return Apply(self, [x, shp], [tensor(dtype=x.type.dtype, shape=out_shape)])
def perform(self, node, inp, out_, params): def perform(self, node, inp, out_, params):
x, shp = inp x, shp = inp
......
差异被折叠。
...@@ -556,13 +556,13 @@ def test_get_var_by_name(): ...@@ -556,13 +556,13 @@ def test_get_var_by_name():
def test_clone_new_inputs(): def test_clone_new_inputs():
"""Make sure that `Apply.clone_with_new_inputs` properly handles `Type` changes.""" """Make sure that `Apply.clone_with_new_inputs` properly handles `Type` changes."""
x = at.tensor(np.float64, shape=(None,)) x = at.tensor(dtype=np.float64, shape=(None,))
y = at.tensor(np.float64, shape=(1,)) y = at.tensor(dtype=np.float64, shape=(1,))
z = at.add(x, y) z = at.add(x, y)
assert z.type.shape == (None,) assert z.type.shape == (None,)
x_new = at.tensor(np.float64, shape=(1,)) x_new = at.tensor(dtype=np.float64, shape=(1,))
# The output nodes should be reconstructed, because the input types' static # The output nodes should be reconstructed, because the input types' static
# shape information increased in specificity # shape information increased in specificity
......
...@@ -146,7 +146,7 @@ def test_Elemwise(inputs, input_vals, output_fn, exc): ...@@ -146,7 +146,7 @@ def test_Elemwise(inputs, input_vals, output_fn, exc):
# `{'drop': [1], 'shuffle': [2, 0], 'augment': [0, 2, 4]}` # `{'drop': [1], 'shuffle': [2, 0], 'augment': [0, 2, 4]}`
( (
set_test_value( set_test_value(
at.tensor(config.floatX, shape=(None, 1, None), name="a"), at.tensor(dtype=config.floatX, shape=(None, 1, None), name="a"),
np.array([[[1.0, 2.0]], [[3.0, 4.0]]], dtype=config.floatX), np.array([[[1.0, 2.0]], [[3.0, 4.0]]], dtype=config.floatX),
), ),
("x", 2, "x", 0, "x"), ("x", 2, "x", 0, "x"),
...@@ -155,21 +155,21 @@ def test_Elemwise(inputs, input_vals, output_fn, exc): ...@@ -155,21 +155,21 @@ def test_Elemwise(inputs, input_vals, output_fn, exc):
# `{'drop': [1], 'shuffle': [0], 'augment': []}` # `{'drop': [1], 'shuffle': [0], 'augment': []}`
( (
set_test_value( set_test_value(
at.tensor(config.floatX, shape=(None, 1), name="a"), at.tensor(dtype=config.floatX, shape=(None, 1), name="a"),
np.array([[1.0], [2.0], [3.0], [4.0]], dtype=config.floatX), np.array([[1.0], [2.0], [3.0], [4.0]], dtype=config.floatX),
), ),
(0,), (0,),
), ),
( (
set_test_value( set_test_value(
at.tensor(config.floatX, shape=(None, 1), name="a"), at.tensor(dtype=config.floatX, shape=(None, 1), name="a"),
np.array([[1.0], [2.0], [3.0], [4.0]], dtype=config.floatX), np.array([[1.0], [2.0], [3.0], [4.0]], dtype=config.floatX),
), ),
(0,), (0,),
), ),
( (
set_test_value( set_test_value(
at.tensor(config.floatX, shape=(1, 1, 1), name="a"), at.tensor(dtype=config.floatX, shape=(1, 1, 1), name="a"),
np.array([[[1.0]]], dtype=config.floatX), np.array([[[1.0]]], dtype=config.floatX),
), ),
(), (),
......
...@@ -270,7 +270,7 @@ rng = np.random.default_rng(42849) ...@@ -270,7 +270,7 @@ rng = np.random.default_rng(42849)
np.array([[1, 2], [3, 4]], dtype=np.float64), np.array([[1, 2], [3, 4]], dtype=np.float64),
), ),
set_test_value( set_test_value(
at.tensor("float64", shape=(1, None, None)), at.tensor(dtype="float64", shape=(1, None, None)),
np.eye(2)[None, ...], np.eye(2)[None, ...],
), ),
], ],
......
...@@ -582,7 +582,7 @@ def test_debugprint_mitmot(): ...@@ -582,7 +582,7 @@ def test_debugprint_mitmot():
def test_debugprint_compiled_fn(): def test_debugprint_compiled_fn():
M = at.tensor(np.float64, shape=(20000, 2, 2)) M = at.tensor(dtype=np.float64, shape=(20000, 2, 2))
one = at.as_tensor(1, dtype=np.int64) one = at.as_tensor(1, dtype=np.int64)
zero = at.as_tensor(0, dtype=np.int64) zero = at.as_tensor(0, dtype=np.int64)
......
...@@ -607,7 +607,7 @@ def test_mvnormal_ShapeFeature(): ...@@ -607,7 +607,7 @@ def test_mvnormal_ShapeFeature():
assert M_at in graph_inputs([s2]) assert M_at in graph_inputs([s2])
# Test broadcasted shapes # Test broadcasted shapes
mean = tensor(config.floatX, shape=(1, None)) mean = tensor(dtype=config.floatX, shape=(1, None))
mean.tag.test_value = np.array([[0, 1, 2]], dtype=config.floatX) mean.tag.test_value = np.array([[0, 1, 2]], dtype=config.floatX)
test_covar = np.diag(np.array([1, 10, 100], dtype=config.floatX)) test_covar = np.diag(np.array([1, 10, 100], dtype=config.floatX))
......
...@@ -125,9 +125,9 @@ def test_RandomVariable_basics(): ...@@ -125,9 +125,9 @@ def test_RandomVariable_basics():
def test_RandomVariable_bcast(): def test_RandomVariable_bcast():
rv = RandomVariable("normal", 0, [0, 0], config.floatX, inplace=True) rv = RandomVariable("normal", 0, [0, 0], config.floatX, inplace=True)
mu = tensor(config.floatX, shape=(1, None, None)) mu = tensor(dtype=config.floatX, shape=(1, None, None))
mu.tag.test_value = np.zeros((1, 2, 3)).astype(config.floatX) mu.tag.test_value = np.zeros((1, 2, 3)).astype(config.floatX)
sd = tensor(config.floatX, shape=(None, None)) sd = tensor(dtype=config.floatX, shape=(None, None))
sd.tag.test_value = np.ones((2, 3)).astype(config.floatX) sd.tag.test_value = np.ones((2, 3)).astype(config.floatX)
s1 = iscalar() s1 = iscalar()
...@@ -163,10 +163,10 @@ def test_RandomVariable_bcast_specify_shape(): ...@@ -163,10 +163,10 @@ def test_RandomVariable_bcast_specify_shape():
s3 = Assert("testing")(s3, eq(s1, 1)) s3 = Assert("testing")(s3, eq(s1, 1))
size = specify_shape(at.as_tensor([s1, s3, s2, s2, s1]), (5,)) size = specify_shape(at.as_tensor([s1, s3, s2, s2, s1]), (5,))
mu = tensor(config.floatX, shape=(None, None, 1)) mu = tensor(dtype=config.floatX, shape=(None, None, 1))
mu.tag.test_value = np.random.normal(size=(2, 2, 1)).astype(config.floatX) mu.tag.test_value = np.random.normal(size=(2, 2, 1)).astype(config.floatX)
std = tensor(config.floatX, shape=(None, 1, 1)) std = tensor(dtype=config.floatX, shape=(None, 1, 1))
std.tag.test_value = np.ones((2, 1, 1)).astype(config.floatX) std.tag.test_value = np.ones((2, 1, 1)).astype(config.floatX)
res = rv(mu, std, size=size) res = rv(mu, std, size=size)
......
...@@ -69,7 +69,7 @@ def test_broadcast_params(): ...@@ -69,7 +69,7 @@ def test_broadcast_params():
# Try it in PyTensor # Try it in PyTensor
with config.change_flags(compute_test_value="raise"): with config.change_flags(compute_test_value="raise"):
mean = tensor(config.floatX, shape=(None, 1)) mean = tensor(dtype=config.floatX, shape=(None, 1))
mean.tag.test_value = np.array([[0], [10], [100]], dtype=config.floatX) mean.tag.test_value = np.array([[0], [10], [100]], dtype=config.floatX)
cov = matrix() cov = matrix()
cov.tag.test_value = np.diag(np.array([1e-6], dtype=config.floatX)) cov.tag.test_value = np.diag(np.array([1e-6], dtype=config.floatX))
......
...@@ -559,8 +559,8 @@ class TestUnbroadcast: ...@@ -559,8 +559,8 @@ class TestUnbroadcast:
self.mode = get_default_mode().including("canonicalize") self.mode = get_default_mode().including("canonicalize")
def test_local_useless_unbroadcast(self): def test_local_useless_unbroadcast(self):
x1 = tensor("float64", shape=(1, 2)) x1 = tensor(dtype="float64", shape=(1, 2))
x2 = tensor("float64", shape=(2, 1)) x2 = tensor(dtype="float64", shape=(2, 1))
unbroadcast_op = Unbroadcast(0) unbroadcast_op = Unbroadcast(0)
f = function([x1], unbroadcast_op(x1), mode=self.mode) f = function([x1], unbroadcast_op(x1), mode=self.mode)
...@@ -576,7 +576,7 @@ class TestUnbroadcast: ...@@ -576,7 +576,7 @@ class TestUnbroadcast:
) )
def test_local_unbroadcast_lift(self): def test_local_unbroadcast_lift(self):
x = tensor("float64", shape=(1, 1)) x = tensor(dtype="float64", shape=(1, 1))
y = unbroadcast(at.exp(unbroadcast(x, 0)), 1) y = unbroadcast(at.exp(unbroadcast(x, 0)), 1)
assert ( assert (
...@@ -1693,8 +1693,8 @@ class TestLocalElemwiseAlloc: ...@@ -1693,8 +1693,8 @@ class TestLocalElemwiseAlloc:
], ],
) )
def test_basic(self, expr, x_shape, y_shape): def test_basic(self, expr, x_shape, y_shape):
x = at.tensor("int64", shape=(None,) * len(x_shape), name="x") x = at.tensor(dtype="int64", shape=(None,) * len(x_shape), name="x")
y = at.tensor("int64", shape=(None,) * len(y_shape), name="y") y = at.tensor(dtype="int64", shape=(None,) * len(y_shape), name="y")
z = expr(x, y) z = expr(x, y)
z_opt = pytensor.function( z_opt = pytensor.function(
......
...@@ -1125,7 +1125,7 @@ class TestFusion: ...@@ -1125,7 +1125,7 @@ class TestFusion:
"inplace", "inplace",
) )
x = tensor("floatX", shape=(None, None, None), name="x") x = tensor(dtype="floatX", shape=(None, None, None), name="x")
out = exp(x).sum(axis=axis) out = exp(x).sum(axis=axis)
out_fn = function([x], out, mode=mode) out_fn = function([x], out, mode=mode)
...@@ -1151,7 +1151,7 @@ class TestFusion: ...@@ -1151,7 +1151,7 @@ class TestFusion:
) )
# `Elemwise`s with more than one client shouldn't be rewritten # `Elemwise`s with more than one client shouldn't be rewritten
x = tensor("floatX", shape=(None, None, None), name="x") x = tensor(dtype="floatX", shape=(None, None, None), name="x")
exp_x = exp(x) exp_x = exp(x)
out = exp_x.sum(axis=axis) + exp(x) out = exp_x.sum(axis=axis) + exp(x)
...@@ -1176,8 +1176,8 @@ class TestFusion: ...@@ -1176,8 +1176,8 @@ class TestFusion:
"inplace", "inplace",
) )
x = tensor("floatX", shape=(None, None, None), name="x") x = tensor(dtype="floatX", shape=(None, None, None), name="x")
y = tensor("floatX", shape=(None, None, None), name="y") y = tensor(dtype="floatX", shape=(None, None, None), name="y")
out = (x + y).sum(axis=axis) out = (x + y).sum(axis=axis)
out_fn = function([x, y], out, mode=mode) out_fn = function([x, y], out, mode=mode)
......
...@@ -959,11 +959,11 @@ class TestAlgebraicCanonizer: ...@@ -959,11 +959,11 @@ class TestAlgebraicCanonizer:
def test_mismatching_types(self): def test_mismatching_types(self):
a = at.as_tensor([[0.0]], dtype=np.float64) a = at.as_tensor([[0.0]], dtype=np.float64)
b = tensor("float64", shape=(None,)).dimshuffle("x", 0) b = tensor(dtype="float64", shape=(None,)).dimshuffle("x", 0)
z = add(a, b) z = add(a, b)
# Construct a node with the wrong output `Type` # Construct a node with the wrong output `Type`
z = Apply( z = Apply(
z.owner.op, z.owner.inputs, [tensor("float64", shape=(None, None))] z.owner.op, z.owner.inputs, [tensor(dtype="float64", shape=(None, None))]
).outputs[0] ).outputs[0]
z_rewritten = rewrite_graph( z_rewritten = rewrite_graph(
......
...@@ -494,7 +494,7 @@ def test_local_Shape_of_SpecifyShape_partial(s1): ...@@ -494,7 +494,7 @@ def test_local_Shape_of_SpecifyShape_partial(s1):
def test_local_Shape_i_ground(): def test_local_Shape_i_ground():
x = tensor(np.float64, shape=(None, 2)) x = tensor(dtype=np.float64, shape=(None, 2))
s = Shape_i(1)(x) s = Shape_i(1)(x)
fgraph = FunctionGraph(outputs=[s], clone=False) fgraph = FunctionGraph(outputs=[s], clone=False)
......
...@@ -92,7 +92,7 @@ z = create_pytensor_param(np.random.default_rng().integers(0, 4, size=(2, 2))) ...@@ -92,7 +92,7 @@ z = create_pytensor_param(np.random.default_rng().integers(0, 4, size=(2, 2)))
def test_local_replace_AdvancedSubtensor(indices, is_none): def test_local_replace_AdvancedSubtensor(indices, is_none):
X_val = np.random.normal(size=(4, 4, 4)) X_val = np.random.normal(size=(4, 4, 4))
X = tensor(np.float64, shape=(None, None, None), name="X") X = tensor(dtype=np.float64, shape=(None, None, None), name="X")
X.tag.test_value = X_val X.tag.test_value = X_val
Y = X[indices] Y = X[indices]
...@@ -932,7 +932,7 @@ class TestLocalSubtensorLift: ...@@ -932,7 +932,7 @@ class TestLocalSubtensorLift:
assert (f1(xval) == xval[:2, :5]).all() assert (f1(xval) == xval[:2, :5]).all()
# corner case 1: Unbroadcast changes dims which are dropped through subtensor # corner case 1: Unbroadcast changes dims which are dropped through subtensor
y = tensor("float64", shape=(1, 10, 1, 3), name="x") y = tensor(dtype="float64", shape=(1, 10, 1, 3), name="x")
yval = np.random.random((1, 10, 1, 3)).astype(config.floatX) yval = np.random.random((1, 10, 1, 3)).astype(config.floatX)
assert y.broadcastable == (True, False, True, False) assert y.broadcastable == (True, False, True, False)
newy = Unbroadcast(0, 2)(y) newy = Unbroadcast(0, 2)(y)
...@@ -956,7 +956,7 @@ class TestLocalSubtensorLift: ...@@ -956,7 +956,7 @@ class TestLocalSubtensorLift:
assert (f3(yval) == yval[:, 3, 0]).all() assert (f3(yval) == yval[:, 3, 0]).all()
# corner case 3: subtensor idx_list is shorter than Unbroadcast.axis # corner case 3: subtensor idx_list is shorter than Unbroadcast.axis
z = tensor("float64", shape=(4, 10, 3, 1), name="x") z = tensor(dtype="float64", shape=(4, 10, 3, 1), name="x")
zval = np.random.random((4, 10, 3, 1)).astype(config.floatX) zval = np.random.random((4, 10, 3, 1)).astype(config.floatX)
assert z.broadcastable == (False, False, False, True) assert z.broadcastable == (False, False, False, True)
newz = Unbroadcast(3)(z) newz = Unbroadcast(3)(z)
...@@ -1911,7 +1911,7 @@ def test_local_subtensor_of_alloc(): ...@@ -1911,7 +1911,7 @@ def test_local_subtensor_of_alloc():
def test_local_subtensor_shape_constant(): def test_local_subtensor_shape_constant():
x = tensor(np.float64, shape=(1, None)).shape[0] x = tensor(dtype=np.float64, shape=(1, None)).shape[0]
(res,) = local_subtensor_shape_constant.transform(None, x.owner) (res,) = local_subtensor_shape_constant.transform(None, x.owner)
assert isinstance(res, Constant) assert isinstance(res, Constant)
assert res.data == 1 assert res.data == 1
...@@ -1921,21 +1921,21 @@ def test_local_subtensor_shape_constant(): ...@@ -1921,21 +1921,21 @@ def test_local_subtensor_shape_constant():
assert isinstance(res, Constant) assert isinstance(res, Constant)
assert res.data == 1 assert res.data == 1
x = _shape(tensor(np.float64, shape=(1, None)))[lscalar()] x = _shape(tensor(dtype=np.float64, shape=(1, None)))[lscalar()]
assert not local_subtensor_shape_constant.transform(None, x.owner) assert not local_subtensor_shape_constant.transform(None, x.owner)
x = _shape(tensor(np.float64, shape=(1, None)))[0:] x = _shape(tensor(dtype=np.float64, shape=(1, None)))[0:]
assert not local_subtensor_shape_constant.transform(None, x.owner) assert not local_subtensor_shape_constant.transform(None, x.owner)
x = _shape(tensor(np.float64, shape=(1, None)))[lscalar() :] x = _shape(tensor(dtype=np.float64, shape=(1, None)))[lscalar() :]
assert not local_subtensor_shape_constant.transform(None, x.owner) assert not local_subtensor_shape_constant.transform(None, x.owner)
x = _shape(tensor(np.float64, shape=(1, 1)))[1:] x = _shape(tensor(dtype=np.float64, shape=(1, 1)))[1:]
(res,) = local_subtensor_shape_constant.transform(None, x.owner) (res,) = local_subtensor_shape_constant.transform(None, x.owner)
assert isinstance(res, Constant) assert isinstance(res, Constant)
assert np.array_equal(res.data, [1]) assert np.array_equal(res.data, [1])
x = _shape(tensor(np.float64, shape=(None, 1, 1)))[1:] x = _shape(tensor(dtype=np.float64, shape=(None, 1, 1)))[1:]
(res,) = local_subtensor_shape_constant.transform(None, x.owner) (res,) = local_subtensor_shape_constant.transform(None, x.owner)
assert isinstance(res, Constant) assert isinstance(res, Constant)
assert np.array_equal(res.data, [1, 1]) assert np.array_equal(res.data, [1, 1])
......
...@@ -1177,7 +1177,7 @@ def test_get_vector_length(): ...@@ -1177,7 +1177,7 @@ def test_get_vector_length():
# Test `Alloc`s # Test `Alloc`s
assert 3 == get_vector_length(alloc(0, 3)) assert 3 == get_vector_length(alloc(0, 3))
assert 5 == get_vector_length(tensor(np.float64, shape=(5,))) assert 5 == get_vector_length(tensor(dtype=np.float64, shape=(5,)))
class TestJoinAndSplit: class TestJoinAndSplit:
...@@ -4263,10 +4263,10 @@ class TestTakeAlongAxis: ...@@ -4263,10 +4263,10 @@ class TestTakeAlongAxis:
indices = rng.integers(low=0, high=shape[axis or 0], size=indices_size) indices = rng.integers(low=0, high=shape[axis or 0], size=indices_size)
arr_in = at.tensor( arr_in = at.tensor(
config.floatX, shape=tuple(1 if s == 1 else None for s in arr.shape) dtype=config.floatX, shape=tuple(1 if s == 1 else None for s in arr.shape)
) )
indices_in = at.tensor( indices_in = at.tensor(
np.int64, shape=tuple(1 if s == 1 else None for s in indices.shape) dtype=np.int64, shape=tuple(1 if s == 1 else None for s in indices.shape)
) )
out = at.take_along_axis(arr_in, indices_in, axis) out = at.take_along_axis(arr_in, indices_in, axis)
...@@ -4278,12 +4278,12 @@ class TestTakeAlongAxis: ...@@ -4278,12 +4278,12 @@ class TestTakeAlongAxis:
) )
def test_ndim_dtype_failures(self): def test_ndim_dtype_failures(self):
arr = at.tensor(config.floatX, shape=(None,) * 2) arr = at.tensor(dtype=config.floatX, shape=(None,) * 2)
indices = at.tensor(np.int64, shape=(None,) * 3) indices = at.tensor(dtype=np.int64, shape=(None,) * 3)
with pytest.raises(ValueError): with pytest.raises(ValueError):
at.take_along_axis(arr, indices) at.take_along_axis(arr, indices)
indices = at.tensor(np.float64, shape=(None,) * 2) indices = at.tensor(dtype=np.float64, shape=(None,) * 2)
with pytest.raises(IndexError): with pytest.raises(IndexError):
at.take_along_axis(arr, indices) at.take_along_axis(arr, indices)
...@@ -4310,7 +4310,7 @@ def test_oriented_stack_functions(func): ...@@ -4310,7 +4310,7 @@ def test_oriented_stack_functions(func):
with pytest.raises(ValueError): with pytest.raises(ValueError):
func() func()
a = at.tensor(np.float64, shape=(None, None, None)) a = at.tensor(dtype=np.float64, shape=(None, None, None))
with pytest.raises(ValueError): with pytest.raises(ValueError):
func(a, a) func(a, a)
...@@ -185,7 +185,7 @@ class TestDimShuffle(unittest_tools.InferShapeTester): ...@@ -185,7 +185,7 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
assert np.allclose(np.mean(block_diffs), 0) assert np.allclose(np.mean(block_diffs), 0)
def test_static_shape(self): def test_static_shape(self):
x = tensor(np.float64, shape=(1, 2), name="x") x = tensor(dtype=np.float64, shape=(1, 2), name="x")
y = x.dimshuffle([0, 1, "x"]) y = x.dimshuffle([0, 1, "x"])
assert y.type.shape == (1, 2, 1) assert y.type.shape == (1, 2, 1)
...@@ -852,8 +852,8 @@ class TestElemwise(unittest_tools.InferShapeTester): ...@@ -852,8 +852,8 @@ class TestElemwise(unittest_tools.InferShapeTester):
z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_1_shape]) z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_1_shape])
def test_shape_types(self): def test_shape_types(self):
x = tensor(np.float64, (None, 1)) x = tensor(dtype=np.float64, shape=(None, 1))
y = tensor(np.float64, (50, 10)) y = tensor(dtype=np.float64, shape=(50, 10))
z = x * y z = x * y
...@@ -864,33 +864,33 @@ class TestElemwise(unittest_tools.InferShapeTester): ...@@ -864,33 +864,33 @@ class TestElemwise(unittest_tools.InferShapeTester):
assert all(isinstance(v.type, TensorType) for v in out_shape) assert all(isinstance(v.type, TensorType) for v in out_shape)
def test_static_shape_unary(self): def test_static_shape_unary(self):
x = tensor("float64", shape=(None, 0, 1, 5)) x = tensor(dtype="float64", shape=(None, 0, 1, 5))
assert exp(x).type.shape == (None, 0, 1, 5) assert exp(x).type.shape == (None, 0, 1, 5)
def test_static_shape_binary(self): def test_static_shape_binary(self):
x = tensor("float64", shape=(None, 5)) x = tensor(dtype="float64", shape=(None, 5))
y = tensor("float64", shape=(None, 5)) y = tensor(dtype="float64", shape=(None, 5))
assert (x + y).type.shape == (None, 5) assert (x + y).type.shape == (None, 5)
x = tensor("float64", shape=(None, 5)) x = tensor(dtype="float64", shape=(None, 5))
y = tensor("float64", shape=(10, 5)) y = tensor(dtype="float64", shape=(10, 5))
assert (x + y).type.shape == (10, 5) assert (x + y).type.shape == (10, 5)
x = tensor("float64", shape=(1, 5)) x = tensor(dtype="float64", shape=(1, 5))
y = tensor("float64", shape=(10, 5)) y = tensor(dtype="float64", shape=(10, 5))
assert (x + y).type.shape == (10, 5) assert (x + y).type.shape == (10, 5)
x = tensor("float64", shape=(None, 1)) x = tensor(dtype="float64", shape=(None, 1))
y = tensor("float64", shape=(1, 1)) y = tensor(dtype="float64", shape=(1, 1))
assert (x + y).type.shape == (None, 1) assert (x + y).type.shape == (None, 1)
x = tensor("float64", shape=(0, 0, 0)) x = tensor(dtype="float64", shape=(0, 0, 0))
y = tensor("float64", shape=(0, 1, None)) y = tensor(dtype="float64", shape=(0, 1, None))
assert (x + y).type.shape == (0, 0, 0) assert (x + y).type.shape == (0, 0, 0)
def test_invalid_static_shape(self): def test_invalid_static_shape(self):
x = tensor("float64", shape=(2,)) x = tensor(dtype="float64", shape=(2,))
y = tensor("float64", shape=(3,)) y = tensor(dtype="float64", shape=(3,))
with pytest.raises( with pytest.raises(
ValueError, ValueError,
match=re.escape("Incompatible Elemwise input shapes [(2,), (3,)]"), match=re.escape("Incompatible Elemwise input shapes [(2,), (3,)]"),
......
...@@ -1335,7 +1335,7 @@ class TestBroadcastTo(utt.InferShapeTester): ...@@ -1335,7 +1335,7 @@ class TestBroadcastTo(utt.InferShapeTester):
def test_infer_shape(self): def test_infer_shape(self):
rng = np.random.default_rng(43) rng = np.random.default_rng(43)
a = tensor(config.floatX, shape=(None, 1, None)) a = tensor(dtype=config.floatX, shape=(None, 1, None))
shape = list(a.shape) shape = list(a.shape)
out = self.op(a, shape) out = self.op(a, shape)
...@@ -1346,7 +1346,7 @@ class TestBroadcastTo(utt.InferShapeTester): ...@@ -1346,7 +1346,7 @@ class TestBroadcastTo(utt.InferShapeTester):
self.op_class, self.op_class,
) )
a = tensor(config.floatX, shape=(None, 1, None)) a = tensor(dtype=config.floatX, shape=(None, 1, None))
shape = [iscalar() for i in range(4)] shape = [iscalar() for i in range(4)]
self._compile_and_check( self._compile_and_check(
[a] + shape, [a] + shape,
......
...@@ -650,7 +650,7 @@ class TestUnbroadcast: ...@@ -650,7 +650,7 @@ class TestUnbroadcast:
class TestUnbroadcastInferShape(utt.InferShapeTester): class TestUnbroadcastInferShape(utt.InferShapeTester):
def test_basic(self): def test_basic(self):
rng = np.random.default_rng(3453) rng = np.random.default_rng(3453)
adtens4 = tensor("float64", shape=(1, 1, 1, None)) adtens4 = tensor(dtype="float64", shape=(1, 1, 1, None))
adtens4_val = rng.random((1, 1, 1, 3)).astype(config.floatX) adtens4_val = rng.random((1, 1, 1, 3)).astype(config.floatX)
self._compile_and_check( self._compile_and_check(
[adtens4], [adtens4],
...@@ -666,7 +666,7 @@ def test_shape_tuple(): ...@@ -666,7 +666,7 @@ def test_shape_tuple():
x = Variable(MyType2(), None, None) x = Variable(MyType2(), None, None)
assert shape_tuple(x) == () assert shape_tuple(x) == ()
x = tensor(np.float64, shape=(1, 2, None)) x = tensor(dtype=np.float64, shape=(1, 2, None))
res = shape_tuple(x) res = shape_tuple(x)
assert isinstance(res, tuple) assert isinstance(res, tuple)
assert isinstance(res[0], ScalarConstant) assert isinstance(res[0], ScalarConstant)
......
...@@ -7,7 +7,20 @@ import pytest ...@@ -7,7 +7,20 @@ import pytest
import pytensor.tensor as at import pytensor.tensor as at
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.tensor.shape import SpecifyShape from pytensor.tensor.shape import SpecifyShape
from pytensor.tensor.type import TensorType from pytensor.tensor.type import (
TensorType,
col,
matrix,
row,
scalar,
tensor,
tensor3,
tensor4,
tensor5,
tensor6,
tensor7,
vector,
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -326,3 +339,114 @@ def test_deprecated_kwargs(): ...@@ -326,3 +339,114 @@ def test_deprecated_kwargs():
new_res = res.clone(broadcastable=(False, True)) new_res = res.clone(broadcastable=(False, True))
assert new_res.shape == (None, 1) assert new_res.shape == (None, 1)
def test_tensor_creator_helper():
res = tensor(shape=(5, None))
assert res.type == TensorType(config.floatX, shape=(5, None))
assert res.name is None
res = tensor(dtype="int64", shape=(5, None), name="custom")
assert res.type == TensorType(dtype="int64", shape=(5, None))
assert res.name == "custom"
# Test with positional name argument
res = tensor("custom", dtype="int64", shape=(5, None))
assert res.type == TensorType(dtype="int64", shape=(5, None))
assert res.name == "custom"
with pytest.warns(
DeprecationWarning, match="The `broadcastable` keyword is deprecated"
):
res = tensor(dtype="int64", broadcastable=(True, False), name="custom")
assert res.type == TensorType("int64", shape=(1, None))
assert res.name == "custom"
@pytest.mark.parametrize("dtype", ("floatX", "float64", bool, np.int64))
def test_tensor_creator_dtype_catch(dtype):
with pytest.raises(
ValueError,
match="This name looks like a dtype, which you should pass as a keyword argument only",
):
tensor(dtype, shape=(None,))
# This should work
assert tensor(dtype=dtype, shape=(None,))
def test_tensor_creator_ignores_rare_dtype_name():
# This could be a dtype, but we assume it's a name
assert tensor("a", shape=(None,)).type.dtype == config.floatX
def test_scalar_creator_helper():
default = scalar()
assert default.type.dtype == config.floatX
assert default.type.ndim == 0
assert default.type.shape == ()
assert default.name is None
custom = scalar(name="custom", dtype="int64")
assert custom.dtype == "int64"
assert custom.type.ndim == 0
assert custom.type.shape == ()
@pytest.mark.parametrize(
"helper, ndims",
[
(vector, 1),
(matrix, 2),
(row, 2),
(col, 2),
(tensor3, 3),
(tensor4, 4),
(tensor5, 5),
(tensor6, 6),
(tensor7, 7),
],
)
def test_tensor_creator_helpers(helper, ndims):
if helper is row:
default_shape = (1, None)
custom_shape = (1, 5)
elif helper is col:
default_shape = (None, 1)
custom_shape = (5, 1)
else:
default_shape = (None,) * ndims
custom_shape = tuple(range(ndims))
default = helper()
assert default.type.dtype == config.floatX
assert default.type.ndim == ndims
assert default.type.shape == default_shape
assert default.name is None
assert helper(shape=default_shape).type == default.type
custom = helper(name="custom", dtype="int64", shape=custom_shape)
assert custom.type.dtype == "int64"
assert custom.type.ndim == ndims
assert custom.type.shape == custom_shape
assert custom.name == "custom"
with pytest.raises(TypeError, match="Shape must be a tuple"):
helper(shape=list(default_shape))
with pytest.raises(ValueError, match=f"Shape must be a tuple of length {ndims}"):
helper(shape=(None,) + default_shape)
with pytest.raises(TypeError, match="Shape entries must be None or integer"):
helper(shape=(1.0,) * ndims)
@pytest.mark.parametrize("helper", (row, col))
def test_row_matrix_creator_helpers(helper):
if helper is row:
match = "The first dimension of a `row` must have shape 1, got 2"
else:
match = "The second dimension of a `col` must have shape 1, got 5"
with pytest.raises(ValueError, match=match):
helper(shape=(2, 5))
...@@ -303,7 +303,7 @@ class TestIfelse(utt.OptimizationTestMixin): ...@@ -303,7 +303,7 @@ class TestIfelse(utt.OptimizationTestMixin):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
data = rng.random(5).astype(self.dtype) data = rng.random(5).astype(self.dtype)
x = self.shared(data) x = self.shared(data)
y = col("y", self.dtype) y = col("y", dtype=self.dtype)
cond = iscalar("cond") cond = iscalar("cond")
with pytest.raises(TypeError): with pytest.raises(TypeError):
...@@ -316,7 +316,7 @@ class TestIfelse(utt.OptimizationTestMixin): ...@@ -316,7 +316,7 @@ class TestIfelse(utt.OptimizationTestMixin):
data = rng.random(5).astype(self.dtype) data = rng.random(5).astype(self.dtype)
x = self.shared(data) x = self.shared(data)
# print x.broadcastable # print x.broadcastable
y = row("y", self.dtype) y = row("y", dtype=self.dtype)
# print y.broadcastable # print y.broadcastable
cond = iscalar("cond") cond = iscalar("cond")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论