提交 eb1d2f9f authored 作者: Hengjean's avatar Hengjean

Added constant input support for GetItem

上级 6b60c53a
...@@ -39,7 +39,13 @@ class GetItem(Op): ...@@ -39,7 +39,13 @@ class GetItem(Op):
def make_node(self, x, index): def make_node(self, x, index):
assert isinstance(x.type, TypedListType) assert isinstance(x.type, TypedListType)
assert isinstance(index, Variable) if not isinstance(index, Variable):
if isinstance(index, slice):
index = Constant(SliceType(), index)
return Apply(self, [x, index], [x.type()])
else:
index = T.constant(index, ndim=0)
return Apply(self, [x, index], [x.ttype()])
if isinstance(index.type, SliceType): if isinstance(index.type, SliceType):
return Apply(self, [x, index], [x.type()]) return Apply(self, [x, index], [x.type()])
elif isinstance(index, T.TensorVariable) and index.ndim == 0: elif isinstance(index, T.TensorVariable) and index.ndim == 0:
......
...@@ -70,6 +70,13 @@ class test_get_item(unittest.TestCase): ...@@ -70,6 +70,13 @@ class test_get_item(unittest.TestCase):
self.assertTrue(numpy.array_equal(f([x], numpy.asarray(0)), x)) self.assertTrue(numpy.array_equal(f([x], numpy.asarray(0)), x))
z = mySymbolicMatricesList[0: 1: 1]
f = theano.function([mySymbolicMatricesList],
z)
self.assertTrue(numpy.array_equal(f([x]), [x]))
def test_wrong_input(self): def test_wrong_input(self):
mySymbolicMatricesList = TypedListType(T.TensorType( mySymbolicMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))() theano.config.floatX, (False, False)))()
...@@ -78,17 +85,38 @@ class test_get_item(unittest.TestCase): ...@@ -78,17 +85,38 @@ class test_get_item(unittest.TestCase):
self.assertRaises(TypeError, GetItem(), mySymbolicMatricesList, self.assertRaises(TypeError, GetItem(), mySymbolicMatricesList,
mySymbolicMatrix) mySymbolicMatrix)
def test_constant_input(self):
mySymbolicMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
z = GetItem()(mySymbolicMatricesList, 0)
f = theano.function([mySymbolicMatricesList],
z)
x = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(numpy.array_equal(f([x]), x))
z = GetItem()(mySymbolicMatricesList, slice(0, 1, 1))
f = theano.function([mySymbolicMatricesList],
z)
self.assertTrue(numpy.array_equal(f([x]), [x]))
class test_append(unittest.TestCase): class test_append(unittest.TestCase):
def test_sanity_check(self): def test_inplace(self):
mySymbolicMatricesList = TypedListType(T.TensorType( mySymbolicMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))() theano.config.floatX, (False, False)))()
myMatrix = T.matrix() myMatrix = T.matrix()
z = Append()(mySymbolicMatricesList, myMatrix) z = Append()(mySymbolicMatricesList, myMatrix)
f = theano.function([mySymbolicMatricesList, myMatrix], z) f = theano.function([mySymbolicMatricesList, myMatrix], z,
accept_inplace=True)
x = rand_ranged_matrix(-1000, 1000, [100, 101]) x = rand_ranged_matrix(-1000, 1000, [100, 101])
...@@ -99,7 +127,7 @@ class test_append(unittest.TestCase): ...@@ -99,7 +127,7 @@ class test_append(unittest.TestCase):
class test_extend(unittest.TestCase): class test_extend(unittest.TestCase):
def test_sanity_check(self): def test_inplace(self):
mySymbolicMatricesList1 = TypedListType(T.TensorType( mySymbolicMatricesList1 = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))() theano.config.floatX, (False, False)))()
mySymbolicMatricesList2 = TypedListType(T.TensorType( mySymbolicMatricesList2 = TypedListType(T.TensorType(
...@@ -108,7 +136,7 @@ class test_extend(unittest.TestCase): ...@@ -108,7 +136,7 @@ class test_extend(unittest.TestCase):
z = Extend()(mySymbolicMatricesList1, mySymbolicMatricesList2) z = Extend()(mySymbolicMatricesList1, mySymbolicMatricesList2)
f = theano.function([mySymbolicMatricesList1, mySymbolicMatricesList2], f = theano.function([mySymbolicMatricesList1, mySymbolicMatricesList2],
z) z, accept_inplace=True)
x = rand_ranged_matrix(-1000, 1000, [100, 101]) x = rand_ranged_matrix(-1000, 1000, [100, 101])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论