提交 d3d56a6c authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Clean up tests.scalar.test_basic

- Remove unused tests and unnecessary helper functions - Rename camel-case tests - Use pytest.raises
上级 8e8eda2b
...@@ -69,58 +69,21 @@ from theano.scalar.basic import ( ...@@ -69,58 +69,21 @@ from theano.scalar.basic import (
) )
def inputs(): def test_mul_add_div_proxy():
return floats("xyz") x, y, z = floats("xyz")
e = mul(add(x, y), div_proxy(x, y))
g = FunctionGraph([x, y], [e])
class TestScalarOps: fn = gof.DualLinker().accept(g).make_function()
def test_straightforward(self): assert fn(1.0, 2.0) == 1.5
x, y, z = inputs()
e = mul(add(x, y), div_proxy(x, y))
g = FunctionGraph([x, y], [e])
fn = gof.DualLinker().accept(g).make_function()
assert fn(1.0, 2.0) == 1.5
# This test is moved to tests.tensor.test_basic.py:test_mod
# We move it their as under ubuntu the c_extract call of theano.scalar
# call PyInt_check and it fail under some os. If work in other case.
# As we use theano.scalar normally, but we use theano.tensor.scalar
# that is not important. Also this make the theano fct fail at call time
# so this is not a silent bug.
# --> This is why it is purposely named 'tes_mod' instead of 'test_mod'.
def tes_mod(self):
# We add this test as not all language and C implementation give the same
# sign to the result. This check that the c_code of `Mod` is implemented
# as Python. That is what we want.
x, y = ints("xy")
fn = gof.DualLinker().accept(FunctionGraph([x, y], [x % y])).make_function()
for a, b in (
(0, 1),
(1, 1),
(0, -1),
(1, -1),
(-1, -1),
(1, 2),
(-1, 2),
(1, -2),
(-1, -2),
(5, 3),
(-5, 3),
(5, -3),
(-5, -3),
):
assert fn(a, b) == a % b, (a,)
def has_f16(comp):
if any(v.type == float16 for v in comp.fgraph.variables):
return True
return False
class TestComposite: class TestComposite:
def test_composite_clone_float32(self): def test_composite_clone_float32(self):
def has_f16(comp):
if any(v.type == float16 for v in comp.fgraph.variables):
return True
return False
w = int8() w = int8()
x = float16() x = float16()
y = float32() y = float32()
...@@ -153,7 +116,7 @@ class TestComposite: ...@@ -153,7 +116,7 @@ class TestComposite:
assert not has_f16(nc) assert not has_f16(nc)
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = floats("xyz")
e = mul(add(x, y), div_proxy(x, y)) e = mul(add(x, y), div_proxy(x, y))
C = Composite([x, y], [e]) C = Composite([x, y], [e])
c = C.make_node(x, y) c = C.make_node(x, y)
...@@ -164,7 +127,7 @@ class TestComposite: ...@@ -164,7 +127,7 @@ class TestComposite:
def test_flatten(self): def test_flatten(self):
# Test that we flatten multiple Composite. # Test that we flatten multiple Composite.
x, y, z = inputs() x, y, z = floats("xyz")
C = Composite([x, y], [x + y]) C = Composite([x, y], [x + y])
CC = Composite([x, y], [C(x * y, y)]) CC = Composite([x, y], [C(x * y, y)])
assert not isinstance(CC.outputs[0].owner.op, Composite) assert not isinstance(CC.outputs[0].owner.op, Composite)
...@@ -175,7 +138,7 @@ class TestComposite: ...@@ -175,7 +138,7 @@ class TestComposite:
assert isinstance(CC.outputs[0].owner.op, Composite) assert isinstance(CC.outputs[0].owner.op, Composite)
def test_with_constants(self): def test_with_constants(self):
x, y, z = inputs() x, y, z = floats("xyz")
e = mul(add(70.0, y), div_proxy(x, y)) e = mul(add(70.0, y), div_proxy(x, y))
C = Composite([x, y], [e]) C = Composite([x, y], [e])
c = C.make_node(x, y) c = C.make_node(x, y)
...@@ -186,7 +149,7 @@ class TestComposite: ...@@ -186,7 +149,7 @@ class TestComposite:
assert fn(1.0, 2.0) == 36.0 assert fn(1.0, 2.0) == 36.0
def test_many_outputs(self): def test_many_outputs(self):
x, y, z = inputs() x, y, z = floats("xyz")
e0 = x + y + z e0 = x + y + z
e1 = x + y * z e1 = x + y * z
e2 = x / y e2 = x / y
...@@ -243,37 +206,37 @@ class TestComposite: ...@@ -243,37 +206,37 @@ class TestComposite:
class TestLogical: class TestLogical:
def test_gt(self): def test_gt(self):
x, y, z = inputs() x, y, z = floats("xyz")
fn = gof.DualLinker().accept(FunctionGraph([x, y], [x > y])).make_function() fn = gof.DualLinker().accept(FunctionGraph([x, y], [x > y])).make_function()
for a, b in ((3.0, 9), (3, 0.9), (3, 3)): for a, b in ((3.0, 9), (3, 0.9), (3, 3)):
assert fn(a, b) == (a > b) assert fn(a, b) == (a > b)
def test_lt(self): def test_lt(self):
x, y, z = inputs() x, y, z = floats("xyz")
fn = gof.DualLinker().accept(FunctionGraph([x, y], [x < y])).make_function() fn = gof.DualLinker().accept(FunctionGraph([x, y], [x < y])).make_function()
for a, b in ((3.0, 9), (3, 0.9), (3, 3)): for a, b in ((3.0, 9), (3, 0.9), (3, 3)):
assert fn(a, b) == (a < b) assert fn(a, b) == (a < b)
def test_le(self): def test_le(self):
x, y, z = inputs() x, y, z = floats("xyz")
fn = gof.DualLinker().accept(FunctionGraph([x, y], [x <= y])).make_function() fn = gof.DualLinker().accept(FunctionGraph([x, y], [x <= y])).make_function()
for a, b in ((3.0, 9), (3, 0.9), (3, 3)): for a, b in ((3.0, 9), (3, 0.9), (3, 3)):
assert fn(a, b) == (a <= b) assert fn(a, b) == (a <= b)
def test_ge(self): def test_ge(self):
x, y, z = inputs() x, y, z = floats("xyz")
fn = gof.DualLinker().accept(FunctionGraph([x, y], [x >= y])).make_function() fn = gof.DualLinker().accept(FunctionGraph([x, y], [x >= y])).make_function()
for a, b in ((3.0, 9), (3, 0.9), (3, 3)): for a, b in ((3.0, 9), (3, 0.9), (3, 3)):
assert fn(a, b) == (a >= b) assert fn(a, b) == (a >= b)
def test_eq(self): def test_eq(self):
x, y, z = inputs() x, y, z = floats("xyz")
fn = gof.DualLinker().accept(FunctionGraph([x, y], [eq(x, y)])).make_function() fn = gof.DualLinker().accept(FunctionGraph([x, y], [eq(x, y)])).make_function()
for a, b in ((3.0, 9), (3, 0.9), (3, 3)): for a, b in ((3.0, 9), (3, 0.9), (3, 3)):
assert fn(a, b) == (a == b) assert fn(a, b) == (a == b)
def test_neq(self): def test_neq(self):
x, y, z = inputs() x, y, z = floats("xyz")
fn = gof.DualLinker().accept(FunctionGraph([x, y], [neq(x, y)])).make_function() fn = gof.DualLinker().accept(FunctionGraph([x, y], [neq(x, y)])).make_function()
for a, b in ((3.0, 9), (3, 0.9), (3, 3)): for a, b in ((3.0, 9), (3, 0.9), (3, 3)):
assert fn(a, b) == (a != b) assert fn(a, b) == (a != b)
...@@ -424,36 +387,33 @@ class TestUpgradeToFloat: ...@@ -424,36 +387,33 @@ class TestUpgradeToFloat:
self._test_binary(binary_op, x_range, y_range) self._test_binary(binary_op, x_range, y_range)
class TestComplexMod: def test_mod_complex_fail():
# Make sure % fails on complex numbers. # Make sure % fails on complex numbers.
x = complex64()
def test_fail(self): y = int32()
x = complex64() with pytest.raises(ComplexError):
y = int32() x % y
with pytest.raises(ComplexError):
x % y
def test_div():
a = int8()
class TestDiv: b = int32()
def test_0(self): c = complex64()
a = int8() d = float64()
b = int32() f = float32()
c = complex64()
d = float64() assert isinstance((a // b).owner.op, IntDiv)
f = float32() assert isinstance((b // a).owner.op, IntDiv)
assert isinstance((b / d).owner.op, TrueDiv)
assert isinstance((a // b).owner.op, IntDiv) assert isinstance((b / f).owner.op, TrueDiv)
assert isinstance((b // a).owner.op, IntDiv) assert isinstance((f / a).owner.op, TrueDiv)
assert isinstance((b / d).owner.op, TrueDiv) assert isinstance((d / b).owner.op, TrueDiv)
assert isinstance((b / f).owner.op, TrueDiv) assert isinstance((d / f).owner.op, TrueDiv)
assert isinstance((f / a).owner.op, TrueDiv) assert isinstance((f / c).owner.op, TrueDiv)
assert isinstance((d / b).owner.op, TrueDiv) assert isinstance((a / c).owner.op, TrueDiv)
assert isinstance((d / f).owner.op, TrueDiv)
assert isinstance((f / c).owner.op, TrueDiv)
assert isinstance((a / c).owner.op, TrueDiv) def test_grad_gt():
def TestGradGt():
x = float32(name="x") x = float32(name="x")
y = float32(name="y") y = float32(name="y")
z = x > y z = x > y
...@@ -461,7 +421,7 @@ def TestGradGt(): ...@@ -461,7 +421,7 @@ def TestGradGt():
assert g.eval({y: 1.0}) == 0.0 assert g.eval({y: 1.0}) == 0.0
def TestGradSwitch(): def test_grad_switch():
# This is a code snippet from the mailing list # This is a code snippet from the mailing list
# It caused an assert to be raised due to the # It caused an assert to be raised due to the
...@@ -477,7 +437,7 @@ def TestGradSwitch(): ...@@ -477,7 +437,7 @@ def TestGradSwitch():
theano.gradient.grad(l, x) theano.gradient.grad(l, x)
def TestGradIdentity(): def test_grad_identity():
# Check that the grad method of Identity correctly handles int dytpes # Check that the grad method of Identity correctly handles int dytpes
x = theano.tensor.imatrix("x") x = theano.tensor.imatrix("x")
# tensor_copy is Elemwise{Identity} # tensor_copy is Elemwise{Identity}
...@@ -486,7 +446,7 @@ def TestGradIdentity(): ...@@ -486,7 +446,7 @@ def TestGradIdentity():
theano.gradient.grad(l, x) theano.gradient.grad(l, x)
def TestGradInrange(): def test_grad_inrange():
for bound_definition in [(True, True), (False, False)]: for bound_definition in [(True, True), (False, False)]:
# Instantiate op, and then take the gradient # Instantiate op, and then take the gradient
op = InRange(*bound_definition) op = InRange(*bound_definition)
...@@ -512,7 +472,7 @@ def TestGradInrange(): ...@@ -512,7 +472,7 @@ def TestGradInrange():
utt.assert_allclose(f(7, 1, 5), [0, 0, 0]) utt.assert_allclose(f(7, 1, 5), [0, 0, 0])
def TestGradAbs(): def test_grad_abs():
a = theano.tensor.fscalar("a") a = theano.tensor.fscalar("a")
b = theano.tensor.nnet.relu(a) b = theano.tensor.nnet.relu(a)
c = theano.grad(b, a) c = theano.grad(b, a)
...@@ -523,11 +483,7 @@ def TestGradAbs(): ...@@ -523,11 +483,7 @@ def TestGradAbs():
assert ret == 0.5, ret assert ret == 0.5, ret
# Testing of Composite is done in tensor/tests/test_opt.py def test_constant():
# in test_fusion, TestCompositeCodegen
def TestConstant():
c = constant(2, name="a") c = constant(2, name="a")
assert c.name == "a" assert c.name == "a"
assert c.dtype == "int8" assert c.dtype == "int8"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论