提交 5103e4f4 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3902 from abergeron/fix_filter

Fix allow_downcast on gpuarray
...@@ -195,7 +195,7 @@ def rebuild_collect_shared(outputs, ...@@ -195,7 +195,7 @@ def rebuild_collect_shared(outputs,
store_into, store_into,
store_into.type, store_into.type,
update_val, update_val,
update_val.type)) getattr(update_val, 'type', None)))
err_sug = ('If the difference is related to the broadcast pattern,' err_sug = ('If the difference is related to the broadcast pattern,'
' you can call the' ' you can call the'
' tensor.unbroadcast(var, axis_to_unbroadcast[, ...])' ' tensor.unbroadcast(var, axis_to_unbroadcast[, ...])'
......
...@@ -3,6 +3,7 @@ import numpy ...@@ -3,6 +3,7 @@ import numpy
import theano import theano
from theano.compile import DeepCopyOp from theano.compile import DeepCopyOp
from .config import test_ctx_name
from .test_basic_ops import rand_gpuarray from .test_basic_ops import rand_gpuarray
from ..type import GpuArrayType from ..type import GpuArrayType
...@@ -37,3 +38,8 @@ def test_specify_shape(): ...@@ -37,3 +38,8 @@ def test_specify_shape():
g = GpuArrayType(dtype='float32', broadcastable=(False,))('g') g = GpuArrayType(dtype='float32', broadcastable=(False,))('g')
f = theano.function([g], theano.tensor.specify_shape(g, [20])) f = theano.function([g], theano.tensor.specify_shape(g, [20]))
f(a) f(a)
def test_filter_float():
s = theano.shared(numpy.array(0.0, dtype='float32'), target=test_ctx_name)
theano.function([], updates=[(s, 0.0)])
...@@ -189,13 +189,13 @@ class GpuArrayType(Type): ...@@ -189,13 +189,13 @@ class GpuArrayType(Type):
if self.context != data.context: if self.context != data.context:
raise TypeError("data context does not match type context") raise TypeError("data context does not match type context")
# fallthrough to ndim check # fallthrough to ndim check
elif (allow_downcast or elif (allow_downcast or
(allow_downcast is None and (allow_downcast is None and
type(data) == float and type(data) == float and
self.dtype == config.floatX)): self.dtype == config.floatX)):
data = gpuarray.array(data, dtype=self.typecode, copy=False, data = gpuarray.array(data, dtype=self.typecode, copy=False,
ndmin=len(self.broadcastable), ndmin=len(self.broadcastable),
context=self.context) context=self.context)
else: else:
if not hasattr(data, 'dtype'): if not hasattr(data, 'dtype'):
# This is to convert objects that don't have a dtype # This is to convert objects that don't have a dtype
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论