提交 200db67b authored 作者: Hengjean's avatar Hengjean

Added get_item op.

上级 ddf1a974
......@@ -52,6 +52,9 @@ class SliceType(Type):
def __str__(self):
return "slice"
def __eq__(self, other):
return type(self) is SliceType and type(other) is SliceType
slicetype = SliceType()
......
from type import TypedListType
from basic import *
from type import TypedListType
from theano.gof import Apply, Constant, Op, Variable
from theano.tensor.type_other import SliceType
from theano import tensor as T
class _typed_list_py_operators:
def __getitem__(self, index):
return get_item()(self, index)
class TypedListVariable(_typed_list_py_operators, Variable):
"""
Subclass to add the typed list operators to the basic `Variable` class.
"""
TypedListType.Variable = TypedListVariable
class get_item(Op):
"""
get specified slice of a typed list
"""
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, x, index):
if index.type == SliceType():
return Apply(self, [x, index], [x.type()])
elif isinstance(index, T.TensorVariable) and index.ndim == 0:
return Apply(self, [x, index], [x.ttype()])
else:
raise TypeError('Expected scalar or slice as index.')
def perform(self, node, (x, index), (out, )):
if not isinstance(index, slice):
index = int(index)
out[0] = x[index]
def __str__(self):
return self.__class__.__name__
import unittest
from nose.plugins.skip import SkipTest
from nose.plugins.attrib import attr
import numpy
from numpy.testing import dec, assert_array_equal, assert_allclose
from numpy.testing.noseclasses import KnownFailureTest
import theano
import theano.typed_list
from theano import tensor as T
from theano.tensor.type_other import SliceType
from theano.typed_list.type import TypedListType
from theano.typed_list.basic import get_item
from theano.tests import unittest_tools as utt
#took from tensors/tests/test_basic.py
def rand_ranged_matrix(minimum, maximum, shape):
return numpy.asarray(numpy.random.rand(*shape) * (maximum - minimum)
+ minimum, dtype=theano.config.floatX)
class test_get_slice(unittest.TestCase):
def setUp(self):
utt.seed_rng()
def test_sanity_check_slice(self):
mySymbolicMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
mySymbolicSlice = SliceType()()
z = get_item()(mySymbolicMatricesList, mySymbolicSlice)
self.assertFalse(isinstance(z, T.TensorVariable))
f = theano.function([mySymbolicMatricesList, mySymbolicSlice],
z)
x = rand_ranged_matrix(-1000, 1000, [100, 100])
self.assertTrue(numpy.array_equal(f([x], slice(0, 1, 1)), [x]))
def test_sanity_check_single(self):
mySymbolicMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
mySymbolicScalar = T.scalar()
z = get_item()(mySymbolicMatricesList, mySymbolicScalar)
f = theano.function([mySymbolicMatricesList, mySymbolicScalar],
z)
x = rand_ranged_matrix(-1000, 1000, [100, 100])
self.assertTrue(numpy.array_equal(f([x], numpy.asarray(0)), x))
def test_interface(self):
mySymbolicMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
mySymbolicScalar = T.scalar()
z = mySymbolicMatricesList[mySymbolicScalar]
f = theano.function([mySymbolicMatricesList, mySymbolicScalar],
z)
x = rand_ranged_matrix(-1000, 1000, [100, 100])
self.assertTrue(numpy.array_equal(f([x], numpy.asarray(0)), x))
def test_wrong_input(self):
mySymbolicMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
mySymbolicMatrix = T.matrix()
self.assertRaises(TypeError, get_item(), mySymbolicMatricesList,
mySymbolicMatrix)
......@@ -187,3 +187,10 @@ class test_typed_list_type(unittest.TestCase):
self.assertFalse(myManualNestedType1 == myManualNestedType2)
self.assertFalse(myManualNestedType2 == myManualNestedType1)
def test_variable_is_Typed_List_variable(self):
mySymbolicVariable = TypedListType(T.TensorType(theano.config.floatX,
(False, False)))()
self.assertTrue(isinstance(mySymbolicVariable,
theano.typed_list.TypedListVariable))
......@@ -23,6 +23,8 @@ class TypedListType(gof.Type):
else:
self.ttype = TypedListType(ttype, depth - 1)
self.Variable.ttype = self.ttype
def filter(self, x, strict=False, allow_downcast=None):
"""
:Parameters:
......@@ -67,3 +69,13 @@ class TypedListType(gof.Type):
return self.ttype.get_depth() + 1
else:
return 0
def make_variable(self, name=None):
"""Return a `TypedListVariable` of this type
:Parameters:
- `name`: str
A pretty name to identify this `Variable` when printing and
debugging
"""
return self.Variable(self, name=name)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论