提交 f64d243d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor tests.tensor.test_basic_opt tests

* Convert tests to pytest `parametrized` tests * Use `pytest.raises` * Remove unused timing code and methods
上级 c86a717b
import copy import copy
import time
import numpy as np import numpy as np
import pytest import pytest
...@@ -354,45 +353,36 @@ class TestFusion: ...@@ -354,45 +353,36 @@ class TestFusion:
_shared = staticmethod(shared) _shared = staticmethod(shared)
topo_exclude = () topo_exclude = ()
def do(self, mode, shared_fn, shp, nb_repeat=1, assert_len_topo=True, slice=None): def my_init(dtype="float64", num=0):
""" return np.zeros((5, 5), dtype=dtype) + num
param shared_fn: if None, will use function
verify that the elemwise fusion work fw, fx, fy, fz = [
Test with and without DimShuffle tensor(dtype="float32", broadcastable=[False] * 2, name=n) for n in "wxyz"
""" ]
# TODO: disable the canonizer? dw, dx, dy, dz = [
def my_init(shp, dtype="float64", num=0): tensor(dtype="float64", broadcastable=[False] * 2, name=n) for n in "wxyz"
ret = np.zeros(shp, dtype=dtype) + num ]
return ret ix, iy, iz = [
tensor(dtype="int32", broadcastable=[False] * 2, name=n) for n in "xyz"
fw, fx, fy, fz = [ ]
tensor(dtype="float32", broadcastable=[False] * len(shp), name=n) fv = fvector("v")
for n in "wxyz" fs = fscalar("s")
] fwv = my_init("float32", 1)
dw, dx, dy, dz = [ fxv = my_init("float32", 2)
tensor(dtype="float64", broadcastable=[False] * len(shp), name=n) fyv = my_init("float32", 3)
for n in "wxyz" fzv = my_init("float32", 4)
] fvv = _asarray(np.random.random(5), dtype="float32")
ix, iy, iz = [ fsv = np.asarray(np.random.random(), dtype="float32")
tensor(dtype="int32", broadcastable=[False] * len(shp), name=n) dwv = my_init("float64", 5)
for n in "xyz" ixv = _asarray(my_init(num=60), dtype="int32")
] iyv = _asarray(my_init(num=70), dtype="int32")
fv = fvector("v") izv = _asarray(my_init(num=70), dtype="int32")
fs = fscalar("s") fwx = fw + fx
ftanx = tan(fx)
fwv = my_init(shp, "float32", 1)
fxv = my_init(shp, "float32", 2) @pytest.mark.parametrize(
fyv = my_init(shp, "float32", 3) "case",
fzv = my_init(shp, "float32", 4) [
fvv = _asarray(np.random.random((shp[0])), dtype="float32")
fsv = np.asarray(np.random.random(), dtype="float32")
dwv = my_init(shp, "float64", 5)
ixv = _asarray(my_init(shp, num=60), dtype="int32")
iyv = _asarray(my_init(shp, num=70), dtype="int32")
izv = _asarray(my_init(shp, num=70), dtype="int32")
fwx = fw + fx
ftanx = tan(fx)
cases = [
( (
fx + fy + fz, fx + fy + fz,
(fx, fy, fz), (fx, fy, fz),
...@@ -991,68 +981,52 @@ class TestFusion: ...@@ -991,68 +981,52 @@ class TestFusion:
fxv * np.sin(fsv), fxv * np.sin(fsv),
"float32", "float32",
), ),
] ],
if slice: )
cases = cases[slice] def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True):
times = np.zeros(len(cases)) """Verify that `Elemwise` fusion works."""
fail1 = []
fail2 = [] g, sym_inputs, val_inputs, nb_elemwise, answer, out_dtype = case
fail3 = []
fail4 = [] if isinstance(out_dtype, dict):
for ( out_dtype = out_dtype[config.cast_policy]
id,
[g, sym_inputs, val_inputs, nb_elemwise, answer, out_dtype], if self._shared is None:
) in enumerate(cases): f = function(list(sym_inputs), g, mode=self.mode)
if isinstance(out_dtype, dict): for x in range(nb_repeat):
out_dtype = out_dtype[config.cast_policy] out = f(*val_inputs)
else:
if shared_fn is None: out = self._shared(np.zeros((5, 5), dtype=out_dtype), "out")
f = function(list(sym_inputs), g, mode=mode) assert out.dtype == g.dtype
for x in range(nb_repeat): f = function(sym_inputs, [], updates=[(out, g)], mode=self.mode)
out = f(*val_inputs) for x in range(nb_repeat):
t1 = time.time() f(*val_inputs)
else: out = out.get_value()
out = shared_fn(np.zeros(shp, dtype=out_dtype), "out")
assert out.dtype == g.dtype atol = 1e-8
f = function(sym_inputs, [], updates=[(out, g)], mode=mode) if out_dtype == "float32":
t0 = time.time() atol = 1e-6
for x in range(nb_repeat):
f(*val_inputs) assert np.allclose(out, answer * nb_repeat, atol=atol)
t1 = time.time()
out = out.get_value() topo = f.maker.fgraph.toposort()
topo_ = [n for n in topo if not isinstance(n.op, self.topo_exclude)]
times[id] = t1 - t0 if assert_len_topo:
atol = 1e-8
if out_dtype == "float32": assert len(topo_) == nb_elemwise
atol = 1e-6
if not np.allclose(out, answer * nb_repeat, atol=atol): if nb_elemwise == 1:
fail1.append(id) # if no variable appears multiple times in the
topo = f.maker.fgraph.toposort() # input of g,
topo_ = [n for n in topo if not isinstance(n.op, self.topo_exclude)] # check that the number of input to the Composite
if assert_len_topo: # Elemwise is ok
if not len(topo_) == nb_elemwise: if len(set(g.owner.inputs)) == len(g.owner.inputs):
fail3.append((id, topo_, nb_elemwise)) expected_len_sym_inputs = np.sum(
if nb_elemwise == 1: [not isinstance(x, Constant) for x in topo_[0].inputs]
# if no variable appears multiple times in the )
# input of g, assert expected_len_sym_inputs == len(sym_inputs)
# check that the number of input to the Composite
# Elemwise is ok assert out_dtype == out.dtype
if len(set(g.owner.inputs)) == len(g.owner.inputs):
expected_len_sym_inputs = np.sum(
[not isinstance(x, Constant) for x in topo_[0].inputs]
)
assert expected_len_sym_inputs == len(sym_inputs)
if not out_dtype == out.dtype:
fail4.append((id, out_dtype, out.dtype))
assert len(fail1 + fail2 + fail3 + fail4) == 0
return times
def test_elemwise_fusion(self):
shp = (5, 5)
self.do(self.mode, self._shared, shp)
def test_fusion_35_inputs(self): def test_fusion_35_inputs(self):
r"""Make sure we don't fuse too many `Op`\s and go past the 31 function arguments limit.""" r"""Make sure we don't fuse too many `Op`\s and go past the 31 function arguments limit."""
...@@ -1144,78 +1118,6 @@ class TestFusion: ...@@ -1144,78 +1118,6 @@ class TestFusion:
np.random.random((5, 5)), np.random.random((5, 5)), np.random.random((5, 5)) np.random.random((5, 5)), np.random.random((5, 5)), np.random.random((5, 5))
) )
def speed_fusion(self, s=None):
"""
param type s: a slice object
param s: a slice to apply to the case to execute. If None, exec all case.
"""
shp = (3000, 3000)
shp = (1000, 1000)
nb_repeat = 50
# linker=CLinker
# linker=OpWiseCLinker
mode1 = copy.copy(self.mode)
mode1._optimizer = mode1._optimizer.including("local_elemwise_fusion")
# TODO:clinker is much faster... but use to much memory
# Possible cause: as their is do deletion of intermediate value when we don't keep the fct.
# More plausible cause: we keep a link to the output data?
# Follow up. Clinker do the same... second cause?
mode2 = copy.copy(self.mode)
mode2._optimizer = mode2._optimizer.excluding("local_elemwise_fusion")
print("test with linker", str(mode1.linker))
times1 = self.do(
mode1,
self._shared,
shp,
nb_repeat=nb_repeat,
assert_len_topo=False,
slice=s,
)
times2 = self.do(
mode2,
self._shared,
shp,
nb_repeat=nb_repeat,
assert_len_topo=False,
slice=s,
)
print("times1 with local_elemwise_fusion")
print(times1, times1.min(), times1.max(), times1.sum())
print("times2 without local_elemwise_fusion")
print(times2, times2.min(), times2.max(), times2.sum())
d = times2 / times1
print("times2/times1")
print(d)
print(
"min",
d.min(),
"argmin",
d.argmin(),
"max",
d.max(),
"mean",
d.mean(),
"std",
d.std(),
)
def speed_log_exp(self):
s = slice(31, 36)
print(
"time",
self.do(
self.mode,
self._shared,
shp=(1000, 1000),
assert_len_topo=False,
slice=s,
nb_repeat=100,
),
)
@pytest.mark.skipif(not config.cxx, reason="No cxx compiler") @pytest.mark.skipif(not config.cxx, reason="No cxx compiler")
def test_no_c_code(self): def test_no_c_code(self):
r"""Make sure we avoid fusions for `Op`\s without C code implementations.""" r"""Make sure we avoid fusions for `Op`\s without C code implementations."""
...@@ -2342,78 +2244,87 @@ class TestLocalUselessSwitch: ...@@ -2342,78 +2244,87 @@ class TestLocalUselessSwitch:
def setup_method(self): def setup_method(self):
self.mode = mode_opt.excluding("constant_folding") self.mode = mode_opt.excluding("constant_folding")
@pytest.mark.parametrize(
"dtype1",
["int32", "int64"],
)
@pytest.mark.parametrize(
"dtype2",
["int32", "int64"],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"cond", "cond",
[0, 1, np.array([True])], [0, 1, np.array([True])],
) )
def test_const(self, cond): def test_const(self, dtype1, dtype2, cond):
for dtype1 in ["int32", "int64"]: x = matrix("x", dtype=dtype1)
for dtype2 in ["int32", "int64"]: y = matrix("y", dtype=dtype2)
x = matrix("x", dtype=dtype1) z = aet.switch(cond, x, y)
y = matrix("y", dtype=dtype2) f = function([x, y], z, mode=self.mode)
z = aet.switch(cond, x, y) assert not any(
f = function([x, y], z, mode=self.mode) [
assert ( node.op
len( for node in f.maker.fgraph.toposort()
[ if (
node.op isinstance(node.op, Elemwise)
for node in f.maker.fgraph.toposort() and isinstance(node.op.scalar_op, aes.basic.Switch)
if (
isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, aes.basic.Switch)
)
]
)
== 0
) )
vx = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1) ]
vy = np.array([[7, 8, 9], [10, 11, 12]], dtype=dtype2) )
np_res = np.where(cond, vx, vy) vx = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1)
assert np.array_equal(f(vx, vy), np_res) vy = np.array([[7, 8, 9], [10, 11, 12]], dtype=dtype2)
np_res = np.where(cond, vx, vy)
def test_left_is_right(self): assert np.array_equal(f(vx, vy), np_res)
for dtype1 in ["int32", "int64"]:
x = matrix("x", dtype=dtype1) @pytest.mark.parametrize(
varc = matrix("varc", dtype=dtype1) "dtype1",
z1 = aet.switch(1, x, x) ["int32", "int64"],
z0 = aet.switch(0, x, x) )
z2 = aet.switch(varc, x, x) def test_left_is_right(self, dtype1):
f1 = function([x], z1, mode=self.mode) x = matrix("x", dtype=dtype1)
f0 = function([x], z0, mode=self.mode) varc = matrix("varc", dtype=dtype1)
f2 = function([x, varc], z2, mode=self.mode) z1 = aet.switch(1, x, x)
z0 = aet.switch(0, x, x)
topo = f1.maker.fgraph.toposort() z2 = aet.switch(varc, x, x)
assert len(topo) == 1 f1 = function([x], z1, mode=self.mode)
assert topo[0].op == deep_copy_op f0 = function([x], z0, mode=self.mode)
f2 = function([x, varc], z2, mode=self.mode)
topo = f0.maker.fgraph.toposort()
assert len(topo) == 1 topo = f1.maker.fgraph.toposort()
assert topo[0].op == deep_copy_op assert len(topo) == 1
assert topo[0].op == deep_copy_op
topo = f2.maker.fgraph.toposort()
assert len(topo) == 1 topo = f0.maker.fgraph.toposort()
assert topo[0].op == deep_copy_op assert len(topo) == 1
assert topo[0].op == deep_copy_op
vx = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1)
vc = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1) topo = f2.maker.fgraph.toposort()
assert np.array_equal(f1(vx), vx) assert len(topo) == 1
assert np.array_equal(f0(vx), vx) assert topo[0].op == deep_copy_op
assert np.array_equal(f2(vx, vc), vx)
vx = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1)
def test_shape_le_0(self): vc = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1)
for dtype1 in ["float32", "float64"]: assert np.array_equal(f1(vx), vx)
x = matrix("x", dtype=dtype1) assert np.array_equal(f0(vx), vx)
z0 = aet.switch(le(x.shape[0], 0), 0, x.shape[0]) assert np.array_equal(f2(vx, vc), vx)
f0 = function([x], z0, mode=self.mode)
assert isinstance(f0.maker.fgraph.toposort()[0].op, Shape_i) @pytest.mark.parametrize(
"dtype1",
z1 = aet.switch(le(x.shape[1], 0), 0, x.shape[1]) ["float32", "float64"],
f1 = function([x], z1, mode=self.mode) )
assert isinstance(f1.maker.fgraph.toposort()[0].op, Shape_i) def test_shape_le_0(self, dtype1):
x = matrix("x", dtype=dtype1)
vx = np.random.standard_normal((0, 5)).astype(dtype1) z0 = aet.switch(le(x.shape[0], 0), 0, x.shape[0])
assert f0(vx) == 0 f0 = function([x], z0, mode=self.mode)
assert f1(vx) == 5 assert isinstance(f0.maker.fgraph.toposort()[0].op, Shape_i)
z1 = aet.switch(le(x.shape[1], 0), 0, x.shape[1])
f1 = function([x], z1, mode=self.mode)
assert isinstance(f1.maker.fgraph.toposort()[0].op, Shape_i)
vx = np.random.standard_normal((0, 5)).astype(dtype1)
assert f0(vx) == 0
assert f1(vx) == 5
def test_broadcasting_1(self): def test_broadcasting_1(self):
# test switch(cst, matrix, row) # test switch(cst, matrix, row)
...@@ -2489,13 +2400,9 @@ class TestLocalUselessSwitch: ...@@ -2489,13 +2400,9 @@ class TestLocalUselessSwitch:
class TestLocalMergeSwitchSameCond: class TestLocalMergeSwitchSameCond:
def test_elemwise(self): @pytest.mark.parametrize(
# float Ops "op",
mats = matrices("cabxy") [
c, a, b, x, y = mats
s1 = aet.switch(c, a, b)
s2 = aet.switch(c, x, y)
for op in (
add, add,
sub, sub,
mul, mul,
...@@ -2511,30 +2418,53 @@ class TestLocalMergeSwitchSameCond: ...@@ -2511,30 +2418,53 @@ class TestLocalMergeSwitchSameCond:
eq, eq,
neq, neq,
aet_pow, aet_pow,
): ],
g = optimize(FunctionGraph(mats, [op(s1, s2)])) )
assert str(g).count("Switch") == 1 def test_elemwise_float_ops(self, op):
# integer Ops # float Ops
mats = imatrices("cabxy") mats = matrices("cabxy")
c, a, b, x, y = mats c, a, b, x, y = mats
s1 = aet.switch(c, a, b) s1 = aet.switch(c, a, b)
s2 = aet.switch(c, x, y) s2 = aet.switch(c, x, y)
for op in (
g = optimize(FunctionGraph(mats, [op(s1, s2)]))
assert str(g).count("Switch") == 1
@pytest.mark.parametrize(
"op",
[
bitwise_and, bitwise_and,
bitwise_or, bitwise_or,
bitwise_xor, bitwise_xor,
): ],
g = optimize(FunctionGraph(mats, [op(s1, s2)])) )
assert str(g).count("Switch") == 1 def test_elemwise_int_ops(self, op):
# integer Ops
mats = imatrices("cabxy")
c, a, b, x, y = mats
s1 = aet.switch(c, a, b)
s2 = aet.switch(c, x, y)
g = optimize(FunctionGraph(mats, [op(s1, s2)]))
assert str(g).count("Switch") == 1
@pytest.mark.parametrize("op", [add, mul])
def test_elemwise_multi_inputs(self, op):
# add/mul with more than two inputs # add/mul with more than two inputs
mats = imatrices("cabxy")
c, a, b, x, y = mats
s1 = aet.switch(c, a, b)
s2 = aet.switch(c, x, y)
u, v = matrices("uv") u, v = matrices("uv")
s3 = aet.switch(c, u, v) s3 = aet.switch(c, u, v)
for op in (add, mul): g = optimize(FunctionGraph(mats + [u, v], [op(s1, s2, s3)]))
g = optimize(FunctionGraph(mats + [u, v], [op(s1, s2, s3)])) assert str(g).count("Switch") == 1
assert str(g).count("Switch") == 1
class TestLocalOptAlloc: class TestLocalOptAlloc:
"""
TODO FIXME: These tests are incomplete; they need to `assert` something.
"""
dtype = "float32" dtype = "float32"
def test_sum_upcast(self): def test_sum_upcast(self):
...@@ -2571,20 +2501,17 @@ class TestLocalOptAllocF16(TestLocalOptAlloc): ...@@ -2571,20 +2501,17 @@ class TestLocalOptAllocF16(TestLocalOptAlloc):
class TestMakeVector(utt.InferShapeTester): class TestMakeVector(utt.InferShapeTester):
b = bscalar()
i = iscalar()
d = dscalar()
def setup_method(self): def setup_method(self):
self.rng = np.random.default_rng(utt.fetch_seed()) self.rng = np.random.default_rng(utt.fetch_seed())
super().setup_method() super().setup_method()
def test_make_vector(self): @pytest.mark.parametrize(
b = bscalar() "dtype, inputs",
i = iscalar() [
d = dscalar()
# TODO: draw random values instead. Not really important.
val = {b: 2, i: -3, d: 0.7}
# Should work
for (dtype, inputs) in [
("int8", (b, b)), ("int8", (b, b)),
("int32", (i, b)), ("int32", (i, b)),
("int32", (b, i)), ("int32", (b, i)),
...@@ -2593,56 +2520,63 @@ class TestMakeVector(utt.InferShapeTester): ...@@ -2593,56 +2520,63 @@ class TestMakeVector(utt.InferShapeTester):
("float64", (d, i)), ("float64", (d, i)),
("float64", ()), ("float64", ()),
("int64", ()), ("int64", ()),
]: ],
mv = MakeVector(dtype=dtype)(*inputs) )
assert mv.dtype == dtype def test_make_vector(self, dtype, inputs):
f = function([b, i, d], mv, on_unused_input="ignore") b, i, d = self.b, self.i, self.d
f(val[b], val[i], val[d])
val = {b: 2, i: -3, d: 0.7}
s = mv.sum()
gb = aesara.gradient.grad(s, b, disconnected_inputs="ignore") mv = MakeVector(dtype=dtype)(*inputs)
gi = aesara.gradient.grad(s, i, disconnected_inputs="ignore") assert mv.dtype == dtype
gd = aesara.gradient.grad(s, d, disconnected_inputs="ignore") f = function([b, i, d], mv, on_unused_input="ignore")
f(val[b], val[i], val[d])
g = function([b, i, d], [gb, gi, gd])
g_val = g(val[b], val[i], val[d]) s = mv.sum()
gb = aesara.gradient.grad(s, b, disconnected_inputs="ignore")
if dtype in int_dtypes: gi = aesara.gradient.grad(s, i, disconnected_inputs="ignore")
# The gradient should be 0 gd = aesara.gradient.grad(s, d, disconnected_inputs="ignore")
utt.assert_allclose(g_val, 0)
else: g = function([b, i, d], [gb, gi, gd])
for var, grval in zip((b, i, d), g_val): g_val = g(val[b], val[i], val[d])
float_inputs = []
if var.dtype in int_dtypes: if dtype in int_dtypes:
pass # The gradient should be 0
# Currently we don't do any checks on these variables utt.assert_allclose(g_val, 0)
# verify_grad doesn't support integer inputs yet else:
# however, the gradient on them is *not* defined to for var, grval in zip((b, i, d), g_val):
# be 0 float_inputs = []
elif var not in inputs: if var.dtype in int_dtypes:
assert grval == 0 pass
else: # Currently we don't do any checks on these variables
float_inputs.append(var) # verify_grad doesn't support integer inputs yet
# however, the gradient on them is *not* defined to
# Build a function that takes float_inputs, use fix values for the # be 0
# other inputs, and returns the MakeVector. Use it for verify_grad. elif var not in inputs:
if float_inputs: assert grval == 0
else:
def fun(*fl_inputs): float_inputs.append(var)
f_inputs = []
for var in f_inputs: # Build a function that takes float_inputs, use fix values for the
if var in fl_inputs: # other inputs, and returns the MakeVector. Use it for verify_grad.
# use symbolic variable if float_inputs:
f_inputs.append(var)
else: def fun(*fl_inputs):
# use constant value f_inputs = []
f_inputs.append(val[var]) for var in f_inputs:
return MakeVector(dtype=dtype)(*f_inputs) if var in fl_inputs:
# use symbolic variable
utt.verify_grad(fun, [val[ri] for ri in float_inputs]) f_inputs.append(var)
else:
# should fail # use constant value
for (dtype, inputs) in [ f_inputs.append(val[var])
return MakeVector(dtype=dtype)(*f_inputs)
utt.verify_grad(fun, [val[ri] for ri in float_inputs])
@pytest.mark.parametrize(
"dtype, inputs",
[
("int8", (b, i)), ("int8", (b, i)),
("int8", (i, b)), ("int8", (i, b)),
("int8", (b, d)), ("int8", (b, d)),
...@@ -2650,12 +2584,11 @@ class TestMakeVector(utt.InferShapeTester): ...@@ -2650,12 +2584,11 @@ class TestMakeVector(utt.InferShapeTester):
("int32", (d, i)), ("int32", (d, i)),
("int32", (i, d)), ("int32", (i, d)),
("float32", (i, d)), ("float32", (i, d)),
]: ],
try: )
MakeVector(dtype=dtype)(*inputs) def test_make_vector_fail(self, dtype, inputs):
raise Exception("Aesara should have raised an error") with pytest.raises(AssertionError):
except AssertionError: MakeVector(dtype=dtype)(*inputs)
pass
def test_infer_shape(self): def test_infer_shape(self):
adscal = dscalar() adscal = dscalar()
...@@ -2824,8 +2757,9 @@ def test_local_join_make_vector(): ...@@ -2824,8 +2757,9 @@ def test_local_join_make_vector():
assert check_stack_trace(f, ops_to_check="all") assert check_stack_trace(f, ops_to_check="all")
def test_local_tensor_scalar_tensor(): @pytest.mark.parametrize(
dtypes = [ "dtype",
[
"int8", "int8",
"int16", "int16",
"int32", "int32",
...@@ -2838,25 +2772,24 @@ def test_local_tensor_scalar_tensor(): ...@@ -2838,25 +2772,24 @@ def test_local_tensor_scalar_tensor():
"float64", "float64",
"complex64", "complex64",
"complex128", "complex128",
] ],
)
for dtype in dtypes: def test_local_tensor_scalar_tensor(dtype):
t_type = TensorType(dtype=dtype, broadcastable=()) t_type = TensorType(dtype=dtype, broadcastable=())
t = t_type() t = t_type()
s = aet.scalar_from_tensor(t) s = aet.scalar_from_tensor(t)
t2 = aet.tensor_from_scalar(s) t2 = aet.tensor_from_scalar(s)
f = function([t], t2, mode=mode_opt) f = function([t], t2, mode=mode_opt)
e = f.maker.fgraph.toposort() e = f.maker.fgraph.toposort()
cast_nodes = [ assert not any(
n for n in e if isinstance(n.op, (TensorFromScalar, ScalarFromTensor)) [n for n in e if isinstance(n.op, (TensorFromScalar, ScalarFromTensor))]
] )
assert len(cast_nodes) == 0
f(0)
def test_local_scalar_tensor_scalar(): @pytest.mark.parametrize(
dtypes = [ "dtype",
[
"int8", "int8",
"int16", "int16",
"int32", "int32",
...@@ -2869,21 +2802,19 @@ def test_local_scalar_tensor_scalar(): ...@@ -2869,21 +2802,19 @@ def test_local_scalar_tensor_scalar():
"float64", "float64",
"complex64", "complex64",
"complex128", "complex128",
] ],
)
for dtype in dtypes: def test_local_scalar_tensor_scalar(dtype):
s_type = aes.Scalar(dtype=dtype) s_type = aes.Scalar(dtype=dtype)
s = s_type() s = s_type()
t = aet.tensor_from_scalar(s) t = aet.tensor_from_scalar(s)
s2 = aet.scalar_from_tensor(t) s2 = aet.scalar_from_tensor(t)
f = function([s], s2, mode=mode_opt) f = function([s], s2, mode=mode_opt)
e = f.maker.fgraph.toposort() e = f.maker.fgraph.toposort()
cast_nodes = [ assert not any(
n for n in e if isinstance(n.op, (TensorFromScalar, ScalarFromTensor)) [n for n in e if isinstance(n.op, (TensorFromScalar, ScalarFromTensor))]
] )
assert len(cast_nodes) == 0
f(0)
def test_local_useless_split(): def test_local_useless_split():
...@@ -2909,25 +2840,23 @@ def test_local_useless_split(): ...@@ -2909,25 +2840,23 @@ def test_local_useless_split():
assert check_stack_trace(f_nonopt, ops_to_check="all") assert check_stack_trace(f_nonopt, ops_to_check="all")
def test_local_flatten_lift(): @pytest.mark.parametrize("i", list(range(1, 4)))
for i in range(1, 4): def test_local_flatten_lift(i):
x = tensor4() x = tensor4()
out = aet.flatten(exp(x), i) out = aet.flatten(exp(x), i)
assert out.ndim == i assert out.ndim == i
mode = get_default_mode() mode = get_default_mode()
mode = mode.including("local_reshape_lift") mode = mode.including("local_reshape_lift")
f = function([x], out, mode=mode) f = function([x], out, mode=mode)
x_np = np.random.random((5, 4, 3, 2)).astype(config.floatX) x_np = np.random.random((5, 4, 3, 2)).astype(config.floatX)
out_np = f(x_np) out_np = f(x_np)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
shape_out_np = tuple(x_np.shape[: i - 1]) + (np.prod(x_np.shape[i - 1 :]),) shape_out_np = tuple(x_np.shape[: i - 1]) + (np.prod(x_np.shape[i - 1 :]),)
assert shape_out_np == out_np.shape assert shape_out_np == out_np.shape
reshape_nodes = [n for n in topo if isinstance(n.op, Reshape)] reshape_nodes = [n for n in topo if isinstance(n.op, Reshape)]
assert len(reshape_nodes) == 1 and aet.is_flat( assert len(reshape_nodes) == 1 and aet.is_flat(reshape_nodes[0].outputs[0], ndim=i)
reshape_nodes[0].outputs[0], ndim=i assert isinstance(topo[-1].op, Elemwise)
)
assert isinstance(topo[-1].op, Elemwise)
class TestReshape: class TestReshape:
...@@ -3118,15 +3047,16 @@ class TestShapeI(utt.InferShapeTester): ...@@ -3118,15 +3047,16 @@ class TestShapeI(utt.InferShapeTester):
super().setup_method() super().setup_method()
def test_perform(self): def test_perform(self):
rng = np.random.default_rng(utt.fetch_seed())
advec = vector() advec = vector()
advec_val = np.random.random((3)).astype(config.floatX) advec_val = rng.random((3)).astype(config.floatX)
f = function([advec], Shape_i(0)(advec)) f = function([advec], Shape_i(0)(advec))
out = f(advec_val) out = f(advec_val)
utt.assert_allclose(out, advec_val.shape[0]) utt.assert_allclose(out, advec_val.shape[0])
admat = matrix() admat = matrix()
admat_val = np.random.random((4, 3)).astype(config.floatX) admat_val = rng.random((4, 3)).astype(config.floatX)
for i in range(2): for i in range(2):
f = function([admat], Shape_i(i)(admat)) f = function([admat], Shape_i(i)(admat))
out = f(admat_val) out = f(admat_val)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论