提交 ec537fd6 authored 作者: James Bergstra's avatar James Bergstra

adding unittest for CudaNdarrayType.filter with broadcasting

上级 1fef32f1
......@@ -75,6 +75,19 @@ class T_updates(unittest.TestCase):
updates=output_updates, givens=output_givens)
output_func()
def test_3(self):
# Test that broadcastable dimensions don't screw up
# update expressions.
data = numpy.random.rand(10,10).astype('float32')
output_var = f32sc(name="output",
value=numpy.zeros((10,10), 'float32'))
# the update_var has type matrix, and the update expression
# is a broadcasted scalar, and that should be allowed.
output_func = theano.function(inputs=[], outputs=[],
updates={output_var:output_var.sum().dimshuffle('x', 'x')})
output_func()
class T_ifelse(unittest.TestCase):
def setUp(self):
utt.seed_rng()
......
......@@ -128,9 +128,15 @@ class CudaNdarrayType(Type):
if (other.type.dtype != self.dtype):
raise TypeError('Incompatible dtype', (self.dtype,
other.type.dtype))
if (other.type.broadcastable != self.broadcastable):
if any(bi and not obi
for obi, bi in zip(
other.type.broadcastable,
self.broadcastable)):
raise TypeError('Incompatible broadcastable', (self.broadcastable,
other.type.broadcastable))
if other.type.broadcastable != self.broadcastable:
rebroadcast = tensor.Rebroadcast(*enumerate(self.broadcastable))
other = rebroadcast(other)
return theano.sandbox.cuda.basic_ops.GpuFromHost()(other)
@staticmethod
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论