提交 034bb5a3 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

PEP 8 fixes on files I've recently worked on.

上级 41a4a100
import numpy import numpy
import unittest import unittest
import copy
import theano import theano
from theano.tensor import Tensor, TensorType from theano.tensor import Tensor, TensorType
from theano.compile.sharedvalue import * from theano.compile.sharedvalue import *
class Test_SharedVariable(unittest.TestCase): class Test_SharedVariable(unittest.TestCase):
def test_ctors(self): def test_ctors(self):
if 0: #when using an implementation that handles scalars with Scalar type if 0:
# when using an implementation that handles scalars with
# Scalar type
assert shared(7).type == Scalar('int64') assert shared(7).type == Scalar('int64')
assert shared(7.0).type == Scalar('float64') assert shared(7.0).type == Scalar('float64')
assert shared(7, dtype='float64').type == Scalar('float64') assert shared(7, dtype='float64').type == Scalar('float64')
...@@ -24,14 +26,16 @@ class Test_SharedVariable(unittest.TestCase): ...@@ -24,14 +26,16 @@ class Test_SharedVariable(unittest.TestCase):
assert shared(numpy.float32(7)).type == theano.tensor.fscalar assert shared(numpy.float32(7)).type == theano.tensor.fscalar
# test tensor constructor # test tensor constructor
b = shared(numpy.zeros((5,5), dtype='int32')) b = shared(numpy.zeros((5, 5), dtype='int32'))
assert b.type == TensorType('int32', broadcastable=[False,False]) assert b.type == TensorType('int32', broadcastable=[False, False])
b = shared(numpy.random.rand(4,5)) b = shared(numpy.random.rand(4, 5))
assert b.type == TensorType('float64', broadcastable=[False,False]) assert b.type == TensorType('float64', broadcastable=[False, False])
b = shared(numpy.random.rand(5,1,2)) b = shared(numpy.random.rand(5, 1, 2))
assert b.type == TensorType('float64', broadcastable=[False,False,False]) assert b.type == TensorType('float64',
broadcastable=[False, False, False])
assert shared([]).type == generic assert shared([]).type == generic
def badfunc(): def badfunc():
shared(7, bad_kw=False) shared(7, bad_kw=False)
self.assertRaises(TypeError, badfunc) self.assertRaises(TypeError, badfunc)
...@@ -70,7 +74,7 @@ class Test_SharedVariable(unittest.TestCase): ...@@ -70,7 +74,7 @@ class Test_SharedVariable(unittest.TestCase):
SharedVariable( SharedVariable(
name='u', name='u',
type=Tensor(broadcastable=[False], dtype='float64'), type=Tensor(broadcastable=[False], dtype='float64'),
value=[1, 2], #different dtype and not a numpy array value=[1, 2], # different dtype and not a numpy array
strict=False) strict=False)
# here the value is not castable, and we're not strict about it, # here the value is not castable, and we're not strict about it,
...@@ -79,7 +83,7 @@ class Test_SharedVariable(unittest.TestCase): ...@@ -79,7 +83,7 @@ class Test_SharedVariable(unittest.TestCase):
SharedVariable( SharedVariable(
name='u', name='u',
type=Tensor(broadcastable=[False], dtype='float64'), type=Tensor(broadcastable=[False], dtype='float64'),
value=dict(), #not an array by any stretch value=dict(), # not an array by any stretch
strict=False) strict=False)
assert 0 assert 0
except TypeError: except TypeError:
...@@ -96,10 +100,10 @@ class Test_SharedVariable(unittest.TestCase): ...@@ -96,10 +100,10 @@ class Test_SharedVariable(unittest.TestCase):
strict=False) strict=False)
# check that assignments to value are cast properly # check that assignments to value are cast properly
u.set_value([3,4]) u.set_value([3, 4])
assert type(u.get_value()) is numpy.ndarray assert type(u.get_value()) is numpy.ndarray
assert str(u.get_value(borrow=True).dtype) == 'float64' assert str(u.get_value(borrow=True).dtype) == 'float64'
assert numpy.all(u.get_value() == [3,4]) assert numpy.all(u.get_value() == [3, 4])
# check that assignments of nonsense fail # check that assignments of nonsense fail
try: try:
...@@ -109,7 +113,7 @@ class Test_SharedVariable(unittest.TestCase): ...@@ -109,7 +113,7 @@ class Test_SharedVariable(unittest.TestCase):
pass pass
# check that an assignment of a perfect value results in no copying # check that an assignment of a perfect value results in no copying
uval = theano._asarray([5,6,7,8], dtype='float64') uval = theano._asarray([5, 6, 7, 8], dtype='float64')
u.set_value(uval, borrow=True) u.set_value(uval, borrow=True)
assert u.get_value(borrow=True) is uval assert u.get_value(borrow=True) is uval
...@@ -149,10 +153,8 @@ class Test_SharedVariable(unittest.TestCase): ...@@ -149,10 +153,8 @@ class Test_SharedVariable(unittest.TestCase):
assert b.type == theano.tensor.dscalar assert b.type == theano.tensor.dscalar
self.assertRaises(TypeError, f, b, 8) self.assertRaises(TypeError, f, b, 8)
c = shared(numpy.zeros((5,5), dtype='float32')) b = shared(numpy.zeros((5, 5), dtype='float32'))
self.assertRaises(TypeError, f, b, numpy.random.rand(5,5)) self.assertRaises(TypeError, f, b, numpy.random.rand(5, 5))
def test_tensor_strict(self): def test_tensor_strict(self):
def f(var, val): def f(var, val):
...@@ -192,19 +194,16 @@ class Test_SharedVariable(unittest.TestCase): ...@@ -192,19 +194,16 @@ class Test_SharedVariable(unittest.TestCase):
# assert b.type == theano.tensor.dvector # assert b.type == theano.tensor.dvector
# self.assertRaises(TypeError, f, b, 8) # self.assertRaises(TypeError, f, b, 8)
c = shared(numpy.zeros((5,5), dtype='float32')) b = shared(numpy.zeros((5, 5), dtype='float32'))
self.assertRaises(TypeError, f, b, numpy.random.rand(5,5)) self.assertRaises(TypeError, f, b, numpy.random.rand(5, 5))
def test_scalar_floatX(self): def test_scalar_floatX(self):
# # the test should assure that floatX is not used in the shared
# the test should assure that floatX is not used in the shared constructor for scalars # constructor for scalars Shared values can change, and since we don't
# Shared values can change, and since we don't know the range they might take, we # know the range they might take, we should keep the same
# should keep the same bit width / precision as the original value used to create the # bit width / precision as the original value used to create the
# shared variable. # shared variable.
#
# Since downcasting of a value now raises an Exception, # Since downcasting of a value now raises an Exception,
...@@ -213,48 +212,46 @@ class Test_SharedVariable(unittest.TestCase): ...@@ -213,48 +212,46 @@ class Test_SharedVariable(unittest.TestCase):
b = shared(numpy.int64(7), allow_downcast=True) b = shared(numpy.int64(7), allow_downcast=True)
assert b.type == theano.tensor.lscalar assert b.type == theano.tensor.lscalar
f(b,8.23) f(b, 8.23)
assert b.get_value()==8 assert b.get_value() == 8
b = shared(numpy.int32(7), allow_downcast=True) b = shared(numpy.int32(7), allow_downcast=True)
assert b.type == theano.tensor.iscalar assert b.type == theano.tensor.iscalar
f(b,8.23) f(b, 8.23)
assert b.get_value()==8 assert b.get_value() == 8
b = shared(numpy.int16(7), allow_downcast=True) b = shared(numpy.int16(7), allow_downcast=True)
assert b.type == theano.tensor.wscalar assert b.type == theano.tensor.wscalar
f(b,8.23) f(b, 8.23)
assert b.get_value()==8 assert b.get_value() == 8
b = shared(numpy.int8(7), allow_downcast=True) b = shared(numpy.int8(7), allow_downcast=True)
assert b.type == theano.tensor.bscalar assert b.type == theano.tensor.bscalar
f(b,8.23) f(b, 8.23)
assert b.get_value()==8 assert b.get_value() == 8
b = shared(numpy.float64(7.234), allow_downcast=True) b = shared(numpy.float64(7.234), allow_downcast=True)
assert b.type == theano.tensor.dscalar assert b.type == theano.tensor.dscalar
f(b,8) f(b, 8)
assert b.get_value()==8 assert b.get_value() == 8
b = shared(numpy.float32(7.234), allow_downcast=True) b = shared(numpy.float32(7.234), allow_downcast=True)
assert b.type == theano.tensor.fscalar assert b.type == theano.tensor.fscalar
f(b,8) f(b, 8)
assert b.get_value()==8 assert b.get_value() == 8
b = shared(numpy.float(7.234), allow_downcast=True) b = shared(numpy.float(7.234), allow_downcast=True)
assert b.type == theano.tensor.dscalar assert b.type == theano.tensor.dscalar
f(b,8) f(b, 8)
assert b.get_value()==8 assert b.get_value() == 8
b = shared(7.234, allow_downcast=True) b = shared(7.234, allow_downcast=True)
assert b.type == theano.tensor.dscalar assert b.type == theano.tensor.dscalar
f(b,8) f(b, 8)
assert b.get_value()==8 assert b.get_value() == 8
c = shared(numpy.zeros((5,5), dtype='float32'), allow_downcast=True)
self.assertRaises(TypeError, f, b, numpy.random.rand(5,5))
b = shared(numpy.zeros((5, 5), dtype='float32'), allow_downcast=True)
self.assertRaises(TypeError, f, b, numpy.random.rand(5, 5))
def test_tensor_floatX(self): def test_tensor_floatX(self):
def f(var, val): def f(var, val):
...@@ -262,32 +259,32 @@ class Test_SharedVariable(unittest.TestCase): ...@@ -262,32 +259,32 @@ class Test_SharedVariable(unittest.TestCase):
b = shared(numpy.int64([7]), allow_downcast=True) b = shared(numpy.int64([7]), allow_downcast=True)
assert b.type == theano.tensor.lvector assert b.type == theano.tensor.lvector
f(b,[8.23]) f(b, [8.23])
assert b.get_value() == 8 assert b.get_value() == 8
b = shared(numpy.int32([7]), allow_downcast=True) b = shared(numpy.int32([7]), allow_downcast=True)
assert b.type == theano.tensor.ivector assert b.type == theano.tensor.ivector
f(b,[8.23]) f(b, [8.23])
assert b.get_value() == 8 assert b.get_value() == 8
b = shared(numpy.int16([7]), allow_downcast=True) b = shared(numpy.int16([7]), allow_downcast=True)
assert b.type == theano.tensor.wvector assert b.type == theano.tensor.wvector
f(b,[8.23]) f(b, [8.23])
assert b.get_value() == 8 assert b.get_value() == 8
b = shared(numpy.int8([7]), allow_downcast=True) b = shared(numpy.int8([7]), allow_downcast=True)
assert b.type == theano.tensor.bvector assert b.type == theano.tensor.bvector
f(b,[8.23]) f(b, [8.23])
assert b.get_value() == 8 assert b.get_value() == 8
b = shared(numpy.float64([7.234]), allow_downcast=True) b = shared(numpy.float64([7.234]), allow_downcast=True)
assert b.type == theano.tensor.dvector assert b.type == theano.tensor.dvector
f(b,[8]) f(b, [8])
assert b.get_value() == 8 assert b.get_value() == 8
b = shared(numpy.float32([7.234]), allow_downcast=True) b = shared(numpy.float32([7.234]), allow_downcast=True)
assert b.type == theano.tensor.fvector assert b.type == theano.tensor.fvector
f(b,[8]) f(b, [8])
assert b.get_value() == 8 assert b.get_value() == 8
#numpy.float([7.234]) don't work #numpy.float([7.234]) don't work
...@@ -300,10 +297,12 @@ class Test_SharedVariable(unittest.TestCase): ...@@ -300,10 +297,12 @@ class Test_SharedVariable(unittest.TestCase):
# assert b.type == theano.tensor.dvector # assert b.type == theano.tensor.dvector
# f(b,[8]) # f(b,[8])
b = shared(numpy.asarray([7.234],dtype=theano.config.floatX), allow_downcast=True) b = shared(numpy.asarray([7.234], dtype=theano.config.floatX),
allow_downcast=True)
assert b.dtype == theano.config.floatX assert b.dtype == theano.config.floatX
f(b,[8]) f(b, [8])
assert b.get_value() == 8 assert b.get_value() == 8
c = shared(numpy.zeros((5,5), dtype='float32'), allow_downcast=True) b = shared(numpy.zeros((5, 5), dtype='float32'),
self.assertRaises(TypeError, f, b, numpy.random.rand(5,5)) allow_downcast=True)
self.assertRaises(TypeError, f, b, numpy.random.rand(5, 5))
差异被折叠。
""" """
Helper functions to make gof backwards compatible (tested on python 2.4 and 2.5) Helper functions to make gof backwards compatible
(tested on python 2.4 and 2.5)
""" """
import collections import collections
import sys import sys
if sys.version_info[:2] < (2,5): if sys.version_info[:2] < (2, 5):
def all(iterable): def all(iterable):
for element in iterable: for element in iterable:
...@@ -55,16 +57,19 @@ if sys.version_info[:2] < (2,5): ...@@ -55,16 +57,19 @@ if sys.version_info[:2] < (2,5):
raise TypeError('first argument must be callable') raise TypeError('first argument must be callable')
dict.__init__(self, *a, **kw) dict.__init__(self, *a, **kw)
self.default_factory = default_factory self.default_factory = default_factory
def __getitem__(self, key): def __getitem__(self, key):
try: try:
return dict.__getitem__(self, key) return dict.__getitem__(self, key)
except KeyError: except KeyError:
return self.__missing__(key) return self.__missing__(key)
def __missing__(self, key): def __missing__(self, key):
if self.default_factory is None: if self.default_factory is None:
raise KeyError(key) raise KeyError(key)
self[key] = value = self.default_factory() self[key] = value = self.default_factory()
return value return value
def __reduce__(self): def __reduce__(self):
if self.default_factory is None: if self.default_factory is None:
args = tuple() args = tuple()
...@@ -72,14 +77,18 @@ if sys.version_info[:2] < (2,5): ...@@ -72,14 +77,18 @@ if sys.version_info[:2] < (2,5):
args = self.default_factory, args = self.default_factory,
# consider replacing items() with iteritems() # consider replacing items() with iteritems()
return type(self), args, None, None, self.items() return type(self), args, None, None, self.items()
def copy(self): def copy(self):
return self.__copy__() return self.__copy__()
def __copy__(self): def __copy__(self):
return type(self)(self.default_factory, self) return type(self)(self.default_factory, self)
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
import copy import copy
return type(self)(self.default_factory, return type(self)(self.default_factory,
copy.deepcopy(self.items())) copy.deepcopy(self.items()))
def __repr__(self): def __repr__(self):
return 'defaultdict(%s, %s)' % (self.default_factory, return 'defaultdict(%s, %s)' % (self.default_factory,
dict.__repr__(self)) dict.__repr__(self))
...@@ -90,14 +99,15 @@ else: ...@@ -90,14 +99,15 @@ else:
import __builtin__ import __builtin__
all = __builtin__.all all = __builtin__.all
any = __builtin__.any any = __builtin__.any
import functools, collections import collections
import functools
partial = functools.partial partial = functools.partial
defaultdict = collections.defaultdict defaultdict = collections.defaultdict
deque = collections.deque deque = collections.deque
__all__ = ['all', 'any'] __all__ = ['all', 'any']
if sys.version_info[:2] < (2,6): if sys.version_info[:2] < (2, 6):
# Borrowed from Python docs # Borrowed from Python docs
def combinations(iterable, r): def combinations(iterable, r):
# combinations('ABCD', 2) --> AB AC AD BC BD CD # combinations('ABCD', 2) --> AB AC AD BC BD CD
...@@ -115,18 +125,17 @@ if sys.version_info[:2] < (2,6): ...@@ -115,18 +125,17 @@ if sys.version_info[:2] < (2,6):
else: else:
return return
indices[i] += 1 indices[i] += 1
for j in range(i+1, r): for j in range(i + 1, r):
indices[j] = indices[j-1] + 1 indices[j] = indices[j - 1] + 1
yield tuple(pool[i] for i in indices) yield tuple(pool[i] for i in indices)
def product(*args, **kwds): def product(*args, **kwds):
# product('ABCD', 'xy') --> Ax Ay Bx By Cx Cy Dx Dy # product('ABCD', 'xy') --> Ax Ay Bx By Cx Cy Dx Dy
# product(range(2), repeat=3) --> 000 001 010 011 100 101 110 111 # product(range(2), repeat=3) --> 000 001 010 011 100 101 110 111
pools = map(tuple, args) * kwds.get('repeat', 1) pools = map(tuple, args) * kwds.get('repeat', 1)
result = [[]] result = [[]]
for pool in pools: for pool in pools:
result = [x+[y] for x in result for y in pool] result = [x + [y] for x in result for y in pool]
for prod in result: for prod in result:
yield tuple(prod) yield tuple(prod)
......
...@@ -21,7 +21,6 @@ from theano.tensor import opt, get_constant_value ...@@ -21,7 +21,6 @@ from theano.tensor import opt, get_constant_value
from theano import gof from theano import gof
from theano.gof.python25 import maxsize from theano.gof.python25 import maxsize
from theano.compile import optdb from theano.compile import optdb
from theano import config
from theano.compile.function_module import deep_copy_op from theano.compile.function_module import deep_copy_op
import scan_op import scan_op
...@@ -97,7 +96,6 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -97,7 +96,6 @@ def remove_constants_and_unused_inputs_scan(node):
try: try:
# This works if input is a constant that has all entries # This works if input is a constant that has all entries
# equal # equal
val = tensor.get_constant_value(node.inputs[idx + 1])
givens[op_ins[idx]] = node.inputs[idx + 1].clone()[0] givens[op_ins[idx]] = node.inputs[idx + 1].clone()[0]
except TypeError: except TypeError:
pass pass
...@@ -729,7 +727,6 @@ class ScanSaveMem(gof.Optimizer): ...@@ -729,7 +727,6 @@ class ScanSaveMem(gof.Optimizer):
nw_slice = (fslice,) + tuple(old_slices[1:]) nw_slice = (fslice,) + tuple(old_slices[1:])
nw_pos = inv_compress_map[idx] nw_pos = inv_compress_map[idx]
nw_out = new_outs[nw_pos]
subtens = tensor.basic.Subtensor(nw_slice) subtens = tensor.basic.Subtensor(nw_slice)
# slice inputs # slice inputs
...@@ -748,7 +745,6 @@ class ScanSaveMem(gof.Optimizer): ...@@ -748,7 +745,6 @@ class ScanSaveMem(gof.Optimizer):
for pos, old_outs in old_outputs: for pos, old_outs in old_outputs:
if len(old_outs) > 0: if len(old_outs) > 0:
nw_pos = compress_map[pos] nw_pos = compress_map[pos]
nw_out = new_outs[nw_pos]
for k, old in enumerate(old_outs): for k, old in enumerate(old_outs):
# Get the correct slice # Get the correct slice
cnf_slice, old_slices = slices[pos][k] cnf_slice, old_slices = slices[pos][k]
...@@ -1066,7 +1062,6 @@ def scan_merge_inouts(node): ...@@ -1066,7 +1062,6 @@ def scan_merge_inouts(node):
else: else:
a_inner_outs = a.inner_outputs a_inner_outs = a.inner_outputs
inner_outputs = scan_utils.clone(a_inner_outs, replace=inp_equiv) inner_outputs = scan_utils.clone(a_inner_outs, replace=inp_equiv)
orig_outputs = a.outer_outputs
op = scan_op.Scan(inner_inputs, inner_outputs, info) op = scan_op.Scan(inner_inputs, inner_outputs, info)
outputs = op(*outer_inputs) outputs = op(*outer_inputs)
......
差异被折叠。
...@@ -721,14 +721,14 @@ class ShapeFeature(object): ...@@ -721,14 +721,14 @@ class ShapeFeature(object):
def shape_ir(self, i, r): def shape_ir(self, i, r):
"""Return symbolic r.shape[i] for tensor variable r, int i""" """Return symbolic r.shape[i] for tensor variable r, int i"""
if hasattr(r.type,"broadcastable") and r.type.broadcastable[i]: if hasattr(r.type, "broadcastable") and r.type.broadcastable[i]:
return self.lscalar_one return self.lscalar_one
else: else:
return Shape_i(i).make_node(r).outputs[0] return Shape_i(i).make_node(r).outputs[0]
def shape_tuple(self, r): def shape_tuple(self, r):
"""Return a tuple of symbolic shape vars for tensor variable r""" """Return a tuple of symbolic shape vars for tensor variable r"""
return tuple([self.shape_ir(i,r) for i in xrange(r.ndim)]) return tuple([self.shape_ir(i, r) for i in xrange(r.ndim)])
def default_infer_shape(self, node, i_shapes): def default_infer_shape(self, node, i_shapes):
"""Return a list of shape tuple or None for the outputs of node. """Return a list of shape tuple or None for the outputs of node.
...@@ -861,7 +861,7 @@ class ShapeFeature(object): ...@@ -861,7 +861,7 @@ class ShapeFeature(object):
if r not in self.shape_of: if r not in self.shape_of:
try: try:
self.set_shape(r, self.shape_tuple(r)) self.set_shape(r, self.shape_tuple(r))
except AttributeError: #XXX: where would this come from? except AttributeError: # XXX: where would this come from?
self.set_shape(r, None) self.set_shape(r, None)
def make_vector_shape(self, r): def make_vector_shape(self, r):
...@@ -949,17 +949,18 @@ class ShapeFeature(object): ...@@ -949,17 +949,18 @@ class ShapeFeature(object):
if sh is None: if sh is None:
continue continue
for i, d in enumerate(sh): for i, d in enumerate(sh):
# Note: we ignore any shape element that is not typed (i.e. does # Note: we ignore any shape element that is not typed (i.e.,
# not have a 'dtype' attribute). This means there may still # does not have a 'dtype' attribute). This means there may
# remain int elements that are int32 on 32-bit platforms, but # still remain int elements that are int32 on 32-bit platforms,
# this works with `local_useless_subtensor`, so for now we # but this works with `local_useless_subtensor`, so for now we
# keep it this way. See #266 for a better long-term fix. # keep it this way. See #266 for a better long-term fix.
if getattr(d, 'dtype', 'int64') != 'int64': if getattr(d, 'dtype', 'int64') != 'int64':
assert d.dtype in theano.tensor.int_dtypes assert d.dtype in theano.tensor.int_dtypes
new_shape += sh[len(new_shape):i + 1] new_shape += sh[len(new_shape):i + 1]
new_shape[i] = theano.tensor.cast(d, 'int64') new_shape[i] = theano.tensor.cast(d, 'int64')
if new_shape: if new_shape:
# We replace the shape with wrong dtype by the one with 'int64'. # We replace the shape with wrong dtype by the one with
# 'int64'.
new_shape += sh[len(new_shape):] new_shape += sh[len(new_shape):]
o_shapes[sh_idx] = tuple(new_shape) o_shapes[sh_idx] = tuple(new_shape)
new_shape = [] new_shape = []
...@@ -990,8 +991,8 @@ class ShapeFeature(object): ...@@ -990,8 +991,8 @@ class ShapeFeature(object):
for (shpnode, idx) in (r.clients + [(node, i)]): for (shpnode, idx) in (r.clients + [(node, i)]):
if isinstance(getattr(shpnode, 'op', None), Shape_i): if isinstance(getattr(shpnode, 'op', None), Shape_i):
self.scheduled[shpnode] = new_r self.scheduled[shpnode] = new_r
# In case 2, if r is a variable that we've scheduled for shape update, then we # In case 2, if r is a variable that we've scheduled for shape update,
# should cancel it. # then we should cancel it.
unscheduled = [k for k, v in self.scheduled.items() if v == r] unscheduled = [k for k, v in self.scheduled.items() if v == r]
for k in unscheduled: for k in unscheduled:
del self.scheduled[k] del self.scheduled[k]
...@@ -1212,9 +1213,10 @@ def local_alloc_unary(node): ...@@ -1212,9 +1213,10 @@ def local_alloc_unary(node):
class Assert(T.Op): class Assert(T.Op):
""" """
Implements assertion in a computational graph. Implements assertion in a computational graph.
Notes: Notes:
This Op can be removed from the graph because of optimizations, and can hide This Op can be removed from the graph because of optimizations, and can
some possible optimizations to the optimizer. hide some possible optimizations to the optimizer.
Also, the output of the Op must be returned by the function computing the Also, the output of the Op must be returned by the function computing the
graph, otherwise it will not be used. graph, otherwise it will not be used.
""" """
...@@ -2773,7 +2775,6 @@ class Canonizer(gof.LocalOptimizer): ...@@ -2773,7 +2775,6 @@ class Canonizer(gof.LocalOptimizer):
if op not in [self.main, self.inverse, self.reciprocal]: if op not in [self.main, self.inverse, self.reciprocal]:
return False return False
inputs = node.inputs
out = node.outputs[0] out = node.outputs[0]
assert len(node.outputs) == 1 assert len(node.outputs) == 1
...@@ -2934,7 +2935,6 @@ def local_sum_div_dimshuffle(node): ...@@ -2934,7 +2935,6 @@ def local_sum_div_dimshuffle(node):
axis = range(node.inputs[0].ndim) axis = range(node.inputs[0].ndim)
#print 'axis =', axis #print 'axis =', axis
thing_summed = node.inputs[0] thing_summed = node.inputs[0]
dimshuffled = None
if thing_summed.owner and thing_summed.owner.op == T.true_div: if thing_summed.owner and thing_summed.owner.op == T.true_div:
numerator, denominator = thing_summed.owner.inputs numerator, denominator = thing_summed.owner.inputs
...@@ -3035,11 +3035,13 @@ def local_sum_sum(node): ...@@ -3035,11 +3035,13 @@ def local_sum_sum(node):
if summed.owner.op.axis is None: if summed.owner.op.axis is None:
# special case of local_cut_useless_reduce # special case of local_cut_useless_reduce
return [T.Sum(None, dtype=out_dtype)(summed.owner.inputs[0])] return [T.Sum(None, dtype=out_dtype)(
summed.owner.inputs[0])]
if node.op.axis is None: if node.op.axis is None:
# we're summing up everything anyway so lets # we're summing up everything anyway so lets
# do it all at once # do it all at once
return [T.Sum(None, dtype=out_dtype)(summed.owner.inputs[0])] return [T.Sum(None, dtype=out_dtype)(
summed.owner.inputs[0])]
newaxis = list(tuple(summed.owner.op.axis)) newaxis = list(tuple(summed.owner.op.axis))
# figure out which dimensions of the original input # figure out which dimensions of the original input
...@@ -3113,7 +3115,7 @@ def local_sum_alloc(node): ...@@ -3113,7 +3115,7 @@ def local_sum_alloc(node):
assert val.size == 1 assert val.size == 1
val = val.reshape(1)[0] * T.mul(*shapes) val = val.reshape(1)[0] * T.mul(*shapes)
return [T.cast(val, dtype=node.outputs[0].dtype)] return [T.cast(val, dtype=node.outputs[0].dtype)]
except TypeError, e: except TypeError:
pass pass
else: else:
try: try:
...@@ -3127,7 +3129,7 @@ def local_sum_alloc(node): ...@@ -3127,7 +3129,7 @@ def local_sum_alloc(node):
return [T.alloc(T.cast(val, dtype=node.outputs[0].dtype), return [T.alloc(T.cast(val, dtype=node.outputs[0].dtype),
*[shapes[i] for i in xrange(len(shapes)) *[shapes[i] for i in xrange(len(shapes))
if i not in node.op.axis])] if i not in node.op.axis])]
except TypeError, e: except TypeError:
pass pass
...@@ -4433,7 +4435,6 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024): ...@@ -4433,7 +4435,6 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
fusion optimization. We skip this optimization. You can ignore this message, fusion optimization. We skip this optimization. You can ignore this message,
your code will run correctly, but may be slower.""") your code will run correctly, but may be slower.""")
otype = node.outputs[0].type
s_new_out = node.op.scalar_op(*s_g) s_new_out = node.op.scalar_op(*s_g)
try: try:
s_new_out.owner.op.c_code(s_new_out.owner, s_new_out.owner.op.c_code(s_new_out.owner,
...@@ -4509,7 +4510,7 @@ class FusionOptimizer(Optimizer): ...@@ -4509,7 +4510,7 @@ class FusionOptimizer(Optimizer):
zip(node.outputs, new_outputs), zip(node.outputs, new_outputs),
reason=self.__class__.__name__) reason=self.__class__.__name__)
did_something = True did_something = True
except InconsistencyError, e: except InconsistencyError:
pass pass
if config.tensor.local_elemwise_fusion: if config.tensor.local_elemwise_fusion:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论