提交 a31a572e authored 作者: David Warde-Farley's avatar David Warde-Farley

Merge.

...@@ -169,51 +169,140 @@ class Test_pfunc(unittest.TestCase): ...@@ -169,51 +169,140 @@ class Test_pfunc(unittest.TestCase):
# ... but not to b.value ! # ... but not to b.value !
assert not (bval == b.value).all() assert not (bval == b.value).all()
def test_param_allow_downcast(self): def test_param_allow_downcast_int(self):
a = tensor.wvector('a') # int16 a = tensor.wvector('a') # int16
b = tensor.bvector('b') # int8 b = tensor.bvector('b') # int8
c = tensor.bscalar('c') # int8
f = pfunc([Param(a, allow_downcast=True), f = pfunc([Param(a, allow_downcast=True),
Param(b, allow_downcast=False)], Param(b, allow_downcast=False),
a+b) Param(c, allow_downcast=None)],
a+b+c)
# Both values are in range. Since they're not ndarrays (but lists), # Both values are in range. Since they're not ndarrays (but lists),
# they will be converted, and their value checked. # they will be converted, and their value checked.
assert numpy.all(f([3], [6]) == 9) assert numpy.all(f([3], [6], 1) == 10)
# Values are in range, but a dtype too large has explicitly been given # Values are in range, but a dtype too large has explicitly been given
# For performance reasons, no check of the data is explicitly performed # For performance reasons, no check of the data is explicitly performed
# (It might be OK to change this in the future.) # (It might be OK to change this in the future.)
self.assertRaises(TypeError, f, self.assertRaises(TypeError, f,
[3], numpy.array([6], dtype='int16')) [3], numpy.array([6], dtype='int16'), 1)
# Value too big for a, silently ignored # Value too big for a, silently ignored
assert numpy.all(f([2**20], numpy.ones(1, dtype='int8')) == 1) assert numpy.all(f([2**20], numpy.ones(1, dtype='int8'), 1) == 2)
# Value too big for b, raises TypeError # Value too big for b, raises TypeError
self.assertRaises(TypeError, f, [3], [312]) self.assertRaises(TypeError, f, [3], [312], 1)
def test_allow_input_downcast(self): # Value too big for c, raises TypeError
self.assertRaises(TypeError, f, [3], [6], 806)
def test_param_allow_downcast_floatX(self):
a = tensor.fscalar('a')
b = tensor.fscalar('b')
c = tensor.fscalar('c')
f = pfunc([Param(a, allow_downcast=True),
Param(b, allow_downcast=False),
Param(c, allow_downcast=None)],
a+b+c)
# If the values can be accurately represented, everything is OK
assert numpy.all(f(0, 0, 0) == 0)
# If allow_downcast is True, idem
assert numpy.allclose(f(0.1, 0, 0), 0.1)
# If allow_downcast is False, nope
self.assertRaises(TypeError, f, 0, 0.1, 0)
# If allow_downcast is None, it should work iff floatX=float32
if config.floatX == 'float32':
assert numpy.allclose(f(0, 0, 0.1), 0.1)
else:
self.assertRaises(TypeError, f, 0, 0, 0.1)
def test_param_allow_downcast_vector_floatX(self):
a = tensor.fvector('a')
b = tensor.fvector('b')
c = tensor.fvector('c')
f = pfunc([Param(a, allow_downcast=True),
Param(b, allow_downcast=False),
Param(c, allow_downcast=None)],
a+b+c)
# If the values can be accurately represented, everything is OK
z = [0]
assert numpy.all(f(z, z, z) == 0)
# If allow_downcast is True, idem
assert numpy.allclose(f([0.1], z, z), 0.1)
# If allow_downcast is False, nope
self.assertRaises(TypeError, f, z, [0.1], z)
# If allow_downcast is None, like False
self.assertRaises(TypeError, f, z, z, [0.1])
def test_allow_input_downcast_int(self):
a = tensor.wvector('a') # int16 a = tensor.wvector('a') # int16
b = tensor.bvector('b') # int8 b = tensor.bvector('b') # int8
c = tensor.bscalar('c') # int8
f = pfunc([a, b], a+b, allow_input_downcast=True) f = pfunc([a, b, c], a+b+c, allow_input_downcast=True)
# Value too big for a or b, silently ignored # Value too big for a, b, or c, silently ignored
assert f([2**20], [1]) == 1 assert f([2**20], [1], 0) == 1
assert f([3], [312]) == 59 assert f([3], [312], 0) == 59
assert f([3], [1], 806) == 42
g = pfunc([a, b], a+b, allow_input_downcast=False) g = pfunc([a, b, c], a+b+c, allow_input_downcast=False)
# Both values are in range. Since they're not ndarrays (but lists), # All values are in range. Since they're not ndarrays (but lists
# they will be converted, and their value checked. # or scalars), they will be converted, and their value checked.
assert numpy.all(g([3], [6]) == 9) assert numpy.all(g([3], [6], 0) == 9)
# Values are in range, but a dtype too large has explicitly been given # Values are in range, but a dtype too large has explicitly been given
# For performance reasons, no check of the data is explicitly performed # For performance reasons, no check of the data is explicitly performed
# (It might be OK to change this in the future.) # (It might be OK to change this in the future.)
self.assertRaises(TypeError, g, self.assertRaises(TypeError, g,
[3], numpy.array([6], dtype='int16')) [3], numpy.array([6], dtype='int16'), 0)
# Value too big for b, raises TypeError # Value too big for b, raises TypeError
self.assertRaises(TypeError, g, [3], [312]) self.assertRaises(TypeError, g, [3], [312], 0)
h = pfunc([a, b, c], a+b+c) # Default: allow_input_downcast=None
# Everything here should behave like with False
assert numpy.all(h([3], [6], 0) == 9)
self.assertRaises(TypeError, h,
[3], numpy.array([6], dtype='int16'), 0)
self.assertRaises(TypeError, h, [3], [312], 0)
def test_allow_downcast_floatX(self):
a = tensor.fscalar('a')
b = tensor.fvector('b')
f = pfunc([a, b], a+b, allow_input_downcast=True)
g = pfunc([a, b], a+b, allow_input_downcast=False)
h = pfunc([a, b], a+b, allow_input_downcast=None)
# If the values can be accurately represented, OK
assert numpy.all(f(0, [0]) == 0)
assert numpy.all(g(0, [0]) == 0)
assert numpy.all(h(0, [0]) == 0)
# For the vector: OK iff allow_input_downcast is True
assert numpy.allclose(f(0, [0.1]), 0.1)
self.assertRaises(TypeError, g, 0, [0.1])
self.assertRaises(TypeError, h, 0, [0.1])
# For the scalar: OK if allow_input_downcast is True,
# or None and floatX==float32
assert numpy.allclose(f(0.1, [0]), 0.1)
self.assertRaises(TypeError, g, 0.1, [0])
if config.floatX == 'float32':
assert numpy.allclose(h(0.1, [0]), 0.1)
else:
self.assertRaises(TypeError, h, 0.1, [0])
def test_update(self): def test_update(self):
"""Test update mechanism in different settings.""" """Test update mechanism in different settings."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论