提交 f64d243d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor tests.tensor.test_basic_opt tests

* Convert tests to pytest `parametrized` tests * Use `pytest.raises` * Remove unused timing code and methods
上级 c86a717b
import copy
import time
import numpy as np
import pytest
......@@ -354,45 +353,36 @@ class TestFusion:
_shared = staticmethod(shared)
topo_exclude = ()
def do(self, mode, shared_fn, shp, nb_repeat=1, assert_len_topo=True, slice=None):
"""
param shared_fn: if None, will use function
verify that the elemwise fusion work
Test with and without DimShuffle
"""
# TODO: disable the canonizer?
def my_init(shp, dtype="float64", num=0):
ret = np.zeros(shp, dtype=dtype) + num
return ret
fw, fx, fy, fz = [
tensor(dtype="float32", broadcastable=[False] * len(shp), name=n)
for n in "wxyz"
]
dw, dx, dy, dz = [
tensor(dtype="float64", broadcastable=[False] * len(shp), name=n)
for n in "wxyz"
]
ix, iy, iz = [
tensor(dtype="int32", broadcastable=[False] * len(shp), name=n)
for n in "xyz"
]
fv = fvector("v")
fs = fscalar("s")
fwv = my_init(shp, "float32", 1)
fxv = my_init(shp, "float32", 2)
fyv = my_init(shp, "float32", 3)
fzv = my_init(shp, "float32", 4)
fvv = _asarray(np.random.random((shp[0])), dtype="float32")
fsv = np.asarray(np.random.random(), dtype="float32")
dwv = my_init(shp, "float64", 5)
ixv = _asarray(my_init(shp, num=60), dtype="int32")
iyv = _asarray(my_init(shp, num=70), dtype="int32")
izv = _asarray(my_init(shp, num=70), dtype="int32")
fwx = fw + fx
ftanx = tan(fx)
cases = [
def my_init(dtype="float64", num=0):
return np.zeros((5, 5), dtype=dtype) + num
fw, fx, fy, fz = [
tensor(dtype="float32", broadcastable=[False] * 2, name=n) for n in "wxyz"
]
dw, dx, dy, dz = [
tensor(dtype="float64", broadcastable=[False] * 2, name=n) for n in "wxyz"
]
ix, iy, iz = [
tensor(dtype="int32", broadcastable=[False] * 2, name=n) for n in "xyz"
]
fv = fvector("v")
fs = fscalar("s")
fwv = my_init("float32", 1)
fxv = my_init("float32", 2)
fyv = my_init("float32", 3)
fzv = my_init("float32", 4)
fvv = _asarray(np.random.random(5), dtype="float32")
fsv = np.asarray(np.random.random(), dtype="float32")
dwv = my_init("float64", 5)
ixv = _asarray(my_init(num=60), dtype="int32")
iyv = _asarray(my_init(num=70), dtype="int32")
izv = _asarray(my_init(num=70), dtype="int32")
fwx = fw + fx
ftanx = tan(fx)
@pytest.mark.parametrize(
"case",
[
(
fx + fy + fz,
(fx, fy, fz),
......@@ -991,68 +981,52 @@ class TestFusion:
fxv * np.sin(fsv),
"float32",
),
]
if slice:
cases = cases[slice]
times = np.zeros(len(cases))
fail1 = []
fail2 = []
fail3 = []
fail4 = []
for (
id,
[g, sym_inputs, val_inputs, nb_elemwise, answer, out_dtype],
) in enumerate(cases):
if isinstance(out_dtype, dict):
out_dtype = out_dtype[config.cast_policy]
if shared_fn is None:
f = function(list(sym_inputs), g, mode=mode)
for x in range(nb_repeat):
out = f(*val_inputs)
t1 = time.time()
else:
out = shared_fn(np.zeros(shp, dtype=out_dtype), "out")
assert out.dtype == g.dtype
f = function(sym_inputs, [], updates=[(out, g)], mode=mode)
t0 = time.time()
for x in range(nb_repeat):
f(*val_inputs)
t1 = time.time()
out = out.get_value()
times[id] = t1 - t0
atol = 1e-8
if out_dtype == "float32":
atol = 1e-6
if not np.allclose(out, answer * nb_repeat, atol=atol):
fail1.append(id)
topo = f.maker.fgraph.toposort()
topo_ = [n for n in topo if not isinstance(n.op, self.topo_exclude)]
if assert_len_topo:
if not len(topo_) == nb_elemwise:
fail3.append((id, topo_, nb_elemwise))
if nb_elemwise == 1:
# if no variable appears multiple times in the
# input of g,
# check that the number of input to the Composite
# Elemwise is ok
if len(set(g.owner.inputs)) == len(g.owner.inputs):
expected_len_sym_inputs = np.sum(
[not isinstance(x, Constant) for x in topo_[0].inputs]
)
assert expected_len_sym_inputs == len(sym_inputs)
if not out_dtype == out.dtype:
fail4.append((id, out_dtype, out.dtype))
assert len(fail1 + fail2 + fail3 + fail4) == 0
return times
def test_elemwise_fusion(self):
shp = (5, 5)
self.do(self.mode, self._shared, shp)
],
)
def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True):
"""Verify that `Elemwise` fusion works."""
g, sym_inputs, val_inputs, nb_elemwise, answer, out_dtype = case
if isinstance(out_dtype, dict):
out_dtype = out_dtype[config.cast_policy]
if self._shared is None:
f = function(list(sym_inputs), g, mode=self.mode)
for x in range(nb_repeat):
out = f(*val_inputs)
else:
out = self._shared(np.zeros((5, 5), dtype=out_dtype), "out")
assert out.dtype == g.dtype
f = function(sym_inputs, [], updates=[(out, g)], mode=self.mode)
for x in range(nb_repeat):
f(*val_inputs)
out = out.get_value()
atol = 1e-8
if out_dtype == "float32":
atol = 1e-6
assert np.allclose(out, answer * nb_repeat, atol=atol)
topo = f.maker.fgraph.toposort()
topo_ = [n for n in topo if not isinstance(n.op, self.topo_exclude)]
if assert_len_topo:
assert len(topo_) == nb_elemwise
if nb_elemwise == 1:
# if no variable appears multiple times in the
# input of g,
# check that the number of input to the Composite
# Elemwise is ok
if len(set(g.owner.inputs)) == len(g.owner.inputs):
expected_len_sym_inputs = np.sum(
[not isinstance(x, Constant) for x in topo_[0].inputs]
)
assert expected_len_sym_inputs == len(sym_inputs)
assert out_dtype == out.dtype
def test_fusion_35_inputs(self):
r"""Make sure we don't fuse too many `Op`\s and go past the 31 function arguments limit."""
......@@ -1144,78 +1118,6 @@ class TestFusion:
np.random.random((5, 5)), np.random.random((5, 5)), np.random.random((5, 5))
)
def speed_fusion(self, s=None):
"""
param type s: a slice object
param s: a slice to apply to the case to execute. If None, exec all case.
"""
shp = (3000, 3000)
shp = (1000, 1000)
nb_repeat = 50
# linker=CLinker
# linker=OpWiseCLinker
mode1 = copy.copy(self.mode)
mode1._optimizer = mode1._optimizer.including("local_elemwise_fusion")
# TODO:clinker is much faster... but use to much memory
# Possible cause: as their is do deletion of intermediate value when we don't keep the fct.
# More plausible cause: we keep a link to the output data?
# Follow up. Clinker do the same... second cause?
mode2 = copy.copy(self.mode)
mode2._optimizer = mode2._optimizer.excluding("local_elemwise_fusion")
print("test with linker", str(mode1.linker))
times1 = self.do(
mode1,
self._shared,
shp,
nb_repeat=nb_repeat,
assert_len_topo=False,
slice=s,
)
times2 = self.do(
mode2,
self._shared,
shp,
nb_repeat=nb_repeat,
assert_len_topo=False,
slice=s,
)
print("times1 with local_elemwise_fusion")
print(times1, times1.min(), times1.max(), times1.sum())
print("times2 without local_elemwise_fusion")
print(times2, times2.min(), times2.max(), times2.sum())
d = times2 / times1
print("times2/times1")
print(d)
print(
"min",
d.min(),
"argmin",
d.argmin(),
"max",
d.max(),
"mean",
d.mean(),
"std",
d.std(),
)
def speed_log_exp(self):
s = slice(31, 36)
print(
"time",
self.do(
self.mode,
self._shared,
shp=(1000, 1000),
assert_len_topo=False,
slice=s,
nb_repeat=100,
),
)
@pytest.mark.skipif(not config.cxx, reason="No cxx compiler")
def test_no_c_code(self):
r"""Make sure we avoid fusions for `Op`\s without C code implementations."""
......@@ -2342,78 +2244,87 @@ class TestLocalUselessSwitch:
def setup_method(self):
self.mode = mode_opt.excluding("constant_folding")
@pytest.mark.parametrize(
"dtype1",
["int32", "int64"],
)
@pytest.mark.parametrize(
"dtype2",
["int32", "int64"],
)
@pytest.mark.parametrize(
"cond",
[0, 1, np.array([True])],
)
def test_const(self, cond):
for dtype1 in ["int32", "int64"]:
for dtype2 in ["int32", "int64"]:
x = matrix("x", dtype=dtype1)
y = matrix("y", dtype=dtype2)
z = aet.switch(cond, x, y)
f = function([x, y], z, mode=self.mode)
assert (
len(
[
node.op
for node in f.maker.fgraph.toposort()
if (
isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, aes.basic.Switch)
)
]
)
== 0
def test_const(self, dtype1, dtype2, cond):
x = matrix("x", dtype=dtype1)
y = matrix("y", dtype=dtype2)
z = aet.switch(cond, x, y)
f = function([x, y], z, mode=self.mode)
assert not any(
[
node.op
for node in f.maker.fgraph.toposort()
if (
isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, aes.basic.Switch)
)
vx = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1)
vy = np.array([[7, 8, 9], [10, 11, 12]], dtype=dtype2)
np_res = np.where(cond, vx, vy)
assert np.array_equal(f(vx, vy), np_res)
def test_left_is_right(self):
for dtype1 in ["int32", "int64"]:
x = matrix("x", dtype=dtype1)
varc = matrix("varc", dtype=dtype1)
z1 = aet.switch(1, x, x)
z0 = aet.switch(0, x, x)
z2 = aet.switch(varc, x, x)
f1 = function([x], z1, mode=self.mode)
f0 = function([x], z0, mode=self.mode)
f2 = function([x, varc], z2, mode=self.mode)
topo = f1.maker.fgraph.toposort()
assert len(topo) == 1
assert topo[0].op == deep_copy_op
topo = f0.maker.fgraph.toposort()
assert len(topo) == 1
assert topo[0].op == deep_copy_op
topo = f2.maker.fgraph.toposort()
assert len(topo) == 1
assert topo[0].op == deep_copy_op
vx = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1)
vc = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1)
assert np.array_equal(f1(vx), vx)
assert np.array_equal(f0(vx), vx)
assert np.array_equal(f2(vx, vc), vx)
def test_shape_le_0(self):
for dtype1 in ["float32", "float64"]:
x = matrix("x", dtype=dtype1)
z0 = aet.switch(le(x.shape[0], 0), 0, x.shape[0])
f0 = function([x], z0, mode=self.mode)
assert isinstance(f0.maker.fgraph.toposort()[0].op, Shape_i)
z1 = aet.switch(le(x.shape[1], 0), 0, x.shape[1])
f1 = function([x], z1, mode=self.mode)
assert isinstance(f1.maker.fgraph.toposort()[0].op, Shape_i)
vx = np.random.standard_normal((0, 5)).astype(dtype1)
assert f0(vx) == 0
assert f1(vx) == 5
]
)
vx = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1)
vy = np.array([[7, 8, 9], [10, 11, 12]], dtype=dtype2)
np_res = np.where(cond, vx, vy)
assert np.array_equal(f(vx, vy), np_res)
@pytest.mark.parametrize(
"dtype1",
["int32", "int64"],
)
def test_left_is_right(self, dtype1):
x = matrix("x", dtype=dtype1)
varc = matrix("varc", dtype=dtype1)
z1 = aet.switch(1, x, x)
z0 = aet.switch(0, x, x)
z2 = aet.switch(varc, x, x)
f1 = function([x], z1, mode=self.mode)
f0 = function([x], z0, mode=self.mode)
f2 = function([x, varc], z2, mode=self.mode)
topo = f1.maker.fgraph.toposort()
assert len(topo) == 1
assert topo[0].op == deep_copy_op
topo = f0.maker.fgraph.toposort()
assert len(topo) == 1
assert topo[0].op == deep_copy_op
topo = f2.maker.fgraph.toposort()
assert len(topo) == 1
assert topo[0].op == deep_copy_op
vx = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1)
vc = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1)
assert np.array_equal(f1(vx), vx)
assert np.array_equal(f0(vx), vx)
assert np.array_equal(f2(vx, vc), vx)
@pytest.mark.parametrize(
"dtype1",
["float32", "float64"],
)
def test_shape_le_0(self, dtype1):
x = matrix("x", dtype=dtype1)
z0 = aet.switch(le(x.shape[0], 0), 0, x.shape[0])
f0 = function([x], z0, mode=self.mode)
assert isinstance(f0.maker.fgraph.toposort()[0].op, Shape_i)
z1 = aet.switch(le(x.shape[1], 0), 0, x.shape[1])
f1 = function([x], z1, mode=self.mode)
assert isinstance(f1.maker.fgraph.toposort()[0].op, Shape_i)
vx = np.random.standard_normal((0, 5)).astype(dtype1)
assert f0(vx) == 0
assert f1(vx) == 5
def test_broadcasting_1(self):
# test switch(cst, matrix, row)
......@@ -2489,13 +2400,9 @@ class TestLocalUselessSwitch:
class TestLocalMergeSwitchSameCond:
def test_elemwise(self):
# float Ops
mats = matrices("cabxy")
c, a, b, x, y = mats
s1 = aet.switch(c, a, b)
s2 = aet.switch(c, x, y)
for op in (
@pytest.mark.parametrize(
"op",
[
add,
sub,
mul,
......@@ -2511,30 +2418,53 @@ class TestLocalMergeSwitchSameCond:
eq,
neq,
aet_pow,
):
g = optimize(FunctionGraph(mats, [op(s1, s2)]))
assert str(g).count("Switch") == 1
# integer Ops
mats = imatrices("cabxy")
],
)
def test_elemwise_float_ops(self, op):
# float Ops
mats = matrices("cabxy")
c, a, b, x, y = mats
s1 = aet.switch(c, a, b)
s2 = aet.switch(c, x, y)
for op in (
g = optimize(FunctionGraph(mats, [op(s1, s2)]))
assert str(g).count("Switch") == 1
@pytest.mark.parametrize(
"op",
[
bitwise_and,
bitwise_or,
bitwise_xor,
):
g = optimize(FunctionGraph(mats, [op(s1, s2)]))
assert str(g).count("Switch") == 1
],
)
def test_elemwise_int_ops(self, op):
# integer Ops
mats = imatrices("cabxy")
c, a, b, x, y = mats
s1 = aet.switch(c, a, b)
s2 = aet.switch(c, x, y)
g = optimize(FunctionGraph(mats, [op(s1, s2)]))
assert str(g).count("Switch") == 1
@pytest.mark.parametrize("op", [add, mul])
def test_elemwise_multi_inputs(self, op):
# add/mul with more than two inputs
mats = imatrices("cabxy")
c, a, b, x, y = mats
s1 = aet.switch(c, a, b)
s2 = aet.switch(c, x, y)
u, v = matrices("uv")
s3 = aet.switch(c, u, v)
for op in (add, mul):
g = optimize(FunctionGraph(mats + [u, v], [op(s1, s2, s3)]))
assert str(g).count("Switch") == 1
g = optimize(FunctionGraph(mats + [u, v], [op(s1, s2, s3)]))
assert str(g).count("Switch") == 1
class TestLocalOptAlloc:
"""
TODO FIXME: These tests are incomplete; they need to `assert` something.
"""
dtype = "float32"
def test_sum_upcast(self):
......@@ -2571,20 +2501,17 @@ class TestLocalOptAllocF16(TestLocalOptAlloc):
class TestMakeVector(utt.InferShapeTester):
b = bscalar()
i = iscalar()
d = dscalar()
def setup_method(self):
self.rng = np.random.default_rng(utt.fetch_seed())
super().setup_method()
def test_make_vector(self):
b = bscalar()
i = iscalar()
d = dscalar()
# TODO: draw random values instead. Not really important.
val = {b: 2, i: -3, d: 0.7}
# Should work
for (dtype, inputs) in [
@pytest.mark.parametrize(
"dtype, inputs",
[
("int8", (b, b)),
("int32", (i, b)),
("int32", (b, i)),
......@@ -2593,56 +2520,63 @@ class TestMakeVector(utt.InferShapeTester):
("float64", (d, i)),
("float64", ()),
("int64", ()),
]:
mv = MakeVector(dtype=dtype)(*inputs)
assert mv.dtype == dtype
f = function([b, i, d], mv, on_unused_input="ignore")
f(val[b], val[i], val[d])
s = mv.sum()
gb = aesara.gradient.grad(s, b, disconnected_inputs="ignore")
gi = aesara.gradient.grad(s, i, disconnected_inputs="ignore")
gd = aesara.gradient.grad(s, d, disconnected_inputs="ignore")
g = function([b, i, d], [gb, gi, gd])
g_val = g(val[b], val[i], val[d])
if dtype in int_dtypes:
# The gradient should be 0
utt.assert_allclose(g_val, 0)
else:
for var, grval in zip((b, i, d), g_val):
float_inputs = []
if var.dtype in int_dtypes:
pass
# Currently we don't do any checks on these variables
# verify_grad doesn't support integer inputs yet
# however, the gradient on them is *not* defined to
# be 0
elif var not in inputs:
assert grval == 0
else:
float_inputs.append(var)
# Build a function that takes float_inputs, use fix values for the
# other inputs, and returns the MakeVector. Use it for verify_grad.
if float_inputs:
def fun(*fl_inputs):
f_inputs = []
for var in f_inputs:
if var in fl_inputs:
# use symbolic variable
f_inputs.append(var)
else:
# use constant value
f_inputs.append(val[var])
return MakeVector(dtype=dtype)(*f_inputs)
utt.verify_grad(fun, [val[ri] for ri in float_inputs])
# should fail
for (dtype, inputs) in [
],
)
def test_make_vector(self, dtype, inputs):
b, i, d = self.b, self.i, self.d
val = {b: 2, i: -3, d: 0.7}
mv = MakeVector(dtype=dtype)(*inputs)
assert mv.dtype == dtype
f = function([b, i, d], mv, on_unused_input="ignore")
f(val[b], val[i], val[d])
s = mv.sum()
gb = aesara.gradient.grad(s, b, disconnected_inputs="ignore")
gi = aesara.gradient.grad(s, i, disconnected_inputs="ignore")
gd = aesara.gradient.grad(s, d, disconnected_inputs="ignore")
g = function([b, i, d], [gb, gi, gd])
g_val = g(val[b], val[i], val[d])
if dtype in int_dtypes:
# The gradient should be 0
utt.assert_allclose(g_val, 0)
else:
for var, grval in zip((b, i, d), g_val):
float_inputs = []
if var.dtype in int_dtypes:
pass
# Currently we don't do any checks on these variables
# verify_grad doesn't support integer inputs yet
# however, the gradient on them is *not* defined to
# be 0
elif var not in inputs:
assert grval == 0
else:
float_inputs.append(var)
# Build a function that takes float_inputs, use fix values for the
# other inputs, and returns the MakeVector. Use it for verify_grad.
if float_inputs:
def fun(*fl_inputs):
f_inputs = []
for var in f_inputs:
if var in fl_inputs:
# use symbolic variable
f_inputs.append(var)
else:
# use constant value
f_inputs.append(val[var])
return MakeVector(dtype=dtype)(*f_inputs)
utt.verify_grad(fun, [val[ri] for ri in float_inputs])
@pytest.mark.parametrize(
"dtype, inputs",
[
("int8", (b, i)),
("int8", (i, b)),
("int8", (b, d)),
......@@ -2650,12 +2584,11 @@ class TestMakeVector(utt.InferShapeTester):
("int32", (d, i)),
("int32", (i, d)),
("float32", (i, d)),
]:
try:
MakeVector(dtype=dtype)(*inputs)
raise Exception("Aesara should have raised an error")
except AssertionError:
pass
],
)
def test_make_vector_fail(self, dtype, inputs):
with pytest.raises(AssertionError):
MakeVector(dtype=dtype)(*inputs)
def test_infer_shape(self):
adscal = dscalar()
......@@ -2824,8 +2757,9 @@ def test_local_join_make_vector():
assert check_stack_trace(f, ops_to_check="all")
def test_local_tensor_scalar_tensor():
dtypes = [
@pytest.mark.parametrize(
"dtype",
[
"int8",
"int16",
"int32",
......@@ -2838,25 +2772,24 @@ def test_local_tensor_scalar_tensor():
"float64",
"complex64",
"complex128",
]
for dtype in dtypes:
t_type = TensorType(dtype=dtype, broadcastable=())
t = t_type()
s = aet.scalar_from_tensor(t)
t2 = aet.tensor_from_scalar(s)
],
)
def test_local_tensor_scalar_tensor(dtype):
t_type = TensorType(dtype=dtype, broadcastable=())
t = t_type()
s = aet.scalar_from_tensor(t)
t2 = aet.tensor_from_scalar(s)
f = function([t], t2, mode=mode_opt)
e = f.maker.fgraph.toposort()
cast_nodes = [
n for n in e if isinstance(n.op, (TensorFromScalar, ScalarFromTensor))
]
assert len(cast_nodes) == 0
f(0)
f = function([t], t2, mode=mode_opt)
e = f.maker.fgraph.toposort()
assert not any(
[n for n in e if isinstance(n.op, (TensorFromScalar, ScalarFromTensor))]
)
def test_local_scalar_tensor_scalar():
dtypes = [
@pytest.mark.parametrize(
"dtype",
[
"int8",
"int16",
"int32",
......@@ -2869,21 +2802,19 @@ def test_local_scalar_tensor_scalar():
"float64",
"complex64",
"complex128",
]
for dtype in dtypes:
s_type = aes.Scalar(dtype=dtype)
s = s_type()
t = aet.tensor_from_scalar(s)
s2 = aet.scalar_from_tensor(t)
],
)
def test_local_scalar_tensor_scalar(dtype):
s_type = aes.Scalar(dtype=dtype)
s = s_type()
t = aet.tensor_from_scalar(s)
s2 = aet.scalar_from_tensor(t)
f = function([s], s2, mode=mode_opt)
e = f.maker.fgraph.toposort()
cast_nodes = [
n for n in e if isinstance(n.op, (TensorFromScalar, ScalarFromTensor))
]
assert len(cast_nodes) == 0
f(0)
f = function([s], s2, mode=mode_opt)
e = f.maker.fgraph.toposort()
assert not any(
[n for n in e if isinstance(n.op, (TensorFromScalar, ScalarFromTensor))]
)
def test_local_useless_split():
......@@ -2909,25 +2840,23 @@ def test_local_useless_split():
assert check_stack_trace(f_nonopt, ops_to_check="all")
def test_local_flatten_lift():
for i in range(1, 4):
x = tensor4()
out = aet.flatten(exp(x), i)
assert out.ndim == i
mode = get_default_mode()
mode = mode.including("local_reshape_lift")
f = function([x], out, mode=mode)
x_np = np.random.random((5, 4, 3, 2)).astype(config.floatX)
out_np = f(x_np)
topo = f.maker.fgraph.toposort()
shape_out_np = tuple(x_np.shape[: i - 1]) + (np.prod(x_np.shape[i - 1 :]),)
assert shape_out_np == out_np.shape
@pytest.mark.parametrize("i", list(range(1, 4)))
def test_local_flatten_lift(i):
x = tensor4()
out = aet.flatten(exp(x), i)
assert out.ndim == i
mode = get_default_mode()
mode = mode.including("local_reshape_lift")
f = function([x], out, mode=mode)
x_np = np.random.random((5, 4, 3, 2)).astype(config.floatX)
out_np = f(x_np)
topo = f.maker.fgraph.toposort()
shape_out_np = tuple(x_np.shape[: i - 1]) + (np.prod(x_np.shape[i - 1 :]),)
assert shape_out_np == out_np.shape
reshape_nodes = [n for n in topo if isinstance(n.op, Reshape)]
assert len(reshape_nodes) == 1 and aet.is_flat(
reshape_nodes[0].outputs[0], ndim=i
)
assert isinstance(topo[-1].op, Elemwise)
reshape_nodes = [n for n in topo if isinstance(n.op, Reshape)]
assert len(reshape_nodes) == 1 and aet.is_flat(reshape_nodes[0].outputs[0], ndim=i)
assert isinstance(topo[-1].op, Elemwise)
class TestReshape:
......@@ -3118,15 +3047,16 @@ class TestShapeI(utt.InferShapeTester):
super().setup_method()
def test_perform(self):
rng = np.random.default_rng(utt.fetch_seed())
advec = vector()
advec_val = np.random.random((3)).astype(config.floatX)
advec_val = rng.random((3)).astype(config.floatX)
f = function([advec], Shape_i(0)(advec))
out = f(advec_val)
utt.assert_allclose(out, advec_val.shape[0])
admat = matrix()
admat_val = np.random.random((4, 3)).astype(config.floatX)
admat_val = rng.random((4, 3)).astype(config.floatX)
for i in range(2):
f = function([admat], Shape_i(i)(admat))
out = f(admat_val)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论