提交 10abe09a authored 作者: James Bergstra's avatar James Bergstra

fix to bug in gpu shared variable constructor when broadcastable flags specified by non-tuple

上级 b52c1470
import numpy
from theano.sandbox.cuda.var import float32_shared_constructor as f32sc
from theano.sandbox.cuda import CudaNdarrayType
def test_float32_shared_constructor():
npy_row = numpy.zeros((1,10), dtype='float32')
def eq(a,b):
return a==b
# test that we can create a CudaNdarray
assert (f32sc(npy_row).type == CudaNdarrayType((False, False)))
# test that broadcastable arg is accepted, and that they
# don't strictly have to be tuples
assert eq(
f32sc(npy_row, broadcastable=(True, False)).type,
CudaNdarrayType((True, False)))
assert eq(
f32sc(npy_row, broadcastable=[True, False]).type,
CudaNdarrayType((True, False)))
assert eq(
f32sc(npy_row, broadcastable=numpy.array([True, False])).type,
CudaNdarrayType([True, False]))
# test that we can make non-matrix shared vars
assert eq(
f32sc(numpy.zeros((2,3,4,5), dtype='float32')).type,
CudaNdarrayType((False,)*4))
...@@ -195,7 +195,9 @@ def float32_shared_constructor(value, name=None, strict=False, ...@@ -195,7 +195,9 @@ def float32_shared_constructor(value, name=None, strict=False,
else: else:
deviceval = value.copy() deviceval = value.copy()
else: else:
deviceval = type_support_filter(value, broadcastable, False, None) # type.broadcastable is guaranteed to be a tuple, which this next
# function requires
deviceval = type_support_filter(value, type.broadcastable, False, None)
try: try:
rval = CudaNdarraySharedVariable(type=type, value=deviceval, name=name, strict=strict) rval = CudaNdarraySharedVariable(type=type, value=deviceval, name=name, strict=strict)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论