提交 d894f567 authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #6140 from lamblin/fix_gpuarray_filter

Let filter_variable fix broadcast between GPU vars
...@@ -91,6 +91,22 @@ def test_filter_float(): ...@@ -91,6 +91,22 @@ def test_filter_float():
del theano.compile.sharedvalue.shared.constructors[-1] del theano.compile.sharedvalue.shared.constructors[-1]
def test_filter_variable():
# Test that filter_variable accepts more restrictive broadcast
gpu_row = GpuArrayType(dtype=theano.config.floatX,
broadcastable=(True, False))
gpu_matrix = GpuArrayType(dtype=theano.config.floatX,
broadcastable=(False, False))
r = gpu_row()
m = gpu_matrix.filter_variable(r)
assert m.type == gpu_matrix
# On CPU as well
r = theano.tensor.row()
m = gpu_matrix.filter_variable(r)
assert m.type == gpu_matrix
def test_gpuarray_shared_scalar(): def test_gpuarray_shared_scalar():
# By default, we don't put scalar as shared variable on the GPU # By default, we don't put scalar as shared variable on the GPU
nose.tools.assert_raises( nose.tools.assert_raises(
......
...@@ -303,8 +303,6 @@ class GpuArrayType(Type): ...@@ -303,8 +303,6 @@ class GpuArrayType(Type):
return data return data
def filter_variable(self, other, allow_convert=True): def filter_variable(self, other, allow_convert=True):
from theano.gpuarray.basic_ops import GpuFromHost
if hasattr(other, '_as_GpuArrayVariable'): if hasattr(other, '_as_GpuArrayVariable'):
other = other._as_GpuArrayVariable(self.context_name) other = other._as_GpuArrayVariable(self.context_name)
...@@ -314,7 +312,7 @@ class GpuArrayType(Type): ...@@ -314,7 +312,7 @@ class GpuArrayType(Type):
if other.type == self: if other.type == self:
return other return other
if not isinstance(other.type, tensor.TensorType): if not isinstance(other.type, (TensorType, GpuArrayType)):
raise TypeError('Incompatible type', (self, other.type)) raise TypeError('Incompatible type', (self, other.type))
if (other.type.dtype != self.dtype): if (other.type.dtype != self.dtype):
raise TypeError('Incompatible dtype', (self.dtype, raise TypeError('Incompatible dtype', (self.dtype,
...@@ -335,7 +333,7 @@ class GpuArrayType(Type): ...@@ -335,7 +333,7 @@ class GpuArrayType(Type):
str(self.broadcastable))) str(self.broadcastable)))
other = other2 other = other2
return GpuFromHost(self.context_name)(other) return other.transfer(self.context_name)
@staticmethod @staticmethod
def values_eq(a, b, force_same_dtype=True): def values_eq(a, b, force_same_dtype=True):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论