提交 b55d3697 authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix opt failure for the new switch opt and better test it.

上级 5628b044
......@@ -1469,6 +1469,9 @@ def local_remove_switch_const_cond(node):
return False
if out.dtype != node.outputs[0].dtype:
out = T.cast(out, node.outputs[0].dtype)
if out.type.broadcastable != node.outputs[0].type.broadcastable:
# We need to copy data to the new dimensions during execution
out = T.alloc(out, *[node.outputs[0].shape[i] for i in range(out.ndim)])
return [out]
return False
......
......@@ -2272,31 +2272,84 @@ class T_local_erfc(unittest.TestCase):
class test_local_remove_switch_const_cond(unittest.TestCase):
def setUp(self):
self.mode = theano.compile.get_default_mode().excluding('constant_folding')
def test_const0(self):
x = theano.tensor.matrix('x', dtype='int64')
y = theano.tensor.matrix('y', dtype='int64')
z = theano.tensor.switch(0, x , y)
f = theano.function([x,y],z)
assert len([x for x in f.maker.env.toposort() if
isinstance(x,theano.tensor.Elemwise) ]) == 0
vx = numpy.array([[1,2,3],[ 4, 5, 6]], dtype='int64')
vy = numpy.array([[7,8,9],[10,11,12]], dtype='int64')
assert numpy.all(f(vx,vy) == vy)
for dtype1 in ['int32', 'int64']:
for dtype2 in ['int32', 'int64']:
x = theano.tensor.matrix('x', dtype=dtype1)
y = theano.tensor.matrix('y', dtype=dtype2)
z = theano.tensor.switch(0, x, y)
f = theano.function([x,y], z, mode=self.mode)
assert len([node.op for node in f.maker.env.toposort() if
isinstance(node.op,theano.tensor.Elemwise) ]) == 0
vx = numpy.array([[1,2,3],[ 4, 5, 6]], dtype=dtype1)
vy = numpy.array([[7,8,9],[10,11,12]], dtype=dtype2)
assert numpy.all(f(vx,vy) == vy)
def test_const1(self):
x = theano.tensor.matrix('x', dtype='int64')
for dtype1 in ['int32', 'int64']:
for dtype2 in ['int32', 'int64']:
x = theano.tensor.matrix('x', dtype=dtype1)
y = theano.tensor.matrix('y', dtype=dtype2)
z = theano.tensor.switch(1, x, y)
f = theano.function([x,y], z, mode=self.mode)
assert len([node.op for node in f.maker.env.toposort() if
isinstance(node.op,theano.tensor.Elemwise) ]) == 0
vx = numpy.array([[1,2,3],[ 4, 5, 6]], dtype=dtype1)
vy = numpy.array([[7,8,9],[10,11,12]], dtype=dtype2)
assert numpy.all(f(vx,vy) == vx)
def test_broadcast1(self):
#test switch(cst, matrix, row)
x = theano.tensor.matrix('x', dtype='int32')
y = theano.tensor.vector('y', dtype='int64')
z = theano.tensor.switch(1, x, y)
f = theano.function([x,y], z, mode=self.mode)
#theano.printing.debugprint(f)
assert len([node.op for node in f.maker.env.toposort() if
isinstance(node.op,theano.tensor.Elemwise) and
not isinstance(node.op.scalar_op,theano.scalar.basic.Cast)]) == 0
vx = numpy.array([[1, 2, 3],[ 4, 5, 6]], dtype='int32')
vy = numpy.array([10,11,12], dtype='int64')
assert numpy.all(f(vx,vy) == vx)
z = theano.tensor.switch(0, x, y)
f = theano.function([x,y], z, mode=self.mode)
#theano.printing.debugprint(f)
assert len([node.op for node in f.maker.env.toposort() if
isinstance(node.op,theano.tensor.Elemwise) ]) == 0
vx = numpy.array([[1, 2, 3],[ 4, 5, 6]], dtype='int32')
vy = numpy.array([10,11,12], dtype='int64')
assert numpy.all(f(vx,vy) == vy)
def test_broadcast2(self):
#test switch(cst, vector, matrix)
#This case is not optimized for now.
x = theano.tensor.vector('x', dtype='int32')
y = theano.tensor.matrix('y', dtype='int64')
z = theano.tensor.switch(1, x , y)
f = theano.function([x,y],z)
assert len([x for x in f.maker.env.toposort() if
isinstance(x,theano.tensor.Elemwise) ]) == 0
vx = numpy.array([[1,2,3],[ 4, 5, 6]], dtype='int64')
z = theano.tensor.switch(1, x, y)
f = theano.function([x,y], z, mode=self.mode)
assert len([node.op for node in f.maker.env.toposort() if
isinstance(node.op,theano.tensor.Elemwise) and
not isinstance(node.op.scalar_op,theano.scalar.basic.Cast)]) == 0
vx = numpy.array([ 4, 5, 6], dtype='int32')
vy = numpy.array([[7,8,9],[10,11,12]], dtype='int64')
assert numpy.all(f(vx,vy) == vx)
z = theano.tensor.switch(0, x, y)
f = theano.function([x,y], z, mode=self.mode)
assert len([node.op for node in f.maker.env.toposort() if
isinstance(node.op,theano.tensor.Elemwise) ]) == 0
vx = numpy.array([ 4, 5, 6], dtype='int32')
vy = numpy.array([[7,8,9],[10,11,12]], dtype='int64')
assert numpy.all(f(vx,vy) == vy)
class T_local_sum(unittest.TestCase):
def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论