提交 a62a13c1 authored 作者: Samira Shabanian's avatar Samira Shabanian

pep8 typed list directory

上级 fba9e145
...@@ -51,13 +51,7 @@ whitelist_flake8 = [ ...@@ -51,13 +51,7 @@ whitelist_flake8 = [
"compile/tests/test_pfunc.py", "compile/tests/test_pfunc.py",
"compile/tests/test_debugmode.py", "compile/tests/test_debugmode.py",
"compile/tests/test_profiling.py", "compile/tests/test_profiling.py",
"typed_list/type.py",
"typed_list/__init__.py", "typed_list/__init__.py",
"typed_list/opt.py",
"typed_list/basic.py",
"typed_list/tests/test_type.py",
"typed_list/tests/test_opt.py",
"typed_list/tests/test_basic.py",
"tensor/__init__.py", "tensor/__init__.py",
"tensor/tests/test_subtensor.py", "tensor/tests/test_subtensor.py",
"tensor/tests/test_utils.py", "tensor/tests/test_utils.py",
......
import copy
import numpy import numpy
...@@ -596,7 +595,7 @@ class MakeList(Op): ...@@ -596,7 +595,7 @@ class MakeList(Op):
a2 = [] a2 = []
for elem in a: for elem in a:
if not isinstance(elem, theano.gof.Variable): if not isinstance(elem, theano.gof.Variable):
elem = as_tensor_variable(elem) elem = theano.tensor.as_tensor_variable(elem)
a2.append(elem) a2.append(elem)
if not all(a2[0].type == elem.type for elem in a2): if not all(a2[0].type == elem.type for elem in a2):
raise TypeError( raise TypeError(
......
from theano import gof from theano import gof
from theano import compile from theano import compile
from theano.gof import TopoOptimizer from theano.gof import TopoOptimizer
from theano.typed_list.basic import (Reverse, from theano.typed_list.basic import Reverse, Append, Extend, Insert, Remove
Append, Extend, Insert, Remove)
@gof.local_optimizer([Append, Extend, Insert, Reverse, Remove], inplace=True) @gof.local_optimizer([Append, Extend, Insert, Reverse, Remove], inplace=True)
def typed_list_inplace_opt(node): def typed_list_inplace_opt(node):
if isinstance(node.op, (Append, Extend, Insert, Reverse, Remove)) \ if (isinstance(node.op, (Append, Extend, Insert, Reverse, Remove)) and not
and not node.op.inplace: node.op.inplace):
new_op = node.op.__class__( new_op = node.op.__class__(inplace=True)
inplace=True)
new_node = new_op(*node.inputs) new_node = new_op(*node.inputs)
return [new_node] return [new_node]
return False return False
compile.optdb.register('typed_list_inplace_opt', compile.optdb.register('typed_list_inplace_opt',
TopoOptimizer(typed_list_inplace_opt, TopoOptimizer(typed_list_inplace_opt,
failure_callback=TopoOptimizer.warn_inplace), 60, failure_callback=TopoOptimizer.warn_inplace),
'fast_run', 'inplace') 60, 'fast_run', 'inplace')
...@@ -10,7 +10,7 @@ from theano.tensor.type_other import SliceType ...@@ -10,7 +10,7 @@ from theano.tensor.type_other import SliceType
from theano.typed_list.type import TypedListType from theano.typed_list.type import TypedListType
from theano.typed_list.basic import (GetItem, Insert, from theano.typed_list.basic import (GetItem, Insert,
Append, Extend, Remove, Reverse, Append, Extend, Remove, Reverse,
Index, Count, Length, make_list, MakeList) Index, Count, Length, make_list)
from theano import sparse from theano import sparse
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
# TODO, handle the case where scipy isn't installed. # TODO, handle the case where scipy isn't installed.
...@@ -23,8 +23,8 @@ except ImportError: ...@@ -23,8 +23,8 @@ except ImportError:
# took from tensors/tests/test_basic.py # took from tensors/tests/test_basic.py
def rand_ranged_matrix(minimum, maximum, shape): def rand_ranged_matrix(minimum, maximum, shape):
return numpy.asarray(numpy.random.rand(*shape) * (maximum - minimum) return numpy.asarray(numpy.random.rand(*shape) * (maximum - minimum) +
+ minimum, dtype=theano.config.floatX) minimum, dtype=theano.config.floatX)
# took from sparse/tests/test_basic.py # took from sparse/tests/test_basic.py
...@@ -82,7 +82,6 @@ class test_get_item(unittest.TestCase): ...@@ -82,7 +82,6 @@ class test_get_item(unittest.TestCase):
z) z)
x = rand_ranged_matrix(-1000, 1000, [100, 101]) x = rand_ranged_matrix(-1000, 1000, [100, 101])
y = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(numpy.array_equal(f([x], self.assertTrue(numpy.array_equal(f([x],
numpy.asarray(0, dtype='int64')), numpy.asarray(0, dtype='int64')),
...@@ -555,7 +554,7 @@ class test_length(unittest.TestCase): ...@@ -555,7 +554,7 @@ class test_length(unittest.TestCase):
self.assertTrue(f([x, x]) == 2) self.assertTrue(f([x, x]) == 2)
class T_MakeList(unittest.TestCase): class TestMakeList(unittest.TestCase):
def test_wrong_shape(self): def test_wrong_shape(self):
a = T.vector() a = T.vector()
...@@ -563,22 +562,22 @@ class T_MakeList(unittest.TestCase): ...@@ -563,22 +562,22 @@ class T_MakeList(unittest.TestCase):
self.assertRaises(TypeError, make_list, (a, b)) self.assertRaises(TypeError, make_list, (a, b))
def correct_answer(self): def test_correct_answer(self):
a = T.matrix() a = T.matrix()
b = T.matrix() b = T.matrix()
x = T.tensor3() x = T.tensor3()
y = T.tensor3() y = T.tensor3()
A = numpy.random.rand(5) A = numpy.random.rand(5, 3)
B = numpy.random.rand(7) B = numpy.random.rand(7, 2)
X = numpy.random.rand(5, 6) X = numpy.random.rand(5, 6, 1)
Y = numpy.random.rand(1, 9) Y = numpy.random.rand(1, 9, 3)
make_list((3., 4.))
c = make_list((a, b)) c = make_list((a, b))
z = make_list((x, y)) z = make_list((x, y))
fc = function([a, b], c) fc = theano.function([a, b], c)
fz = function([x, y], z) fz = theano.function([x, y], z)
self.assertTrue(fc(A, B) == [A, B])
self.assertTrue(f([A, B]) == [A, B]) self.assertTrue(fz(X, Y) == [X, Y])
self.assertTrue(f([X, Y]) == [X, Y])
...@@ -5,18 +5,16 @@ import numpy ...@@ -5,18 +5,16 @@ import numpy
import theano import theano
import theano.typed_list import theano.typed_list
from theano import tensor as T from theano import tensor as T
from theano.tensor.type_other import SliceType
from theano.typed_list.type import TypedListType from theano.typed_list.type import TypedListType
from theano.typed_list.basic import (GetItem, Insert, from theano.typed_list.basic import (Insert,
Append, Extend, Remove, Reverse, Append, Extend, Remove, Reverse)
Index, Count)
from theano import In from theano import In
# took from tensors/tests/test_basic.py # took from tensors/tests/test_basic.py
def rand_ranged_matrix(minimum, maximum, shape): def rand_ranged_matrix(minimum, maximum, shape):
return numpy.asarray(numpy.random.rand(*shape) * (maximum - minimum) return numpy.asarray(numpy.random.rand(*shape) * (maximum - minimum) +
+ minimum, dtype=theano.config.floatX) minimum, dtype=theano.config.floatX)
class test_inplace(unittest.TestCase): class test_inplace(unittest.TestCase):
...@@ -44,7 +42,8 @@ class test_inplace(unittest.TestCase): ...@@ -44,7 +42,8 @@ class test_inplace(unittest.TestCase):
z = Append()(mySymbolicMatricesList, mySymbolicMatrix) z = Append()(mySymbolicMatricesList, mySymbolicMatrix)
m = theano.compile.mode.get_default_mode().including("typed_list_inplace_opt") m = theano.compile.mode.get_default_mode().including("typed_list_inplace_opt")
f = theano.function([In(mySymbolicMatricesList, borrow=True, f = theano.function([In(mySymbolicMatricesList, borrow=True,
mutable=True), In(mySymbolicMatrix, borrow=True, mutable=True),
In(mySymbolicMatrix, borrow=True,
mutable=True)], z, accept_inplace=True, mode=m) mutable=True)], z, accept_inplace=True, mode=m)
self.assertTrue(f.maker.fgraph.toposort()[0].op.inplace) self.assertTrue(f.maker.fgraph.toposort()[0].op.inplace)
......
...@@ -11,8 +11,8 @@ from theano.tests import unittest_tools as utt ...@@ -11,8 +11,8 @@ from theano.tests import unittest_tools as utt
# took from tensors/tests/test_basic.py # took from tensors/tests/test_basic.py
def rand_ranged_matrix(minimum, maximum, shape): def rand_ranged_matrix(minimum, maximum, shape):
return numpy.asarray(numpy.random.rand(*shape) * (maximum - minimum) return numpy.asarray(numpy.random.rand(*shape) * (maximum - minimum) +
+ minimum, dtype=theano.config.floatX) minimum, dtype=theano.config.floatX)
class test_typed_list_type(unittest.TestCase): class test_typed_list_type(unittest.TestCase):
...@@ -177,8 +177,7 @@ class test_typed_list_type(unittest.TestCase): ...@@ -177,8 +177,7 @@ class test_typed_list_type(unittest.TestCase):
myManualNestedType1 = TypedListType(TypedListType( myManualNestedType1 = TypedListType(TypedListType(
TypedListType(myType))) TypedListType(myType)))
myManualNestedType2 = TypedListType(TypedListType( myManualNestedType2 = TypedListType(TypedListType(myType))
myType))
self.assertFalse(myManualNestedType1 == myManualNestedType2) self.assertFalse(myManualNestedType1 == myManualNestedType2)
self.assertFalse(myManualNestedType2 == myManualNestedType1) self.assertFalse(myManualNestedType2 == myManualNestedType1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论