提交 f64d243d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor tests.tensor.test_basic_opt tests

* Convert tests to pytest `parametrized` tests * Use `pytest.raises` * Remove unused timing code and methods
上级 c86a717b
import copy import copy
import time
import numpy as np import numpy as np
import pytest import pytest
...@@ -354,45 +353,36 @@ class TestFusion: ...@@ -354,45 +353,36 @@ class TestFusion:
_shared = staticmethod(shared) _shared = staticmethod(shared)
topo_exclude = () topo_exclude = ()
def do(self, mode, shared_fn, shp, nb_repeat=1, assert_len_topo=True, slice=None): def my_init(dtype="float64", num=0):
""" return np.zeros((5, 5), dtype=dtype) + num
param shared_fn: if None, will use function
verify that the elemwise fusion work
Test with and without DimShuffle
"""
# TODO: disable the canonizer?
def my_init(shp, dtype="float64", num=0):
ret = np.zeros(shp, dtype=dtype) + num
return ret
fw, fx, fy, fz = [ fw, fx, fy, fz = [
tensor(dtype="float32", broadcastable=[False] * len(shp), name=n) tensor(dtype="float32", broadcastable=[False] * 2, name=n) for n in "wxyz"
for n in "wxyz"
] ]
dw, dx, dy, dz = [ dw, dx, dy, dz = [
tensor(dtype="float64", broadcastable=[False] * len(shp), name=n) tensor(dtype="float64", broadcastable=[False] * 2, name=n) for n in "wxyz"
for n in "wxyz"
] ]
ix, iy, iz = [ ix, iy, iz = [
tensor(dtype="int32", broadcastable=[False] * len(shp), name=n) tensor(dtype="int32", broadcastable=[False] * 2, name=n) for n in "xyz"
for n in "xyz"
] ]
fv = fvector("v") fv = fvector("v")
fs = fscalar("s") fs = fscalar("s")
fwv = my_init("float32", 1)
fwv = my_init(shp, "float32", 1) fxv = my_init("float32", 2)
fxv = my_init(shp, "float32", 2) fyv = my_init("float32", 3)
fyv = my_init(shp, "float32", 3) fzv = my_init("float32", 4)
fzv = my_init(shp, "float32", 4) fvv = _asarray(np.random.random(5), dtype="float32")
fvv = _asarray(np.random.random((shp[0])), dtype="float32")
fsv = np.asarray(np.random.random(), dtype="float32") fsv = np.asarray(np.random.random(), dtype="float32")
dwv = my_init(shp, "float64", 5) dwv = my_init("float64", 5)
ixv = _asarray(my_init(shp, num=60), dtype="int32") ixv = _asarray(my_init(num=60), dtype="int32")
iyv = _asarray(my_init(shp, num=70), dtype="int32") iyv = _asarray(my_init(num=70), dtype="int32")
izv = _asarray(my_init(shp, num=70), dtype="int32") izv = _asarray(my_init(num=70), dtype="int32")
fwx = fw + fx fwx = fw + fx
ftanx = tan(fx) ftanx = tan(fx)
cases = [
@pytest.mark.parametrize(
"case",
[
( (
fx + fy + fz, fx + fy + fz,
(fx, fy, fz), (fx, fy, fz),
...@@ -991,47 +981,40 @@ class TestFusion: ...@@ -991,47 +981,40 @@ class TestFusion:
fxv * np.sin(fsv), fxv * np.sin(fsv),
"float32", "float32",
), ),
] ],
if slice: )
cases = cases[slice] def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True):
times = np.zeros(len(cases)) """Verify that `Elemwise` fusion works."""
fail1 = []
fail2 = [] g, sym_inputs, val_inputs, nb_elemwise, answer, out_dtype = case
fail3 = []
fail4 = []
for (
id,
[g, sym_inputs, val_inputs, nb_elemwise, answer, out_dtype],
) in enumerate(cases):
if isinstance(out_dtype, dict): if isinstance(out_dtype, dict):
out_dtype = out_dtype[config.cast_policy] out_dtype = out_dtype[config.cast_policy]
if shared_fn is None: if self._shared is None:
f = function(list(sym_inputs), g, mode=mode) f = function(list(sym_inputs), g, mode=self.mode)
for x in range(nb_repeat): for x in range(nb_repeat):
out = f(*val_inputs) out = f(*val_inputs)
t1 = time.time()
else: else:
out = shared_fn(np.zeros(shp, dtype=out_dtype), "out") out = self._shared(np.zeros((5, 5), dtype=out_dtype), "out")
assert out.dtype == g.dtype assert out.dtype == g.dtype
f = function(sym_inputs, [], updates=[(out, g)], mode=mode) f = function(sym_inputs, [], updates=[(out, g)], mode=self.mode)
t0 = time.time()
for x in range(nb_repeat): for x in range(nb_repeat):
f(*val_inputs) f(*val_inputs)
t1 = time.time()
out = out.get_value() out = out.get_value()
times[id] = t1 - t0
atol = 1e-8 atol = 1e-8
if out_dtype == "float32": if out_dtype == "float32":
atol = 1e-6 atol = 1e-6
if not np.allclose(out, answer * nb_repeat, atol=atol):
fail1.append(id) assert np.allclose(out, answer * nb_repeat, atol=atol)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
topo_ = [n for n in topo if not isinstance(n.op, self.topo_exclude)] topo_ = [n for n in topo if not isinstance(n.op, self.topo_exclude)]
if assert_len_topo: if assert_len_topo:
if not len(topo_) == nb_elemwise:
fail3.append((id, topo_, nb_elemwise)) assert len(topo_) == nb_elemwise
if nb_elemwise == 1: if nb_elemwise == 1:
# if no variable appears multiple times in the # if no variable appears multiple times in the
# input of g, # input of g,
...@@ -1043,16 +1026,7 @@ class TestFusion: ...@@ -1043,16 +1026,7 @@ class TestFusion:
) )
assert expected_len_sym_inputs == len(sym_inputs) assert expected_len_sym_inputs == len(sym_inputs)
if not out_dtype == out.dtype: assert out_dtype == out.dtype
fail4.append((id, out_dtype, out.dtype))
assert len(fail1 + fail2 + fail3 + fail4) == 0
return times
def test_elemwise_fusion(self):
shp = (5, 5)
self.do(self.mode, self._shared, shp)
def test_fusion_35_inputs(self): def test_fusion_35_inputs(self):
r"""Make sure we don't fuse too many `Op`\s and go past the 31 function arguments limit.""" r"""Make sure we don't fuse too many `Op`\s and go past the 31 function arguments limit."""
...@@ -1144,78 +1118,6 @@ class TestFusion: ...@@ -1144,78 +1118,6 @@ class TestFusion:
np.random.random((5, 5)), np.random.random((5, 5)), np.random.random((5, 5)) np.random.random((5, 5)), np.random.random((5, 5)), np.random.random((5, 5))
) )
def speed_fusion(self, s=None):
"""
param type s: a slice object
param s: a slice to apply to the case to execute. If None, exec all case.
"""
shp = (3000, 3000)
shp = (1000, 1000)
nb_repeat = 50
# linker=CLinker
# linker=OpWiseCLinker
mode1 = copy.copy(self.mode)
mode1._optimizer = mode1._optimizer.including("local_elemwise_fusion")
# TODO:clinker is much faster... but use to much memory
# Possible cause: as their is do deletion of intermediate value when we don't keep the fct.
# More plausible cause: we keep a link to the output data?
# Follow up. Clinker do the same... second cause?
mode2 = copy.copy(self.mode)
mode2._optimizer = mode2._optimizer.excluding("local_elemwise_fusion")
print("test with linker", str(mode1.linker))
times1 = self.do(
mode1,
self._shared,
shp,
nb_repeat=nb_repeat,
assert_len_topo=False,
slice=s,
)
times2 = self.do(
mode2,
self._shared,
shp,
nb_repeat=nb_repeat,
assert_len_topo=False,
slice=s,
)
print("times1 with local_elemwise_fusion")
print(times1, times1.min(), times1.max(), times1.sum())
print("times2 without local_elemwise_fusion")
print(times2, times2.min(), times2.max(), times2.sum())
d = times2 / times1
print("times2/times1")
print(d)
print(
"min",
d.min(),
"argmin",
d.argmin(),
"max",
d.max(),
"mean",
d.mean(),
"std",
d.std(),
)
def speed_log_exp(self):
s = slice(31, 36)
print(
"time",
self.do(
self.mode,
self._shared,
shp=(1000, 1000),
assert_len_topo=False,
slice=s,
nb_repeat=100,
),
)
@pytest.mark.skipif(not config.cxx, reason="No cxx compiler") @pytest.mark.skipif(not config.cxx, reason="No cxx compiler")
def test_no_c_code(self): def test_no_c_code(self):
r"""Make sure we avoid fusions for `Op`\s without C code implementations.""" r"""Make sure we avoid fusions for `Op`\s without C code implementations."""
...@@ -2342,19 +2244,24 @@ class TestLocalUselessSwitch: ...@@ -2342,19 +2244,24 @@ class TestLocalUselessSwitch:
def setup_method(self): def setup_method(self):
self.mode = mode_opt.excluding("constant_folding") self.mode = mode_opt.excluding("constant_folding")
@pytest.mark.parametrize(
"dtype1",
["int32", "int64"],
)
@pytest.mark.parametrize(
"dtype2",
["int32", "int64"],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"cond", "cond",
[0, 1, np.array([True])], [0, 1, np.array([True])],
) )
def test_const(self, cond): def test_const(self, dtype1, dtype2, cond):
for dtype1 in ["int32", "int64"]:
for dtype2 in ["int32", "int64"]:
x = matrix("x", dtype=dtype1) x = matrix("x", dtype=dtype1)
y = matrix("y", dtype=dtype2) y = matrix("y", dtype=dtype2)
z = aet.switch(cond, x, y) z = aet.switch(cond, x, y)
f = function([x, y], z, mode=self.mode) f = function([x, y], z, mode=self.mode)
assert ( assert not any(
len(
[ [
node.op node.op
for node in f.maker.fgraph.toposort() for node in f.maker.fgraph.toposort()
...@@ -2364,15 +2271,16 @@ class TestLocalUselessSwitch: ...@@ -2364,15 +2271,16 @@ class TestLocalUselessSwitch:
) )
] ]
) )
== 0
)
vx = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1) vx = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1)
vy = np.array([[7, 8, 9], [10, 11, 12]], dtype=dtype2) vy = np.array([[7, 8, 9], [10, 11, 12]], dtype=dtype2)
np_res = np.where(cond, vx, vy) np_res = np.where(cond, vx, vy)
assert np.array_equal(f(vx, vy), np_res) assert np.array_equal(f(vx, vy), np_res)
def test_left_is_right(self): @pytest.mark.parametrize(
for dtype1 in ["int32", "int64"]: "dtype1",
["int32", "int64"],
)
def test_left_is_right(self, dtype1):
x = matrix("x", dtype=dtype1) x = matrix("x", dtype=dtype1)
varc = matrix("varc", dtype=dtype1) varc = matrix("varc", dtype=dtype1)
z1 = aet.switch(1, x, x) z1 = aet.switch(1, x, x)
...@@ -2400,8 +2308,11 @@ class TestLocalUselessSwitch: ...@@ -2400,8 +2308,11 @@ class TestLocalUselessSwitch:
assert np.array_equal(f0(vx), vx) assert np.array_equal(f0(vx), vx)
assert np.array_equal(f2(vx, vc), vx) assert np.array_equal(f2(vx, vc), vx)
def test_shape_le_0(self): @pytest.mark.parametrize(
for dtype1 in ["float32", "float64"]: "dtype1",
["float32", "float64"],
)
def test_shape_le_0(self, dtype1):
x = matrix("x", dtype=dtype1) x = matrix("x", dtype=dtype1)
z0 = aet.switch(le(x.shape[0], 0), 0, x.shape[0]) z0 = aet.switch(le(x.shape[0], 0), 0, x.shape[0])
f0 = function([x], z0, mode=self.mode) f0 = function([x], z0, mode=self.mode)
...@@ -2489,13 +2400,9 @@ class TestLocalUselessSwitch: ...@@ -2489,13 +2400,9 @@ class TestLocalUselessSwitch:
class TestLocalMergeSwitchSameCond: class TestLocalMergeSwitchSameCond:
def test_elemwise(self): @pytest.mark.parametrize(
# float Ops "op",
mats = matrices("cabxy") [
c, a, b, x, y = mats
s1 = aet.switch(c, a, b)
s2 = aet.switch(c, x, y)
for op in (
add, add,
sub, sub,
mul, mul,
...@@ -2511,30 +2418,53 @@ class TestLocalMergeSwitchSameCond: ...@@ -2511,30 +2418,53 @@ class TestLocalMergeSwitchSameCond:
eq, eq,
neq, neq,
aet_pow, aet_pow,
): ],
)
def test_elemwise_float_ops(self, op):
# float Ops
mats = matrices("cabxy")
c, a, b, x, y = mats
s1 = aet.switch(c, a, b)
s2 = aet.switch(c, x, y)
g = optimize(FunctionGraph(mats, [op(s1, s2)])) g = optimize(FunctionGraph(mats, [op(s1, s2)]))
assert str(g).count("Switch") == 1 assert str(g).count("Switch") == 1
@pytest.mark.parametrize(
"op",
[
bitwise_and,
bitwise_or,
bitwise_xor,
],
)
def test_elemwise_int_ops(self, op):
# integer Ops # integer Ops
mats = imatrices("cabxy") mats = imatrices("cabxy")
c, a, b, x, y = mats c, a, b, x, y = mats
s1 = aet.switch(c, a, b) s1 = aet.switch(c, a, b)
s2 = aet.switch(c, x, y) s2 = aet.switch(c, x, y)
for op in (
bitwise_and,
bitwise_or,
bitwise_xor,
):
g = optimize(FunctionGraph(mats, [op(s1, s2)])) g = optimize(FunctionGraph(mats, [op(s1, s2)]))
assert str(g).count("Switch") == 1 assert str(g).count("Switch") == 1
@pytest.mark.parametrize("op", [add, mul])
def test_elemwise_multi_inputs(self, op):
# add/mul with more than two inputs # add/mul with more than two inputs
mats = imatrices("cabxy")
c, a, b, x, y = mats
s1 = aet.switch(c, a, b)
s2 = aet.switch(c, x, y)
u, v = matrices("uv") u, v = matrices("uv")
s3 = aet.switch(c, u, v) s3 = aet.switch(c, u, v)
for op in (add, mul):
g = optimize(FunctionGraph(mats + [u, v], [op(s1, s2, s3)])) g = optimize(FunctionGraph(mats + [u, v], [op(s1, s2, s3)]))
assert str(g).count("Switch") == 1 assert str(g).count("Switch") == 1
class TestLocalOptAlloc: class TestLocalOptAlloc:
"""
TODO FIXME: These tests are incomplete; they need to `assert` something.
"""
dtype = "float32" dtype = "float32"
def test_sum_upcast(self): def test_sum_upcast(self):
...@@ -2571,20 +2501,17 @@ class TestLocalOptAllocF16(TestLocalOptAlloc): ...@@ -2571,20 +2501,17 @@ class TestLocalOptAllocF16(TestLocalOptAlloc):
class TestMakeVector(utt.InferShapeTester): class TestMakeVector(utt.InferShapeTester):
def setup_method(self):
self.rng = np.random.default_rng(utt.fetch_seed())
super().setup_method()
def test_make_vector(self):
b = bscalar() b = bscalar()
i = iscalar() i = iscalar()
d = dscalar() d = dscalar()
# TODO: draw random values instead. Not really important. def setup_method(self):
val = {b: 2, i: -3, d: 0.7} self.rng = np.random.default_rng(utt.fetch_seed())
super().setup_method()
# Should work @pytest.mark.parametrize(
for (dtype, inputs) in [ "dtype, inputs",
[
("int8", (b, b)), ("int8", (b, b)),
("int32", (i, b)), ("int32", (i, b)),
("int32", (b, i)), ("int32", (b, i)),
...@@ -2593,7 +2520,13 @@ class TestMakeVector(utt.InferShapeTester): ...@@ -2593,7 +2520,13 @@ class TestMakeVector(utt.InferShapeTester):
("float64", (d, i)), ("float64", (d, i)),
("float64", ()), ("float64", ()),
("int64", ()), ("int64", ()),
]: ],
)
def test_make_vector(self, dtype, inputs):
b, i, d = self.b, self.i, self.d
val = {b: 2, i: -3, d: 0.7}
mv = MakeVector(dtype=dtype)(*inputs) mv = MakeVector(dtype=dtype)(*inputs)
assert mv.dtype == dtype assert mv.dtype == dtype
f = function([b, i, d], mv, on_unused_input="ignore") f = function([b, i, d], mv, on_unused_input="ignore")
...@@ -2641,8 +2574,9 @@ class TestMakeVector(utt.InferShapeTester): ...@@ -2641,8 +2574,9 @@ class TestMakeVector(utt.InferShapeTester):
utt.verify_grad(fun, [val[ri] for ri in float_inputs]) utt.verify_grad(fun, [val[ri] for ri in float_inputs])
# should fail @pytest.mark.parametrize(
for (dtype, inputs) in [ "dtype, inputs",
[
("int8", (b, i)), ("int8", (b, i)),
("int8", (i, b)), ("int8", (i, b)),
("int8", (b, d)), ("int8", (b, d)),
...@@ -2650,12 +2584,11 @@ class TestMakeVector(utt.InferShapeTester): ...@@ -2650,12 +2584,11 @@ class TestMakeVector(utt.InferShapeTester):
("int32", (d, i)), ("int32", (d, i)),
("int32", (i, d)), ("int32", (i, d)),
("float32", (i, d)), ("float32", (i, d)),
]: ],
try: )
def test_make_vector_fail(self, dtype, inputs):
with pytest.raises(AssertionError):
MakeVector(dtype=dtype)(*inputs) MakeVector(dtype=dtype)(*inputs)
raise Exception("Aesara should have raised an error")
except AssertionError:
pass
def test_infer_shape(self): def test_infer_shape(self):
adscal = dscalar() adscal = dscalar()
...@@ -2824,8 +2757,9 @@ def test_local_join_make_vector(): ...@@ -2824,8 +2757,9 @@ def test_local_join_make_vector():
assert check_stack_trace(f, ops_to_check="all") assert check_stack_trace(f, ops_to_check="all")
def test_local_tensor_scalar_tensor(): @pytest.mark.parametrize(
dtypes = [ "dtype",
[
"int8", "int8",
"int16", "int16",
"int32", "int32",
...@@ -2838,9 +2772,9 @@ def test_local_tensor_scalar_tensor(): ...@@ -2838,9 +2772,9 @@ def test_local_tensor_scalar_tensor():
"float64", "float64",
"complex64", "complex64",
"complex128", "complex128",
] ],
)
for dtype in dtypes: def test_local_tensor_scalar_tensor(dtype):
t_type = TensorType(dtype=dtype, broadcastable=()) t_type = TensorType(dtype=dtype, broadcastable=())
t = t_type() t = t_type()
s = aet.scalar_from_tensor(t) s = aet.scalar_from_tensor(t)
...@@ -2848,15 +2782,14 @@ def test_local_tensor_scalar_tensor(): ...@@ -2848,15 +2782,14 @@ def test_local_tensor_scalar_tensor():
f = function([t], t2, mode=mode_opt) f = function([t], t2, mode=mode_opt)
e = f.maker.fgraph.toposort() e = f.maker.fgraph.toposort()
cast_nodes = [ assert not any(
n for n in e if isinstance(n.op, (TensorFromScalar, ScalarFromTensor)) [n for n in e if isinstance(n.op, (TensorFromScalar, ScalarFromTensor))]
] )
assert len(cast_nodes) == 0
f(0)
def test_local_scalar_tensor_scalar(): @pytest.mark.parametrize(
dtypes = [ "dtype",
[
"int8", "int8",
"int16", "int16",
"int32", "int32",
...@@ -2869,9 +2802,9 @@ def test_local_scalar_tensor_scalar(): ...@@ -2869,9 +2802,9 @@ def test_local_scalar_tensor_scalar():
"float64", "float64",
"complex64", "complex64",
"complex128", "complex128",
] ],
)
for dtype in dtypes: def test_local_scalar_tensor_scalar(dtype):
s_type = aes.Scalar(dtype=dtype) s_type = aes.Scalar(dtype=dtype)
s = s_type() s = s_type()
t = aet.tensor_from_scalar(s) t = aet.tensor_from_scalar(s)
...@@ -2879,11 +2812,9 @@ def test_local_scalar_tensor_scalar(): ...@@ -2879,11 +2812,9 @@ def test_local_scalar_tensor_scalar():
f = function([s], s2, mode=mode_opt) f = function([s], s2, mode=mode_opt)
e = f.maker.fgraph.toposort() e = f.maker.fgraph.toposort()
cast_nodes = [ assert not any(
n for n in e if isinstance(n.op, (TensorFromScalar, ScalarFromTensor)) [n for n in e if isinstance(n.op, (TensorFromScalar, ScalarFromTensor))]
] )
assert len(cast_nodes) == 0
f(0)
def test_local_useless_split(): def test_local_useless_split():
...@@ -2909,8 +2840,8 @@ def test_local_useless_split(): ...@@ -2909,8 +2840,8 @@ def test_local_useless_split():
assert check_stack_trace(f_nonopt, ops_to_check="all") assert check_stack_trace(f_nonopt, ops_to_check="all")
def test_local_flatten_lift(): @pytest.mark.parametrize("i", list(range(1, 4)))
for i in range(1, 4): def test_local_flatten_lift(i):
x = tensor4() x = tensor4()
out = aet.flatten(exp(x), i) out = aet.flatten(exp(x), i)
assert out.ndim == i assert out.ndim == i
...@@ -2924,9 +2855,7 @@ def test_local_flatten_lift(): ...@@ -2924,9 +2855,7 @@ def test_local_flatten_lift():
assert shape_out_np == out_np.shape assert shape_out_np == out_np.shape
reshape_nodes = [n for n in topo if isinstance(n.op, Reshape)] reshape_nodes = [n for n in topo if isinstance(n.op, Reshape)]
assert len(reshape_nodes) == 1 and aet.is_flat( assert len(reshape_nodes) == 1 and aet.is_flat(reshape_nodes[0].outputs[0], ndim=i)
reshape_nodes[0].outputs[0], ndim=i
)
assert isinstance(topo[-1].op, Elemwise) assert isinstance(topo[-1].op, Elemwise)
...@@ -3118,15 +3047,16 @@ class TestShapeI(utt.InferShapeTester): ...@@ -3118,15 +3047,16 @@ class TestShapeI(utt.InferShapeTester):
super().setup_method() super().setup_method()
def test_perform(self): def test_perform(self):
rng = np.random.default_rng(utt.fetch_seed())
advec = vector() advec = vector()
advec_val = np.random.random((3)).astype(config.floatX) advec_val = rng.random((3)).astype(config.floatX)
f = function([advec], Shape_i(0)(advec)) f = function([advec], Shape_i(0)(advec))
out = f(advec_val) out = f(advec_val)
utt.assert_allclose(out, advec_val.shape[0]) utt.assert_allclose(out, advec_val.shape[0])
admat = matrix() admat = matrix()
admat_val = np.random.random((4, 3)).astype(config.floatX) admat_val = rng.random((4, 3)).astype(config.floatX)
for i in range(2): for i in range(2):
f = function([admat], Shape_i(i)(admat)) f = function([admat], Shape_i(i)(admat))
out = f(admat_val) out = f(admat_val)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论