提交 d894f567 authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #6140 from lamblin/fix_gpuarray_filter

Let filter_variable fix broadcast between GPU vars
......@@ -91,6 +91,22 @@ def test_filter_float():
del theano.compile.sharedvalue.shared.constructors[-1]
def test_filter_variable():
# Test that filter_variable accepts more restrictive broadcast
gpu_row = GpuArrayType(dtype=theano.config.floatX,
broadcastable=(True, False))
gpu_matrix = GpuArrayType(dtype=theano.config.floatX,
broadcastable=(False, False))
r = gpu_row()
m = gpu_matrix.filter_variable(r)
assert m.type == gpu_matrix
# On CPU as well
r = theano.tensor.row()
m = gpu_matrix.filter_variable(r)
assert m.type == gpu_matrix
def test_gpuarray_shared_scalar():
# By default, we don't put scalar as shared variable on the GPU
nose.tools.assert_raises(
......
......@@ -303,8 +303,6 @@ class GpuArrayType(Type):
return data
def filter_variable(self, other, allow_convert=True):
from theano.gpuarray.basic_ops import GpuFromHost
if hasattr(other, '_as_GpuArrayVariable'):
other = other._as_GpuArrayVariable(self.context_name)
......@@ -314,7 +312,7 @@ class GpuArrayType(Type):
if other.type == self:
return other
if not isinstance(other.type, tensor.TensorType):
if not isinstance(other.type, (TensorType, GpuArrayType)):
raise TypeError('Incompatible type', (self, other.type))
if (other.type.dtype != self.dtype):
raise TypeError('Incompatible dtype', (self.dtype,
......@@ -335,7 +333,7 @@ class GpuArrayType(Type):
str(self.broadcastable)))
other = other2
return GpuFromHost(self.context_name)(other)
return other.transfer(self.context_name)
@staticmethod
def values_eq(a, b, force_same_dtype=True):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论