提交 e19bfe0d authored 作者: Hengjean's avatar Hengjean

Fixed bug

上级 64305ce1
from type import TypedListType
import theano
from theano.gof import Apply, Constant, Op, Variable
from theano.tensor.type_other import SliceType
from theano import tensor as T
......@@ -270,10 +271,10 @@ class Index(Op):
if isinstance(elem, numpy.ndarray):
for y in range(x.__len__()):
if numpy.array_equal(x[y], elem):
out[0] = numpy.asarray([y])
out[0] = numpy.asarray([y], dtype=theano.config.floatX)
break
else:
out[0] = numpy.asarray([x.index(elem)])
out[0] = numpy.asarray([x.index(elem)], dtype=theano.config.floatX)
def __str__(self):
return self.__class__.__name__
......@@ -303,9 +304,9 @@ class Count(Op):
for y in range(x.__len__()):
if numpy.array_equal(x[y], elem):
out[0] += 1
out[0] = numpy.asarray([out[0]])
out[0] = numpy.asarray([out[0]], dtype=theano.config.floatX)
else:
out[0] = numpy.asarray([x.count(elem)])
out[0] = numpy.asarray([x.count(elem)], dtype=theano.config.floatX)
def __str__(self):
return self.__class__.__name__
......@@ -230,7 +230,8 @@ class test_insert(unittest.TestCase):
y = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(numpy.array_equal(f([x], numpy.asarray(1), y), [x, y]))
self.assertTrue(numpy.array_equal(f([x], numpy.asarray(1, dtype=theano.config.floatX
), y), [x, y]))
def test_sanity_check(self):
mySymbolicMatricesList = TypedListType(T.TensorType(
......@@ -246,7 +247,8 @@ class test_insert(unittest.TestCase):
y = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(numpy.array_equal(f([x], numpy.asarray(1), y), [x, y]))
self.assertTrue(numpy.array_equal(f([x], numpy.asarray(1,
dtype=theano.config.floatX), y), [x, y]))
def test_interface(self):
mySymbolicMatricesList = TypedListType(T.TensorType(
......@@ -262,7 +264,8 @@ class test_insert(unittest.TestCase):
y = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(numpy.array_equal(f([x], numpy.asarray(1), y), [x, y]))
self.assertTrue(numpy.array_equal(f([x], numpy.asarray(1,
dtype=theano.config.floatX), y), [x, y]))
class test_remove(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论