提交 8b7d4eb0 authored 作者: James Bergstra's avatar James Bergstra

use eq instead of dtype_eq in TensorType.filter

上级 a042720c
......@@ -9,13 +9,6 @@ import numpy
import theano
def dtype_eq(a, b):
"""Returns true iff numpy.dtype objects `a` and `b` match for the purposes
of theano code generators.
"""
return a == b and a.num == b.num
def _asarray(a, dtype, order=None):
"""Convert the input to a Numpy array.
......
......@@ -543,10 +543,8 @@ class TensorType(Type):
'maybe you are trying to call a function on a (possibly '
'shared) variable instead of a numeric array?')
dtype_eq = theano.misc.safe_asarray.dtype_eq
if ((type(data) is numpy.ndarray)
and dtype_eq(data.dtype, self.numpy_dtype)):
and (data.dtype == self.numpy_dtype)):
if data.dtype.num != self.numpy_dtype.num:
data = theano._asarray(data, dtype=self.dtype)
# -- now fall through to ndim check
......@@ -556,7 +554,7 @@ class TensorType(Type):
if not (type(data) is numpy.ndarray):
raise TypeError("%s expected a ndarray object." % self,
data, type(data))
if not dtype_eq(data.dtype, self.numpy_dtype):
if data.dtype != self.numpy_dtype:
raise TypeError(("%s expected a ndarray object with "
"dtype = %s (got %s).") % (
self, self.numpy_dtype, data.dtype))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论