提交 d14b15b7 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix test_Broadcast to respect the C/perform division correctly.

上级 d3bfaae3
...@@ -166,10 +166,10 @@ class test_Broadcast(unittest.TestCase): ...@@ -166,10 +166,10 @@ class test_Broadcast(unittest.TestCase):
linkers = [gof.PerformLinker, gof.CLinker] linkers = [gof.PerformLinker, gof.CLinker]
def rand_val(self, shp): def rand_val(self, shp):
return numpy.asarray(numpy.random.rand(*shp)) return numpy.asarray(numpy.random.rand(*shp), dtype='float32')
def rand_cval(self, shp): def rand_cval(self, shp):
return numpy.asarray(numpy.random.rand(*shp)) return numpy.asarray(numpy.random.rand(*shp), dtype='float32')
def setUp(self): def setUp(self):
unittest_tools.seed_rng() unittest_tools.seed_rng()
...@@ -189,8 +189,8 @@ class test_Broadcast(unittest.TestCase): ...@@ -189,8 +189,8 @@ class test_Broadcast(unittest.TestCase):
((2, 3, 4, 5), (1, 3, 1, 5)), ((2, 3, 4, 5), (1, 3, 1, 5)),
((2, 3, 4, 5), (1, 1, 1, 1)), ((2, 3, 4, 5), (1, 1, 1, 1)),
((), ())]: ((), ())]:
x = type('float64', [(entry == 1) for entry in xsh])('x') x = type('float32', [(entry == 1) for entry in xsh])('x')
y = type('float64', [(entry == 1) for entry in ysh])('y') y = type('float32', [(entry == 1) for entry in ysh])('y')
e = op(scalar.add)(x, y) e = op(scalar.add)(x, y)
f = copy(linker).accept(FunctionGraph([x, y], [e])).make_function() f = copy(linker).accept(FunctionGraph([x, y], [e])).make_function()
xv = rand_val(xsh) xv = rand_val(xsh)
...@@ -202,8 +202,8 @@ class test_Broadcast(unittest.TestCase): ...@@ -202,8 +202,8 @@ class test_Broadcast(unittest.TestCase):
# test Elemwise.infer_shape # test Elemwise.infer_shape
# the Shape op don't implement c_code! # the Shape op don't implement c_code!
if isinstance(linker, gof.PerformLinker): if isinstance(linker, gof.PerformLinker):
x = type('float64', [(entry == 1) for entry in xsh])('x') x = type('float32', [(entry == 1) for entry in xsh])('x')
y = type('float64', [(entry == 1) for entry in ysh])('y') y = type('float32', [(entry == 1) for entry in ysh])('y')
e = op(scalar.add)(x, y) e = op(scalar.add)(x, y)
f = copy(linker).accept(FunctionGraph( f = copy(linker).accept(FunctionGraph(
[x, y], [e.shape])).make_function() [x, y], [e.shape])).make_function()
...@@ -218,8 +218,8 @@ class test_Broadcast(unittest.TestCase): ...@@ -218,8 +218,8 @@ class test_Broadcast(unittest.TestCase):
((2, 3, 4, 5), (1, 3, 1, 5)), ((2, 3, 4, 5), (1, 3, 1, 5)),
((2, 3, 4, 5), (1, 1, 1, 1)), ((2, 3, 4, 5), (1, 1, 1, 1)),
((), ())]: ((), ())]:
x = type('float64', [(entry == 1) for entry in xsh])('x') x = type('float32', [(entry == 1) for entry in xsh])('x')
y = type('float64', [(entry == 1) for entry in ysh])('y') y = type('float32', [(entry == 1) for entry in ysh])('y')
e = op(scalar.Add(scalar.transfer_type(0)), {0: 0})(x, y) e = op(scalar.Add(scalar.transfer_type(0)), {0: 0})(x, y)
f = copy(linker).accept(FunctionGraph([x, y], [e])).make_function() f = copy(linker).accept(FunctionGraph([x, y], [e])).make_function()
xv = rand_val(xsh) xv = rand_val(xsh)
...@@ -232,8 +232,8 @@ class test_Broadcast(unittest.TestCase): ...@@ -232,8 +232,8 @@ class test_Broadcast(unittest.TestCase):
# test Elemwise.infer_shape # test Elemwise.infer_shape
# the Shape op don't implement c_code! # the Shape op don't implement c_code!
if isinstance(linker, gof.PerformLinker): if isinstance(linker, gof.PerformLinker):
x = type('float64', [(entry == 1) for entry in xsh])('x') x = type('float32', [(entry == 1) for entry in xsh])('x')
y = type('float64', [(entry == 1) for entry in ysh])('y') y = type('float32', [(entry == 1) for entry in ysh])('y')
e = op(scalar.Add(scalar.transfer_type(0)), {0: 0})(x, y) e = op(scalar.Add(scalar.transfer_type(0)), {0: 0})(x, y)
f = copy(linker).accept(FunctionGraph( f = copy(linker).accept(FunctionGraph(
[x, y], [e.shape])).make_function() [x, y], [e.shape])).make_function()
...@@ -267,13 +267,15 @@ class test_Broadcast(unittest.TestCase): ...@@ -267,13 +267,15 @@ class test_Broadcast(unittest.TestCase):
def test_fill(self): def test_fill(self):
if not theano.config.cxx: if not theano.config.cxx:
raise SkipTest("G++ not available, so we need to skip this test.") raise SkipTest("G++ not available, so we need to skip this test.")
x = self.ctype('float64', [0, 0])('x') for linker, op, t, rval in zip(self.linkers, [self.op, self.cop],
y = self.ctype('float64', [1, 1])('y') [self.type, self.ctype],
for linker, op in zip(self.linkers, [self.op, self.cop]): [self.rand_val, self.rand_cval]):
x = t('float32', [0, 0])('x')
y = t('float32', [1, 1])('y')
e = op(scalar.Second(scalar.transfer_type(0)), {0: 0})(x, y) e = op(scalar.Second(scalar.transfer_type(0)), {0: 0})(x, y)
f = linker().accept(FunctionGraph([x, y], [e])).make_function() f = linker().accept(FunctionGraph([x, y], [e])).make_function()
xv = self.rand_cval((5, 5)) xv = rval((5, 5))
yv = self.rand_cval((1, 1)) yv = rval((1, 1))
f(xv, yv) f(xv, yv)
assert (xv == yv).all() assert (xv == yv).all()
...@@ -292,24 +294,28 @@ class test_Broadcast(unittest.TestCase): ...@@ -292,24 +294,28 @@ class test_Broadcast(unittest.TestCase):
def test_weird_strides(self): def test_weird_strides(self):
if not theano.config.cxx: if not theano.config.cxx:
raise SkipTest("G++ not available, so we need to skip this test.") raise SkipTest("G++ not available, so we need to skip this test.")
x = self.ctype('float64', [0, 0, 0, 0, 0])('x') for linker, op, t, rval in zip(self.linkers, [self.op, self.cop],
y = self.ctype('float64', [0, 0, 0, 0, 0])('y') [self.type, self.ctype],
for linker, op in zip(self.linkers, [self.op, self.cop]): [self.rand_val, self.rand_cval]):
x = t('float32', [0, 0, 0, 0, 0])('x')
y = t('float32', [0, 0, 0, 0, 0])('y')
e = op(scalar.add)(x, y) e = op(scalar.add)(x, y)
f = linker().accept(FunctionGraph([x, y], [e])).make_function() f = linker().accept(FunctionGraph([x, y], [e])).make_function()
xv = self.rand_cval((2, 2, 2, 2, 2)) xv = rval((2, 2, 2, 2, 2))
yv = self.rand_cval((2, 2, 2, 2, 2)).transpose(4, 0, 3, 1, 2) yv = rval((2, 2, 2, 2, 2)).transpose(4, 0, 3, 1, 2)
zv = xv + yv zv = xv + yv
assert (f(xv, yv) == zv).all() assert (f(xv, yv) == zv).all()
def test_same_inputs(self): def test_same_inputs(self):
if not theano.config.cxx: if not theano.config.cxx:
raise SkipTest("G++ not available, so we need to skip this test.") raise SkipTest("G++ not available, so we need to skip this test.")
x = self.ctype('float64', [0, 0])('x') for linker, op, t, rval in zip(self.linkers, [self.op, self.cop],
for linker, op in zip(self.linkers, [self.op, self.cop]): [self.type, self.ctype],
[self.rand_val, self.rand_cval]):
x = t('float32', [0, 0])('x')
e = op(scalar.add)(x, x) e = op(scalar.add)(x, x)
f = linker().accept(FunctionGraph([x], [e])).make_function() f = linker().accept(FunctionGraph([x], [e])).make_function()
xv = self.rand_cval((2, 2)) xv = rval((2, 2))
zv = xv + xv zv = xv + xv
assert (f(xv) == zv).all() assert (f(xv) == zv).all()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论