提交 64305ce1 authored 作者: Hengjean's avatar Hengjean

Fixed bug

上级 7eef6f8d
...@@ -56,7 +56,8 @@ class test_get_item(unittest.TestCase): ...@@ -56,7 +56,8 @@ class test_get_item(unittest.TestCase):
x = rand_ranged_matrix(-1000, 1000, [100, 101]) x = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(numpy.array_equal(f([x], numpy.asarray(0)), x)) self.assertTrue(numpy.array_equal(f([x], numpy.asarray(0,
dtype=theano.config.floatX)), x))
def test_interface(self): def test_interface(self):
mySymbolicMatricesList = TypedListType(T.TensorType( mySymbolicMatricesList = TypedListType(T.TensorType(
...@@ -70,7 +71,8 @@ class test_get_item(unittest.TestCase): ...@@ -70,7 +71,8 @@ class test_get_item(unittest.TestCase):
x = rand_ranged_matrix(-1000, 1000, [100, 101]) x = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(numpy.array_equal(f([x], numpy.asarray(0)), x)) self.assertTrue(numpy.array_equal(f([x], numpy.asarray(0,
dtype=theano.config.floatX)), x))
z = mySymbolicMatricesList[0: 1: 1] z = mySymbolicMatricesList[0: 1: 1]
......
from theano import gof from theano import gof
from theano import numpy
class TypedListType(gof.Type): class TypedListType(gof.Type):
...@@ -65,3 +65,9 @@ class TypedListType(gof.Type): ...@@ -65,3 +65,9 @@ class TypedListType(gof.Type):
return self.ttype.get_depth() + 1 return self.ttype.get_depth() + 1
else: else:
return 0 return 0
def values_eq(self, a, b):
if isinstance(a, numpy.ndarray):
return numpy.array_equal(a, b)
else:
return a == b
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论