提交 39350199 authored 作者: nouiz's avatar nouiz

Merge pull request #209 from jaberg/rebase2_deepcopy_type_filter_bug

Rebase2 deepcopy type filter bug
......@@ -8,6 +8,7 @@ import numpy
import theano
def _asarray(a, dtype, order=None):
"""Convert the input to a Numpy array.
......
......@@ -15,8 +15,6 @@ import numpy, theano
from theano import gof
from theano.gof import Apply, Constant, Op, Type, Value, Variable
import elemwise
from theano import scalar as scal
from theano.gof.python25 import partial, any, all
......@@ -544,21 +542,29 @@ class TensorType(Type):
'Expected an array-like object, but found a Variable: '
'maybe you are trying to call a function on a (possibly '
'shared) variable instead of a numeric array?')
if (type(data) is numpy.ndarray) and (data.dtype is self.numpy_dtype):
pass # fall through to ndim check
if ((type(data) is numpy.ndarray)
and (data.dtype == self.numpy_dtype)):
if data.dtype.num != self.numpy_dtype.num:
data = theano._asarray(data, dtype=self.dtype)
# -- now fall through to ndim check
elif strict:
# If any of the two conditions above was not met,
# we raise a meaningful TypeError.
if not (type(data) is numpy.ndarray):
raise TypeError("%s expected a ndarray object." % self, data, type(data))
if not (data.dtype is self.numpy_dtype):
raise TypeError("%s expected a ndarray object with dtype = %s (got %s)." % (self, self.numpy_dtype, data.dtype))
assert False, "This point in the program should never be reached."
raise TypeError("%s expected a ndarray object." % self,
data, type(data))
if data.dtype != self.numpy_dtype:
raise TypeError(("%s expected a ndarray object with "
"dtype = %s (got %s).") % (
self, self.numpy_dtype, data.dtype))
assert False, "This point should never be reached."
else:
if allow_downcast:
# Convert to self.dtype, regardless of the type of data
data = theano._asarray(data, dtype=self.dtype) #TODO - consider to pad shape with ones
# to make it consistent with self.broadcastable... like vector->row type thing
data = theano._asarray(data, dtype=self.dtype)
# TODO: consider to pad shape with ones to make it consistent
# with self.broadcastable... like vector->row type thing
else:
if isinstance(data, numpy.ndarray):
# Check if self.dtype can accurately represent data
......
......@@ -71,3 +71,14 @@ def test_bug_2009_07_17_borrowed_output():
assert id_z != id_other
# Just to be 100% sure, ensure that z was not altered.
assert (z == z_backup).all()
def test_deepcopied_type_filter():
a = copy.deepcopy(tensor.matrix())
# The following should run cleanly.
# As of commit 731e2d2fa68487733320d341d08b454a50c90d12
# it was failing.
a.type.filter(
numpy.ones((2,2), dtype=a.dtype),
strict=True)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论