提交 575f8edc authored 作者: James Bergstra's avatar James Bergstra

optimized TensorType.filter because all inputs to function are passed through it

上级 d66b0368
...@@ -330,6 +330,7 @@ class TensorType(Type): ...@@ -330,6 +330,7 @@ class TensorType(Type):
self.broadcastable = tuple(broadcastable) self.broadcastable = tuple(broadcastable)
self.dtype_specs() # error checking is done there self.dtype_specs() # error checking is done there
self.name = name self.name = name
self.numpy_dtype = numpy.dtype(self.dtype)
if shape is None: if shape is None:
#backport self.shape = tuple((1 if b else None) for b in self.broadcastable) #backport self.shape = tuple((1 if b else None) for b in self.broadcastable)
l=[] l=[]
...@@ -360,16 +361,16 @@ class TensorType(Type): ...@@ -360,16 +361,16 @@ class TensorType(Type):
This function is not meant to be called in user code. It is for This function is not meant to be called in user code. It is for
`Linker` instances to use when running a compiled graph. `Linker` instances to use when running a compiled graph.
""" """
_data = data if (type(data) is numpy.ndarray) and (data.dtype is self.numpy_dtype):
if strict: pass # fall through to ndim check
elif strict:
# this is its own subcase that doesn't fall through to anything
if not isinstance(data, numpy.ndarray): if not isinstance(data, numpy.ndarray):
raise TypeError("%s expected a ndarray object.", data, type(data)) raise TypeError("%s expected a ndarray object.", data, type(data))
if not str(data.dtype) == self.dtype: if not str(data.dtype) == self.dtype:
raise TypeError("%s expected a ndarray object with dtype = %s (got %s)." % (self, self.dtype, data.dtype)) raise TypeError("%s expected a ndarray object with dtype = %s (got %s)." % (self, self.dtype, data.dtype))
if not data.ndim == self.ndim: if not data.ndim == self.ndim:
raise TypeError("%s expected a ndarray object with %s dimensions (got %s)." % (self, self.ndim, data.ndim)) raise TypeError("%s expected a ndarray object with %s dimensions (got %s)." % (self, self.ndim, data.ndim))
if self.filter_checks_isfinite and (not numpy.all(numpy.isfinite(data))):
raise TypeError("non-finite elements not allowed")
if TensorType.use_shape: if TensorType.use_shape:
for si, di in zip(self.shape, data.shape): for si, di in zip(self.shape, data.shape):
...@@ -378,11 +379,17 @@ class TensorType(Type): ...@@ -378,11 +379,17 @@ class TensorType(Type):
self, self.shape, data.shape)) self, self.shape, data.shape))
return data return data
else: else:
data = theano._asarray(data, dtype = self.dtype) data = theano._asarray(data, dtype = self.dtype) #TODO - consider to pad shape with ones
if not self.ndim == data.ndim: # to make it consistent with self.broadcastable... like vector->row type thing
if self.ndim != data.ndim:
raise TypeError("Wrong number of dimensions: expected %s, got %s with shape %s." % (self.ndim, data.ndim, data.shape), data) raise TypeError("Wrong number of dimensions: expected %s, got %s with shape %s." % (self.ndim, data.ndim, data.shape), data)
if any(b and d != 1 for d, b in zip(data.shape, self.broadcastable)): i = 0
for b in self.broadcastable:
if b and data.shape[i] != 1:
raise TypeError("Non-unit value on shape on a broadcastable dimension.", data.shape, self.broadcastable) raise TypeError("Non-unit value on shape on a broadcastable dimension.", data.shape, self.broadcastable)
i+=1
if self.filter_checks_isfinite and (not numpy.all(numpy.isfinite(data))):
raise TypeError("non-finite elements not allowed")
return data return data
def dtype_specs(self): def dtype_specs(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论