提交 034bb5a3 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

PEP 8 fixes on files I've recently worked on.

上级 41a4a100
import numpy import numpy
import unittest import unittest
import copy
import theano import theano
from theano.tensor import Tensor, TensorType from theano.tensor import Tensor, TensorType
from theano.compile.sharedvalue import * from theano.compile.sharedvalue import *
class Test_SharedVariable(unittest.TestCase): class Test_SharedVariable(unittest.TestCase):
def test_ctors(self): def test_ctors(self):
if 0: #when using an implementation that handles scalars with Scalar type if 0:
# when using an implementation that handles scalars with
# Scalar type
assert shared(7).type == Scalar('int64') assert shared(7).type == Scalar('int64')
assert shared(7.0).type == Scalar('float64') assert shared(7.0).type == Scalar('float64')
assert shared(7, dtype='float64').type == Scalar('float64') assert shared(7, dtype='float64').type == Scalar('float64')
...@@ -24,14 +26,16 @@ class Test_SharedVariable(unittest.TestCase): ...@@ -24,14 +26,16 @@ class Test_SharedVariable(unittest.TestCase):
assert shared(numpy.float32(7)).type == theano.tensor.fscalar assert shared(numpy.float32(7)).type == theano.tensor.fscalar
# test tensor constructor # test tensor constructor
b = shared(numpy.zeros((5,5), dtype='int32')) b = shared(numpy.zeros((5, 5), dtype='int32'))
assert b.type == TensorType('int32', broadcastable=[False,False]) assert b.type == TensorType('int32', broadcastable=[False, False])
b = shared(numpy.random.rand(4,5)) b = shared(numpy.random.rand(4, 5))
assert b.type == TensorType('float64', broadcastable=[False,False]) assert b.type == TensorType('float64', broadcastable=[False, False])
b = shared(numpy.random.rand(5,1,2)) b = shared(numpy.random.rand(5, 1, 2))
assert b.type == TensorType('float64', broadcastable=[False,False,False]) assert b.type == TensorType('float64',
broadcastable=[False, False, False])
assert shared([]).type == generic assert shared([]).type == generic
def badfunc(): def badfunc():
shared(7, bad_kw=False) shared(7, bad_kw=False)
self.assertRaises(TypeError, badfunc) self.assertRaises(TypeError, badfunc)
...@@ -70,7 +74,7 @@ class Test_SharedVariable(unittest.TestCase): ...@@ -70,7 +74,7 @@ class Test_SharedVariable(unittest.TestCase):
SharedVariable( SharedVariable(
name='u', name='u',
type=Tensor(broadcastable=[False], dtype='float64'), type=Tensor(broadcastable=[False], dtype='float64'),
value=[1, 2], #different dtype and not a numpy array value=[1, 2], # different dtype and not a numpy array
strict=False) strict=False)
# here the value is not castable, and we're not strict about it, # here the value is not castable, and we're not strict about it,
...@@ -79,7 +83,7 @@ class Test_SharedVariable(unittest.TestCase): ...@@ -79,7 +83,7 @@ class Test_SharedVariable(unittest.TestCase):
SharedVariable( SharedVariable(
name='u', name='u',
type=Tensor(broadcastable=[False], dtype='float64'), type=Tensor(broadcastable=[False], dtype='float64'),
value=dict(), #not an array by any stretch value=dict(), # not an array by any stretch
strict=False) strict=False)
assert 0 assert 0
except TypeError: except TypeError:
...@@ -96,10 +100,10 @@ class Test_SharedVariable(unittest.TestCase): ...@@ -96,10 +100,10 @@ class Test_SharedVariable(unittest.TestCase):
strict=False) strict=False)
# check that assignments to value are cast properly # check that assignments to value are cast properly
u.set_value([3,4]) u.set_value([3, 4])
assert type(u.get_value()) is numpy.ndarray assert type(u.get_value()) is numpy.ndarray
assert str(u.get_value(borrow=True).dtype) == 'float64' assert str(u.get_value(borrow=True).dtype) == 'float64'
assert numpy.all(u.get_value() == [3,4]) assert numpy.all(u.get_value() == [3, 4])
# check that assignments of nonsense fail # check that assignments of nonsense fail
try: try:
...@@ -109,7 +113,7 @@ class Test_SharedVariable(unittest.TestCase): ...@@ -109,7 +113,7 @@ class Test_SharedVariable(unittest.TestCase):
pass pass
# check that an assignment of a perfect value results in no copying # check that an assignment of a perfect value results in no copying
uval = theano._asarray([5,6,7,8], dtype='float64') uval = theano._asarray([5, 6, 7, 8], dtype='float64')
u.set_value(uval, borrow=True) u.set_value(uval, borrow=True)
assert u.get_value(borrow=True) is uval assert u.get_value(borrow=True) is uval
...@@ -149,10 +153,8 @@ class Test_SharedVariable(unittest.TestCase): ...@@ -149,10 +153,8 @@ class Test_SharedVariable(unittest.TestCase):
assert b.type == theano.tensor.dscalar assert b.type == theano.tensor.dscalar
self.assertRaises(TypeError, f, b, 8) self.assertRaises(TypeError, f, b, 8)
c = shared(numpy.zeros((5,5), dtype='float32')) b = shared(numpy.zeros((5, 5), dtype='float32'))
self.assertRaises(TypeError, f, b, numpy.random.rand(5,5)) self.assertRaises(TypeError, f, b, numpy.random.rand(5, 5))
def test_tensor_strict(self): def test_tensor_strict(self):
def f(var, val): def f(var, val):
...@@ -192,19 +194,16 @@ class Test_SharedVariable(unittest.TestCase): ...@@ -192,19 +194,16 @@ class Test_SharedVariable(unittest.TestCase):
# assert b.type == theano.tensor.dvector # assert b.type == theano.tensor.dvector
# self.assertRaises(TypeError, f, b, 8) # self.assertRaises(TypeError, f, b, 8)
c = shared(numpy.zeros((5,5), dtype='float32')) b = shared(numpy.zeros((5, 5), dtype='float32'))
self.assertRaises(TypeError, f, b, numpy.random.rand(5,5)) self.assertRaises(TypeError, f, b, numpy.random.rand(5, 5))
def test_scalar_floatX(self): def test_scalar_floatX(self):
# # the test should assure that floatX is not used in the shared
# the test should assure that floatX is not used in the shared constructor for scalars # constructor for scalars Shared values can change, and since we don't
# Shared values can change, and since we don't know the range they might take, we # know the range they might take, we should keep the same
# should keep the same bit width / precision as the original value used to create the # bit width / precision as the original value used to create the
# shared variable. # shared variable.
#
# Since downcasting of a value now raises an Exception, # Since downcasting of a value now raises an Exception,
...@@ -213,48 +212,46 @@ class Test_SharedVariable(unittest.TestCase): ...@@ -213,48 +212,46 @@ class Test_SharedVariable(unittest.TestCase):
b = shared(numpy.int64(7), allow_downcast=True) b = shared(numpy.int64(7), allow_downcast=True)
assert b.type == theano.tensor.lscalar assert b.type == theano.tensor.lscalar
f(b,8.23) f(b, 8.23)
assert b.get_value()==8 assert b.get_value() == 8
b = shared(numpy.int32(7), allow_downcast=True) b = shared(numpy.int32(7), allow_downcast=True)
assert b.type == theano.tensor.iscalar assert b.type == theano.tensor.iscalar
f(b,8.23) f(b, 8.23)
assert b.get_value()==8 assert b.get_value() == 8
b = shared(numpy.int16(7), allow_downcast=True) b = shared(numpy.int16(7), allow_downcast=True)
assert b.type == theano.tensor.wscalar assert b.type == theano.tensor.wscalar
f(b,8.23) f(b, 8.23)
assert b.get_value()==8 assert b.get_value() == 8
b = shared(numpy.int8(7), allow_downcast=True) b = shared(numpy.int8(7), allow_downcast=True)
assert b.type == theano.tensor.bscalar assert b.type == theano.tensor.bscalar
f(b,8.23) f(b, 8.23)
assert b.get_value()==8 assert b.get_value() == 8
b = shared(numpy.float64(7.234), allow_downcast=True) b = shared(numpy.float64(7.234), allow_downcast=True)
assert b.type == theano.tensor.dscalar assert b.type == theano.tensor.dscalar
f(b,8) f(b, 8)
assert b.get_value()==8 assert b.get_value() == 8
b = shared(numpy.float32(7.234), allow_downcast=True) b = shared(numpy.float32(7.234), allow_downcast=True)
assert b.type == theano.tensor.fscalar assert b.type == theano.tensor.fscalar
f(b,8) f(b, 8)
assert b.get_value()==8 assert b.get_value() == 8
b = shared(numpy.float(7.234), allow_downcast=True) b = shared(numpy.float(7.234), allow_downcast=True)
assert b.type == theano.tensor.dscalar assert b.type == theano.tensor.dscalar
f(b,8) f(b, 8)
assert b.get_value()==8 assert b.get_value() == 8
b = shared(7.234, allow_downcast=True) b = shared(7.234, allow_downcast=True)
assert b.type == theano.tensor.dscalar assert b.type == theano.tensor.dscalar
f(b,8) f(b, 8)
assert b.get_value()==8 assert b.get_value() == 8
c = shared(numpy.zeros((5,5), dtype='float32'), allow_downcast=True)
self.assertRaises(TypeError, f, b, numpy.random.rand(5,5))
b = shared(numpy.zeros((5, 5), dtype='float32'), allow_downcast=True)
self.assertRaises(TypeError, f, b, numpy.random.rand(5, 5))
def test_tensor_floatX(self): def test_tensor_floatX(self):
def f(var, val): def f(var, val):
...@@ -262,32 +259,32 @@ class Test_SharedVariable(unittest.TestCase): ...@@ -262,32 +259,32 @@ class Test_SharedVariable(unittest.TestCase):
b = shared(numpy.int64([7]), allow_downcast=True) b = shared(numpy.int64([7]), allow_downcast=True)
assert b.type == theano.tensor.lvector assert b.type == theano.tensor.lvector
f(b,[8.23]) f(b, [8.23])
assert b.get_value() == 8 assert b.get_value() == 8
b = shared(numpy.int32([7]), allow_downcast=True) b = shared(numpy.int32([7]), allow_downcast=True)
assert b.type == theano.tensor.ivector assert b.type == theano.tensor.ivector
f(b,[8.23]) f(b, [8.23])
assert b.get_value() == 8 assert b.get_value() == 8
b = shared(numpy.int16([7]), allow_downcast=True) b = shared(numpy.int16([7]), allow_downcast=True)
assert b.type == theano.tensor.wvector assert b.type == theano.tensor.wvector
f(b,[8.23]) f(b, [8.23])
assert b.get_value() == 8 assert b.get_value() == 8
b = shared(numpy.int8([7]), allow_downcast=True) b = shared(numpy.int8([7]), allow_downcast=True)
assert b.type == theano.tensor.bvector assert b.type == theano.tensor.bvector
f(b,[8.23]) f(b, [8.23])
assert b.get_value() == 8 assert b.get_value() == 8
b = shared(numpy.float64([7.234]), allow_downcast=True) b = shared(numpy.float64([7.234]), allow_downcast=True)
assert b.type == theano.tensor.dvector assert b.type == theano.tensor.dvector
f(b,[8]) f(b, [8])
assert b.get_value() == 8 assert b.get_value() == 8
b = shared(numpy.float32([7.234]), allow_downcast=True) b = shared(numpy.float32([7.234]), allow_downcast=True)
assert b.type == theano.tensor.fvector assert b.type == theano.tensor.fvector
f(b,[8]) f(b, [8])
assert b.get_value() == 8 assert b.get_value() == 8
#numpy.float([7.234]) don't work #numpy.float([7.234]) don't work
...@@ -300,10 +297,12 @@ class Test_SharedVariable(unittest.TestCase): ...@@ -300,10 +297,12 @@ class Test_SharedVariable(unittest.TestCase):
# assert b.type == theano.tensor.dvector # assert b.type == theano.tensor.dvector
# f(b,[8]) # f(b,[8])
b = shared(numpy.asarray([7.234],dtype=theano.config.floatX), allow_downcast=True) b = shared(numpy.asarray([7.234], dtype=theano.config.floatX),
allow_downcast=True)
assert b.dtype == theano.config.floatX assert b.dtype == theano.config.floatX
f(b,[8]) f(b, [8])
assert b.get_value() == 8 assert b.get_value() == 8
c = shared(numpy.zeros((5,5), dtype='float32'), allow_downcast=True) b = shared(numpy.zeros((5, 5), dtype='float32'),
self.assertRaises(TypeError, f, b, numpy.random.rand(5,5)) allow_downcast=True)
self.assertRaises(TypeError, f, b, numpy.random.rand(5, 5))
...@@ -5,7 +5,6 @@ import cPickle ...@@ -5,7 +5,6 @@ import cPickle
import logging import logging
import operator import operator
import os import os
import platform
import shutil import shutil
import stat import stat
import StringIO import StringIO
...@@ -17,18 +16,21 @@ import time ...@@ -17,18 +16,21 @@ import time
import distutils.sysconfig import distutils.sysconfig
import numpy.distutils #TODO: TensorType should handle this import numpy.distutils # TODO: TensorType should handle this
import theano
from theano.configparser import config from theano.configparser import config
from theano.gof.cc import hash_from_code, hash_from_file from theano.gof.cc import hash_from_code
import compilelock # we will abuse the lockfile mechanism when reading and writing the registry
from theano.configparser import TheanoConfigParser, AddConfigVar, EnumStr, StrParam, IntParam, FloatParam, BoolParam # we will abuse the lockfile mechanism when reading and writing the registry
import compilelock
from theano.configparser import AddConfigVar, BoolParam
AddConfigVar('cmodule.mac_framework_link', AddConfigVar('cmodule.mac_framework_link',
"If set to true, breaks certain mac installations with the infamous Bus Error", ("If set to true, breaks certain mac installations with the infamous "
"Bus Error"),
BoolParam(False)) BoolParam(False))
def local_bitwidth(): def local_bitwidth():
""" """
Return 32 for 32bit arch, 64 for 64bit arch Return 32 for 32bit arch, 64 for 64bit arch
...@@ -42,6 +44,7 @@ def local_bitwidth(): ...@@ -42,6 +44,7 @@ def local_bitwidth():
# 'P' denotes a void*, and the size is expressed in bytes. # 'P' denotes a void*, and the size is expressed in bytes.
return struct.calcsize('P') * 8 return struct.calcsize('P') * 8
def python_int_bitwidth(): def python_int_bitwidth():
""" """
Return the bit width of Python int (C long int). Return the bit width of Python int (C long int).
...@@ -51,11 +54,12 @@ def python_int_bitwidth(): ...@@ -51,11 +54,12 @@ def python_int_bitwidth():
# 'l' denotes a C long int, and the size is expressed in bytes. # 'l' denotes a C long int, and the size is expressed in bytes.
return struct.calcsize('l') * 8 return struct.calcsize('l') * 8
_logger=logging.getLogger("theano.gof.cmodule") _logger = logging.getLogger("theano.gof.cmodule")
_logger.setLevel(logging.WARNING) _logger.setLevel(logging.WARNING)
METH_VARARGS="METH_VARARGS" METH_VARARGS = "METH_VARARGS"
METH_NOARGS="METH_NOARGS" METH_NOARGS = "METH_NOARGS"
def debug_counter(name, every=1): def debug_counter(name, every=1):
"""Debug counter to know how often we go through some piece of code. """Debug counter to know how often we go through some piece of code.
...@@ -68,6 +72,7 @@ def debug_counter(name, every=1): ...@@ -68,6 +72,7 @@ def debug_counter(name, every=1):
if n % every == 0: if n % every == 0:
print >>sys.stderr, "debug_counter [%s]: %s" % (name, n) print >>sys.stderr, "debug_counter [%s]: %s" % (name, n)
class ExtFunction(object): class ExtFunction(object):
"""A C function to put into a DynamicModule """ """A C function to put into a DynamicModule """
...@@ -75,14 +80,18 @@ class ExtFunction(object): ...@@ -75,14 +80,18 @@ class ExtFunction(object):
"""string - function's name""" """string - function's name"""
code_block = "" code_block = ""
"""string - the entire code for the function. Has the form ``static PyObject* """string - the entire code for the function.
<name>([...]){ ... }
Has the form ``static PyObject* <name>([...]){ ... }
See Python's C API Reference for how to write c functions for python modules. See Python's C API Reference for how to write c functions for python
modules.
""" """
method = "" method = ""
"""str - calling method for this function (i.e. 'METH_VARARGS', 'METH_NOARGS')""" """
str - calling method for this function (i.e. 'METH_VARARGS', 'METH_NOARGS')
"""
doc = "" doc = ""
"""str - documentation string for this function""" """str - documentation string for this function"""
...@@ -94,8 +103,14 @@ class ExtFunction(object): ...@@ -94,8 +103,14 @@ class ExtFunction(object):
self.doc = doc self.doc = doc
def method_decl(self): def method_decl(self):
"""Returns the signature for this function that goes into the DynamicModule's method table""" """
return '\t{"%s", %s, %s, "%s"}' %(self.name, self.name, self.method, self.doc) Returns the signature for this function.
It goes into the DynamicModule's method table.
"""
return '\t{"%s", %s, %s, "%s"}' % (
self.name, self.name, self.method, self.doc)
class DynamicModule(object): class DynamicModule(object):
def __init__(self, name): def __init__(self, name):
...@@ -103,8 +118,12 @@ class DynamicModule(object): ...@@ -103,8 +118,12 @@ class DynamicModule(object):
self.support_code = [] self.support_code = []
self.functions = [] self.functions = []
self.includes = ["<Python.h>", "<iostream>"] self.includes = ["<Python.h>", "<iostream>"]
self.includes.append('<numpy/arrayobject.h>') #TODO: this should come from TensorType
self.init_blocks = ['import_array();'] #TODO: from TensorType #TODO: this should come from TensorType
self.includes.append('<numpy/arrayobject.h>')
#TODO: from TensorType
self.init_blocks = ['import_array();']
def print_methoddef(self, stream): def print_methoddef(self, stream):
print >> stream, "static PyMethodDef MyMethods[] = {" print >> stream, "static PyMethodDef MyMethods[] = {"
...@@ -117,22 +136,23 @@ class DynamicModule(object): ...@@ -117,22 +136,23 @@ class DynamicModule(object):
print >> stream, "PyMODINIT_FUNC init%s(void){" % self.name print >> stream, "PyMODINIT_FUNC init%s(void){" % self.name
for b in self.init_blocks: for b in self.init_blocks:
print >> stream, ' ', b print >> stream, ' ', b
print >> stream, ' ', '(void) Py_InitModule("%s", MyMethods);' % self.name print >> stream, ' ', ('(void) Py_InitModule("%s", MyMethods);'
% self.name)
print >> stream, "}" print >> stream, "}"
def add_include(self, str): def add_include(self, str):
self.includes.append(str) self.includes.append(str)
def add_init_code(self, code): def add_init_code(self, code):
self.init_blocks.append(code) self.init_blocks.append(code)
def add_support_code(self, code): def add_support_code(self, code):
if code not in self.support_code: #TODO: KLUDGE if code not in self.support_code: # TODO: KLUDGE
self.support_code.append(code) self.support_code.append(code)
def add_function(self, fn): def add_function(self, fn):
self.functions.append(fn) self.functions.append(fn)
def code(self): def code(self):
sio = StringIO.StringIO() sio = StringIO.StringIO()
for inc in self.includes: for inc in self.includes:
...@@ -141,23 +161,23 @@ class DynamicModule(object): ...@@ -141,23 +161,23 @@ class DynamicModule(object):
if inc[0] == '<' or inc[0] == '"': if inc[0] == '<' or inc[0] == '"':
print >> sio, "#include", inc print >> sio, "#include", inc
else: else:
print >> sio, '#include "%s"'%inc print >> sio, '#include "%s"' % inc
print >> sio, "//////////////////////" print >> sio, "//////////////////////"
print >> sio, "//// Support Code" print >> sio, "//// Support Code"
print >> sio, "//////////////////////" print >> sio, "//////////////////////"
for sc in self.support_code: for sc in self.support_code:
print >> sio, sc print >> sio, sc
print >> sio, "//////////////////////" print >> sio, "//////////////////////"
print >> sio, "//// Functions" print >> sio, "//// Functions"
print >> sio, "//////////////////////" print >> sio, "//////////////////////"
for f in self.functions: for f in self.functions:
print >> sio, f.code_block print >> sio, f.code_block
print >> sio, "//////////////////////" print >> sio, "//////////////////////"
print >> sio, "//// Module init" print >> sio, "//// Module init"
print >> sio, "//////////////////////" print >> sio, "//////////////////////"
self.print_methoddef(sio) self.print_methoddef(sio)
self.print_init(sio) self.print_init(sio)
...@@ -166,17 +186,19 @@ class DynamicModule(object): ...@@ -166,17 +186,19 @@ class DynamicModule(object):
def list_code(self, ofile=sys.stdout): def list_code(self, ofile=sys.stdout):
"""Print out the code with line numbers to `ofile` """ """Print out the code with line numbers to `ofile` """
for i, line in enumerate(self.code().split('\n')): for i, line in enumerate(self.code().split('\n')):
print >> ofile, '%4i'%(i+1), line print >> ofile, ('%4i' % (i + 1)), line
ofile.flush() ofile.flush()
#TODO: add_type #TODO: add_type
def dlimport(fullpath, suffix=None): def dlimport(fullpath, suffix=None):
"""Dynamically load a .so, .pyd, .dll, or .py file """Dynamically load a .so, .pyd, .dll, or .py file
:type fullpath: string :type fullpath: string
:param fullpath: a fully-qualified path do a compiled python module :param fullpath: a fully-qualified path do a compiled python module
:param suffix: a suffix to strip from the end of fullpath to get the import name :param suffix: a suffix to strip from the end of fullpath to get the
import name
:type suffix: string :type suffix: string
:returns: the dynamically loaded module (from __import__) :returns: the dynamically loaded module (from __import__)
...@@ -200,12 +222,12 @@ def dlimport(fullpath, suffix=None): ...@@ -200,12 +222,12 @@ def dlimport(fullpath, suffix=None):
module_name = '.'.join(fullpath.split(os.path.sep)[-2:])[:-len(suffix)] module_name = '.'.join(fullpath.split(os.path.sep)[-2:])[:-len(suffix)]
else: else:
raise ValueError('path has wrong suffix', (fullpath, suffix)) raise ValueError('path has wrong suffix', (fullpath, suffix))
workdir = fullpath[:-len(module_name)- 1 - len(suffix)] workdir = fullpath[:-len(module_name) - 1 - len(suffix)]
_logger.debug("WORKDIR %s", workdir) _logger.debug("WORKDIR %s", workdir)
_logger.debug("module_name %s", module_name) _logger.debug("module_name %s", module_name)
sys.path[0:0] = [workdir] #insert workdir at beginning (temporarily) sys.path[0:0] = [workdir] # insert workdir at beginning (temporarily)
try: try:
rval = __import__(module_name, {}, {}, [module_name]) rval = __import__(module_name, {}, {}, [module_name])
if not rval: if not rval:
...@@ -216,20 +238,32 @@ def dlimport(fullpath, suffix=None): ...@@ -216,20 +238,32 @@ def dlimport(fullpath, suffix=None):
assert fullpath.startswith(rval.__file__) assert fullpath.startswith(rval.__file__)
return rval return rval
def dlimport_workdir(basedir): def dlimport_workdir(basedir):
"""Return a directory where you should put your .so file for dlimport to be able to load """
it, given a basedir which should normally be config.compiledir""" Return a directory where you should put your .so file for dlimport
to be able to load it, given a basedir which should normally be
config.compiledir
"""
return tempfile.mkdtemp(dir=basedir) return tempfile.mkdtemp(dir=basedir)
def last_access_time(path): def last_access_time(path):
"""Return the number of seconds since the epoch of the last access of a given file""" """
Return the number of seconds since the epoch of the last access of a
given file.
"""
return os.stat(path)[stat.ST_ATIME] return os.stat(path)[stat.ST_ATIME]
def module_name_from_dir(dirname): def module_name_from_dir(dirname):
"""Scan the contents of a cache directory and return full path of the dynamic lib in it. """
Scan the contents of a cache directory and return full path of the
dynamic lib in it.
""" """
files = os.listdir(dirname) files = os.listdir(dirname)
name, = [file for file in files if file.endswith('.so') or file.endswith('.pyd')] name, = [file for file in files
if file.endswith('.so') or file.endswith('.pyd')]
return os.path.join(dirname, name) return os.path.join(dirname, name)
...@@ -322,7 +356,8 @@ def get_safe_part(key): ...@@ -322,7 +356,8 @@ def get_safe_part(key):
# Find the md5 hash part. # Find the md5 hash part.
c_link_key = key[1] c_link_key = key[1]
for key_element in c_link_key[1:]: for key_element in c_link_key[1:]:
if isinstance(key_element, basestring) and key_element.startswith('md5:'): if (isinstance(key_element, basestring)
and key_element.startswith('md5:')):
md5 = key_element[4:] md5 = key_element[4:]
break break
...@@ -375,7 +410,8 @@ class KeyData(object): ...@@ -375,7 +410,8 @@ class KeyData(object):
cPickle.dump(self, open(self.key_pkl, 'wb'), cPickle.dump(self, open(self.key_pkl, 'wb'),
protocol=cPickle.HIGHEST_PROTOCOL) protocol=cPickle.HIGHEST_PROTOCOL)
except cPickle.PicklingError: except cPickle.PicklingError:
_logger.warning("Cache leak due to unpickle-able key data %s", self.keys) _logger.warning("Cache leak due to unpickle-able key data %s",
self.keys)
os.remove(self.key_pkl) os.remove(self.key_pkl)
raise raise
...@@ -411,9 +447,9 @@ class KeyData(object): ...@@ -411,9 +447,9 @@ class KeyData(object):
class ModuleCache(object): class ModuleCache(object):
"""Interface to the cache of dynamically compiled modules on disk """Interface to the cache of dynamically compiled modules on disk
Note that this interface does not assume exclusive use of the cache directory. Note that this interface does not assume exclusive use of the cache
It is built to handle the case where multiple programs are also using instances of this directory. It is built to handle the case where multiple programs are also
class to manage the same directory. using instances of this class to manage the same directory.
The cache works on the basis of keys. Each key is mapped to only one The cache works on the basis of keys. Each key is mapped to only one
dynamic module, but multiple keys may be mapped to the same module (see dynamic module, but multiple keys may be mapped to the same module (see
...@@ -475,7 +511,9 @@ class ModuleCache(object): ...@@ -475,7 +511,9 @@ class ModuleCache(object):
"""Maps a module hash to its corresponding KeyData object.""" """Maps a module hash to its corresponding KeyData object."""
stats = [] stats = []
"""A list with counters for the number of hits, loads, compiles issued by module_from_key() """
A list with counters for the number of hits, loads, compiles issued by
module_from_key()
""" """
loaded_key_pkl = set() loaded_key_pkl = set()
...@@ -504,7 +542,7 @@ class ModuleCache(object): ...@@ -504,7 +542,7 @@ class ModuleCache(object):
if do_refresh: if do_refresh:
self.refresh() self.refresh()
age_thresh_use = 60*60*24*24 # 24 days age_thresh_use = 60 * 60 * 24 * 24 # 24 days
""" """
The default age threshold (in seconds) for cache files we want to use. The default age threshold (in seconds) for cache files we want to use.
...@@ -552,10 +590,11 @@ class ModuleCache(object): ...@@ -552,10 +590,11 @@ class ModuleCache(object):
elif 'key.pkl' in files: elif 'key.pkl' in files:
try: try:
entry = module_name_from_dir(root) entry = module_name_from_dir(root)
except ValueError: # there is a key but no dll! except ValueError: # there is a key but no dll!
if not root.startswith("/tmp"): if not root.startswith("/tmp"):
# Under /tmp, file are removed periodically by the os. # Under /tmp, file are removed periodically by the
# So it is normal that this happens from time to time. # os. So it is normal that this happens from time
# to time.
_logger.warning("ModuleCache.refresh() Found key " _logger.warning("ModuleCache.refresh() Found key "
"without dll in cache, deleting it. %s", "without dll in cache, deleting it. %s",
key_pkl) key_pkl)
...@@ -564,9 +603,11 @@ class ModuleCache(object): ...@@ -564,9 +603,11 @@ class ModuleCache(object):
continue continue
if (time_now - last_access_time(entry)) < age_thresh_use: if (time_now - last_access_time(entry)) < age_thresh_use:
_logger.debug('refresh adding %s', key_pkl) _logger.debug('refresh adding %s', key_pkl)
def unpickle_failure(): def unpickle_failure():
_logger.info("ModuleCache.refresh() Failed to " _logger.info("ModuleCache.refresh() Failed to "
"unpickle cache file %s", key_pkl) "unpickle cache file %s", key_pkl)
try: try:
key_data = cPickle.load(open(key_pkl, 'rb')) key_data = cPickle.load(open(key_pkl, 'rb'))
except EOFError: except EOFError:
...@@ -632,12 +673,14 @@ class ModuleCache(object): ...@@ -632,12 +673,14 @@ class ModuleCache(object):
# TODO: check if this can happen at all # TODO: check if this can happen at all
to_del = [key for key in key_data.keys if not key[0]] to_del = [key for key in key_data.keys if not key[0]]
if to_del: if to_del:
_logger.warning("ModuleCache.refresh() Found unversioned " _logger.warning(
"ModuleCache.refresh() Found unversioned "
"key in cache, removing it. %s", key_pkl) "key in cache, removing it. %s", key_pkl)
# Since the version is in the module hash, all # Since the version is in the module hash, all
# keys should be unversioned. # keys should be unversioned.
if len(to_del) != len(key_data.keys): if len(to_del) != len(key_data.keys):
_logger.warning('Found a mix of unversioned and ' _logger.warning(
'Found a mix of unversioned and '
'versioned keys for the same ' 'versioned keys for the same '
'module %s', key_pkl) 'module %s', key_pkl)
_rmtree(root, ignore_nocleanup=True, _rmtree(root, ignore_nocleanup=True,
...@@ -726,14 +769,15 @@ class ModuleCache(object): ...@@ -726,14 +769,15 @@ class ModuleCache(object):
key_data.delete_keys_from(self.entry_from_key) key_data.delete_keys_from(self.entry_from_key)
del self.module_hash_to_key_data[module_hash] del self.module_hash_to_key_data[module_hash]
if key[0]: if key[0]:
# this is a versioned entry, so should have been on disk # this is a versioned entry, so should have been on
# Something weird happened to cause this, so we are responding by # disk. Something weird happened to cause this, so we
# printing a warning, removing evidence that we ever saw this mystery # are responding by printing a warning, removing
# key. # evidence that we ever saw this mystery key.
pkl_file_to_remove = key_data.key_pkl pkl_file_to_remove = key_data.key_pkl
if not root.startswith("/tmp"): if not root.startswith("/tmp"):
# Under /tmp, file are removed periodically by the os. # Under /tmp, file are removed periodically by the
# So it is normal that this happen from time to time. # os. So it is normal that this happen from time to
# time.
_logger.warning("Removing key file %s because the " _logger.warning("Removing key file %s because the "
"corresponding module is gone from the " "corresponding module is gone from the "
"file system.", "file system.",
...@@ -813,8 +857,8 @@ class ModuleCache(object): ...@@ -813,8 +857,8 @@ class ModuleCache(object):
try: try:
compile_steps = fn(location=location).__iter__() compile_steps = fn(location=location).__iter__()
# Check if we already know a module with the same hash. If we # Check if we already know a module with the same hash.
# do, then there is no need to even compile it. # If we do, then there is no need to even compile it.
duplicated_module = False duplicated_module = False
# The first compilation step is to yield the source code. # The first compilation step is to yield the source code.
src_code = compile_steps.next() src_code = compile_steps.next()
...@@ -828,11 +872,12 @@ class ModuleCache(object): ...@@ -828,11 +872,12 @@ class ModuleCache(object):
# Note that we do not pass the `fn` argument, since it # Note that we do not pass the `fn` argument, since it
# should not be used considering that the module should # should not be used considering that the module should
# already be compiled. # already be compiled.
module = self.module_from_key(key=None, key_data=key_data) module = self.module_from_key(key=None,
key_data=key_data)
name = module.__file__ name = module.__file__
# Add current key to the set of keys associated to the same # Add current key to the set of keys associated to the
# module. We only save the KeyData object of versioned # same module. We only save the KeyData object of
# modules. # versioned modules.
try: try:
key_data.add_key(key, save_pkl=bool(_version)) key_data.add_key(key, save_pkl=bool(_version))
key_broken = False key_broken = False
...@@ -840,8 +885,8 @@ class ModuleCache(object): ...@@ -840,8 +885,8 @@ class ModuleCache(object):
# This should only happen if we tried to save the # This should only happen if we tried to save the
# pickled file. # pickled file.
assert _version assert _version
# The key we are trying to add is broken: we will not # The key we are trying to add is broken: we will
# add it after all. # not add it after all.
key_data.remove_key(key) key_data.remove_key(key)
key_broken = True key_broken = True
...@@ -868,12 +913,13 @@ class ModuleCache(object): ...@@ -868,12 +913,13 @@ class ModuleCache(object):
# Obtain path to the '.so' module file. # Obtain path to the '.so' module file.
name = module.__file__ name = module.__file__
_logger.debug("Adding module to cache %s %s", key, name) _logger.debug("Adding module to cache %s %s",
key, name)
assert name.startswith(location) assert name.startswith(location)
assert name not in self.module_from_name assert name not in self.module_from_name
# Changing the hash of the key is not allowed during # Changing the hash of the key is not allowed during
# compilation. That is the only cause found that makes the # compilation. That is the only cause found that makes
# following assert fail. # the following assert fail.
assert hash(key) == hash_key assert hash(key) == hash_key
assert key not in self.entry_from_key assert key not in self.entry_from_key
...@@ -896,10 +942,11 @@ class ModuleCache(object): ...@@ -896,10 +942,11 @@ class ModuleCache(object):
key_broken = False key_broken = False
except cPickle.PicklingError: except cPickle.PicklingError:
key_broken = True key_broken = True
# Remove key from the KeyData object, to make sure # Remove key from the KeyData object, to make
# we never try to save it again. # sure we never try to save it again.
# We still keep the KeyData object and save it so # We still keep the KeyData object and save it
# that the module can be re-used in the future. # so that the module can be re-used in the
# future.
key_data.keys = set() key_data.keys = set()
key_data.save_pkl() key_data.save_pkl()
...@@ -910,20 +957,21 @@ class ModuleCache(object): ...@@ -910,20 +957,21 @@ class ModuleCache(object):
# versioned module. # versioned module.
self.loaded_key_pkl.add(key_pkl) self.loaded_key_pkl.add(key_pkl)
# Map the new module to its KeyData object. Note that we # Map the new module to its KeyData object. Note that
# need to do it regardless of whether the key is versioned # we need to do it regardless of whether the key is
# or not if we want to be able to re-use this module inside # versioned or not if we want to be able to re-use this
# the same process. # module inside the same process.
self.module_hash_to_key_data[module_hash] = key_data self.module_hash_to_key_data[module_hash] = key_data
except Exception: except Exception:
# This may happen e.g. when an Op has no C implementation. In # This may happen e.g. when an Op has no C implementation.
# any case, we do not want to keep around the temporary work # In any case, we do not want to keep around the temporary
# directory, as it may cause trouble if we create too many of # work directory, as it may cause trouble if we create too
# these. The 'ignore_if_missing' flag is set just in case this # many of these. The 'ignore_if_missing' flag is set just
# directory would have already been deleted. # in case this directory would have already been deleted.
_rmtree(location, ignore_if_missing=True, _rmtree(location, ignore_if_missing=True,
msg='exception -- typically means no C implementation') msg=('exception -- '
'typically means no C implementation'))
raise raise
finally: finally:
...@@ -982,8 +1030,9 @@ class ModuleCache(object): ...@@ -982,8 +1030,9 @@ class ModuleCache(object):
if key_data.keys: if key_data.keys:
# This is to make debugging in pdb easier, by providing # This is to make debugging in pdb easier, by providing
# the offending keys in the local context. # the offending keys in the local context.
key_data_keys = list(key_data.keys) # key_data_keys = list(key_data.keys)
## import pdb; pdb.set_trace() ## import pdb; pdb.set_trace()
pass
elif found > 1: elif found > 1:
msg = 'Multiple equal keys found in unpickled KeyData file' msg = 'Multiple equal keys found in unpickled KeyData file'
if msg: if msg:
...@@ -1005,8 +1054,8 @@ class ModuleCache(object): ...@@ -1005,8 +1054,8 @@ class ModuleCache(object):
self.time_spent_in_check_key += time.time() - start_time self.time_spent_in_check_key += time.time() - start_time
age_thresh_del = 60*60*24*31 # 31 days age_thresh_del = 60 * 60 * 24 * 31 # 31 days
age_thresh_del_unversioned = 60*60*24*7 # 7 days age_thresh_del_unversioned = 60 * 60 * 24 * 7 # 7 days
"""The default age threshold for `clear_old` (in seconds) """The default age threshold for `clear_old` (in seconds)
""" """
...@@ -1090,7 +1139,8 @@ class ModuleCache(object): ...@@ -1090,7 +1139,8 @@ class ModuleCache(object):
def clear_base_files(self): def clear_base_files(self):
""" """
Remove base directories 'cuda_ndarray', 'cutils_ext', 'lazylinker_ext' and 'scan_perform' if present. Remove base directories 'cuda_ndarray', 'cutils_ext', 'lazylinker_ext'
and 'scan_perform' if present.
Note that we do not delete them outright because it may not work on Note that we do not delete them outright because it may not work on
some systems due to these modules being currently in use. Instead we some systems due to these modules being currently in use. Instead we
...@@ -1099,7 +1149,8 @@ class ModuleCache(object): ...@@ -1099,7 +1149,8 @@ class ModuleCache(object):
""" """
compilelock.get_lock() compilelock.get_lock()
try: try:
for base_dir in ('cuda_ndarray', 'cutils_ext', 'lazylinker_ext', 'scan_perform'): for base_dir in ('cuda_ndarray', 'cutils_ext', 'lazylinker_ext',
'scan_perform'):
to_delete = os.path.join(self.dirname, base_dir + '.delete.me') to_delete = os.path.join(self.dirname, base_dir + '.delete.me')
if os.path.isdir(to_delete): if os.path.isdir(to_delete):
try: try:
...@@ -1174,18 +1225,20 @@ class ModuleCache(object): ...@@ -1174,18 +1225,20 @@ class ModuleCache(object):
for filename in os.listdir(self.dirname): for filename in os.listdir(self.dirname):
if filename.startswith('tmp'): if filename.startswith('tmp'):
try: try:
open(os.path.join(self.dirname, filename, 'key.pkl')).close() open(os.path.join(self.dirname, filename, 'key.pkl')
).close()
has_key = True has_key = True
except IOError: except IOError:
has_key = False has_key = False
if not has_key: if not has_key:
age = time_now - last_access_time(os.path.join(self.dirname, filename)) age = time_now - last_access_time(
# In normal case, the processus that created this directory os.path.join(self.dirname, filename))
# will delete it. However, if this processus crashes, it # In normal case, the processus that created this
# will not be cleaned up. # directory will delete it. However, if this processus
# As we don't know if this directory is still used, we wait # crashes, it will not be cleaned up.
# one week and suppose that the processus crashed, and we # As we don't know if this directory is still used,
# take care of the clean-up. # we wait one week and suppose that the processus
# crashed, and we take care of the clean-up.
if age > min_age: if age > min_age:
_rmtree(os.path.join(self.dirname, filename), _rmtree(os.path.join(self.dirname, filename),
msg='old unversioned', level=logging.INFO, msg='old unversioned', level=logging.INFO,
...@@ -1204,6 +1257,7 @@ class ModuleCache(object): ...@@ -1204,6 +1257,7 @@ class ModuleCache(object):
_logger.debug('Time spent checking keys: %s', _logger.debug('Time spent checking keys: %s',
self.time_spent_in_check_key) self.time_spent_in_check_key)
def _rmtree(parent, ignore_nocleanup=False, msg='', level=logging.DEBUG, def _rmtree(parent, ignore_nocleanup=False, msg='', level=logging.DEBUG,
ignore_if_missing=False): ignore_if_missing=False):
# On NFS filesystems, it is impossible to delete a directory with open # On NFS filesystems, it is impossible to delete a directory with open
...@@ -1226,12 +1280,14 @@ def _rmtree(parent, ignore_nocleanup=False, msg='', level=logging.DEBUG, ...@@ -1226,12 +1280,14 @@ def _rmtree(parent, ignore_nocleanup=False, msg='', level=logging.DEBUG,
if os.path.exists(parent): if os.path.exists(parent):
try: try:
_logger.info('placing "delete.me" in %s', parent) _logger.info('placing "delete.me" in %s', parent)
open(os.path.join(parent,'delete.me'), 'w').close() open(os.path.join(parent, 'delete.me'), 'w').close()
except Exception, ee: except Exception, ee:
_logger.warning("Failed to remove or mark cache directory %s " _logger.warning("Failed to remove or mark cache directory %s "
"for removal %s", parent, ee) "for removal %s", parent, ee)
_module_cache = None _module_cache = None
def get_module_cache(dirname, init_args=None): def get_module_cache(dirname, init_args=None):
""" """
:param init_args: If not None, the (k, v) pairs in this dictionary will :param init_args: If not None, the (k, v) pairs in this dictionary will
...@@ -1252,6 +1308,7 @@ def get_module_cache(dirname, init_args=None): ...@@ -1252,6 +1308,7 @@ def get_module_cache(dirname, init_args=None):
_module_cache.dirname, dirname) _module_cache.dirname, dirname)
return _module_cache return _module_cache
def get_lib_extension(): def get_lib_extension():
"""Return the platform-dependent extension for compiled modules.""" """Return the platform-dependent extension for compiled modules."""
if sys.platform == 'win32': if sys.platform == 'win32':
...@@ -1259,6 +1316,7 @@ def get_lib_extension(): ...@@ -1259,6 +1316,7 @@ def get_lib_extension():
else: else:
return 'so' return 'so'
def get_gcc_shared_library_arg(): def get_gcc_shared_library_arg():
"""Return the platform-dependent GCC argument for shared libraries.""" """Return the platform-dependent GCC argument for shared libraries."""
if sys.platform == 'darwin': if sys.platform == 'darwin':
...@@ -1266,29 +1324,33 @@ def get_gcc_shared_library_arg(): ...@@ -1266,29 +1324,33 @@ def get_gcc_shared_library_arg():
else: else:
return '-shared' return '-shared'
def std_include_dirs(): def std_include_dirs():
return numpy.distutils.misc_util.get_numpy_include_dirs() + [distutils.sysconfig.get_python_inc()] return (numpy.distutils.misc_util.get_numpy_include_dirs()
+ [distutils.sysconfig.get_python_inc()])
def std_lib_dirs_and_libs(): def std_lib_dirs_and_libs():
python_inc = distutils.sysconfig.get_python_inc() python_inc = distutils.sysconfig.get_python_inc()
if sys.platform == 'win32': if sys.platform == 'win32':
# Obtain the library name from the Python version instead of the # Obtain the library name from the Python version instead of the
# installation directory, in case the user defined a custom installation # installation directory, in case the user defined a custom
# directory. # installation directory.
python_version = distutils.sysconfig.get_python_version() python_version = distutils.sysconfig.get_python_version()
libname = 'python' + python_version.replace('.', '') libname = 'python' + python_version.replace('.', '')
# Also add directory containing the Python library to the library # Also add directory containing the Python library to the library
# directories. # directories.
python_lib_dir = os.path.join(os.path.dirname(python_inc), 'libs') python_lib_dir = os.path.join(os.path.dirname(python_inc), 'libs')
lib_dirs = [python_lib_dir]
return [libname], [python_lib_dir] return [libname], [python_lib_dir]
#DSE Patch 2 for supporting OSX frameworks. Suppress -lpython2.x when frameworks are present
elif sys.platform=='darwin' : # DSE Patch 2 for supporting OSX frameworks.
if python_inc.count('Python.framework') : # Suppress -lpython2.x when frameworks are present
return [],[] elif sys.platform == 'darwin':
else : if python_inc.count('Python.framework'):
libname=os.path.basename(python_inc) return [], []
return [libname],[] else:
libname = os.path.basename(python_inc)
return [libname], []
else: else:
# Typical include directory: /usr/include/python2.6 # Typical include directory: /usr/include/python2.6
libname = os.path.basename(python_inc) libname = os.path.basename(python_inc)
...@@ -1399,14 +1461,10 @@ class GCC_compiler(object): ...@@ -1399,14 +1461,10 @@ class GCC_compiler(object):
if python_lib not in lib_dirs: if python_lib not in lib_dirs:
lib_dirs.append(python_lib) lib_dirs.append(python_lib)
workdir = location
cppfilename = os.path.join(location, 'mod.cpp') cppfilename = os.path.join(location, 'mod.cpp')
cppfile = file(cppfilename, 'w') cppfile = file(cppfilename, 'w')
_logger.debug('Writing module C++ code to %s', cppfilename) _logger.debug('Writing module C++ code to %s', cppfilename)
ofiles = []
rval = None
cppfile.write(src_code) cppfile.write(src_code)
# Avoid gcc warning "no newline at end of file". # Avoid gcc warning "no newline at end of file".
...@@ -1433,8 +1491,9 @@ class GCC_compiler(object): ...@@ -1433,8 +1491,9 @@ class GCC_compiler(object):
def print_command_line_error(): def print_command_line_error():
# Print command line when a problem occurred. # Print command line when a problem occurred.
print >> sys.stderr, ("Problem occurred during compilation with the " print >> sys.stderr, (
"command line below:") "Problem occurred during compilation with the "
"command line below:")
print >> sys.stderr, ' '.join(cmd) print >> sys.stderr, ' '.join(cmd)
try: try:
...@@ -1457,8 +1516,8 @@ class GCC_compiler(object): ...@@ -1457,8 +1516,8 @@ class GCC_compiler(object):
# Print errors just below the command line. # Print errors just below the command line.
print compile_stderr print compile_stderr
# We replace '\n' by '. ' in the error message because when Python # We replace '\n' by '. ' in the error message because when Python
# prints the exception, having '\n' in the text makes it more difficult # prints the exception, having '\n' in the text makes it more
# to read. # difficult to read.
raise Exception('Compilation failed (return status=%s): %s' % raise Exception('Compilation failed (return status=%s): %s' %
(status, compile_stderr.replace('\n', '. '))) (status, compile_stderr.replace('\n', '. ')))
......
""" """
Helper functions to make gof backwards compatible (tested on python 2.4 and 2.5) Helper functions to make gof backwards compatible
(tested on python 2.4 and 2.5)
""" """
import collections import collections
import sys import sys
if sys.version_info[:2] < (2,5): if sys.version_info[:2] < (2, 5):
def all(iterable): def all(iterable):
for element in iterable: for element in iterable:
...@@ -55,16 +57,19 @@ if sys.version_info[:2] < (2,5): ...@@ -55,16 +57,19 @@ if sys.version_info[:2] < (2,5):
raise TypeError('first argument must be callable') raise TypeError('first argument must be callable')
dict.__init__(self, *a, **kw) dict.__init__(self, *a, **kw)
self.default_factory = default_factory self.default_factory = default_factory
def __getitem__(self, key): def __getitem__(self, key):
try: try:
return dict.__getitem__(self, key) return dict.__getitem__(self, key)
except KeyError: except KeyError:
return self.__missing__(key) return self.__missing__(key)
def __missing__(self, key): def __missing__(self, key):
if self.default_factory is None: if self.default_factory is None:
raise KeyError(key) raise KeyError(key)
self[key] = value = self.default_factory() self[key] = value = self.default_factory()
return value return value
def __reduce__(self): def __reduce__(self):
if self.default_factory is None: if self.default_factory is None:
args = tuple() args = tuple()
...@@ -72,14 +77,18 @@ if sys.version_info[:2] < (2,5): ...@@ -72,14 +77,18 @@ if sys.version_info[:2] < (2,5):
args = self.default_factory, args = self.default_factory,
# consider replacing items() with iteritems() # consider replacing items() with iteritems()
return type(self), args, None, None, self.items() return type(self), args, None, None, self.items()
def copy(self): def copy(self):
return self.__copy__() return self.__copy__()
def __copy__(self): def __copy__(self):
return type(self)(self.default_factory, self) return type(self)(self.default_factory, self)
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
import copy import copy
return type(self)(self.default_factory, return type(self)(self.default_factory,
copy.deepcopy(self.items())) copy.deepcopy(self.items()))
def __repr__(self): def __repr__(self):
return 'defaultdict(%s, %s)' % (self.default_factory, return 'defaultdict(%s, %s)' % (self.default_factory,
dict.__repr__(self)) dict.__repr__(self))
...@@ -90,14 +99,15 @@ else: ...@@ -90,14 +99,15 @@ else:
import __builtin__ import __builtin__
all = __builtin__.all all = __builtin__.all
any = __builtin__.any any = __builtin__.any
import functools, collections import collections
import functools
partial = functools.partial partial = functools.partial
defaultdict = collections.defaultdict defaultdict = collections.defaultdict
deque = collections.deque deque = collections.deque
__all__ = ['all', 'any'] __all__ = ['all', 'any']
if sys.version_info[:2] < (2,6): if sys.version_info[:2] < (2, 6):
# Borrowed from Python docs # Borrowed from Python docs
def combinations(iterable, r): def combinations(iterable, r):
# combinations('ABCD', 2) --> AB AC AD BC BD CD # combinations('ABCD', 2) --> AB AC AD BC BD CD
...@@ -115,18 +125,17 @@ if sys.version_info[:2] < (2,6): ...@@ -115,18 +125,17 @@ if sys.version_info[:2] < (2,6):
else: else:
return return
indices[i] += 1 indices[i] += 1
for j in range(i+1, r): for j in range(i + 1, r):
indices[j] = indices[j-1] + 1 indices[j] = indices[j - 1] + 1
yield tuple(pool[i] for i in indices) yield tuple(pool[i] for i in indices)
def product(*args, **kwds): def product(*args, **kwds):
# product('ABCD', 'xy') --> Ax Ay Bx By Cx Cy Dx Dy # product('ABCD', 'xy') --> Ax Ay Bx By Cx Cy Dx Dy
# product(range(2), repeat=3) --> 000 001 010 011 100 101 110 111 # product(range(2), repeat=3) --> 000 001 010 011 100 101 110 111
pools = map(tuple, args) * kwds.get('repeat', 1) pools = map(tuple, args) * kwds.get('repeat', 1)
result = [[]] result = [[]]
for pool in pools: for pool in pools:
result = [x+[y] for x in result for y in pool] result = [x + [y] for x in result for y in pool]
for prod in result: for prod in result:
yield tuple(prod) yield tuple(prod)
......
...@@ -21,7 +21,6 @@ from theano.tensor import opt, get_constant_value ...@@ -21,7 +21,6 @@ from theano.tensor import opt, get_constant_value
from theano import gof from theano import gof
from theano.gof.python25 import maxsize from theano.gof.python25 import maxsize
from theano.compile import optdb from theano.compile import optdb
from theano import config
from theano.compile.function_module import deep_copy_op from theano.compile.function_module import deep_copy_op
import scan_op import scan_op
...@@ -97,7 +96,6 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -97,7 +96,6 @@ def remove_constants_and_unused_inputs_scan(node):
try: try:
# This works if input is a constant that has all entries # This works if input is a constant that has all entries
# equal # equal
val = tensor.get_constant_value(node.inputs[idx + 1])
givens[op_ins[idx]] = node.inputs[idx + 1].clone()[0] givens[op_ins[idx]] = node.inputs[idx + 1].clone()[0]
except TypeError: except TypeError:
pass pass
...@@ -729,7 +727,6 @@ class ScanSaveMem(gof.Optimizer): ...@@ -729,7 +727,6 @@ class ScanSaveMem(gof.Optimizer):
nw_slice = (fslice,) + tuple(old_slices[1:]) nw_slice = (fslice,) + tuple(old_slices[1:])
nw_pos = inv_compress_map[idx] nw_pos = inv_compress_map[idx]
nw_out = new_outs[nw_pos]
subtens = tensor.basic.Subtensor(nw_slice) subtens = tensor.basic.Subtensor(nw_slice)
# slice inputs # slice inputs
...@@ -748,7 +745,6 @@ class ScanSaveMem(gof.Optimizer): ...@@ -748,7 +745,6 @@ class ScanSaveMem(gof.Optimizer):
for pos, old_outs in old_outputs: for pos, old_outs in old_outputs:
if len(old_outs) > 0: if len(old_outs) > 0:
nw_pos = compress_map[pos] nw_pos = compress_map[pos]
nw_out = new_outs[nw_pos]
for k, old in enumerate(old_outs): for k, old in enumerate(old_outs):
# Get the correct slice # Get the correct slice
cnf_slice, old_slices = slices[pos][k] cnf_slice, old_slices = slices[pos][k]
...@@ -1066,7 +1062,6 @@ def scan_merge_inouts(node): ...@@ -1066,7 +1062,6 @@ def scan_merge_inouts(node):
else: else:
a_inner_outs = a.inner_outputs a_inner_outs = a.inner_outputs
inner_outputs = scan_utils.clone(a_inner_outs, replace=inp_equiv) inner_outputs = scan_utils.clone(a_inner_outs, replace=inp_equiv)
orig_outputs = a.outer_outputs
op = scan_op.Scan(inner_inputs, inner_outputs, info) op = scan_op.Scan(inner_inputs, inner_outputs, info)
outputs = op(*outer_inputs) outputs = op(*outer_inputs)
......
...@@ -2,9 +2,7 @@ ...@@ -2,9 +2,7 @@
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
import __builtin__
import sys import sys
from theano.configparser import config
import warnings import warnings
from itertools import izip from itertools import izip
...@@ -12,6 +10,7 @@ import numpy ...@@ -12,6 +10,7 @@ import numpy
#from copy import copy as python_copy #from copy import copy as python_copy
import theano import theano
from theano.configparser import config
from theano import gof from theano import gof
from theano.gof import Apply, Constant, Op, Type, Value, Variable from theano.gof import Apply, Constant, Op, Type, Value, Variable
...@@ -185,7 +184,7 @@ def as_tensor_variable(x, name=None, ndim=None): ...@@ -185,7 +184,7 @@ def as_tensor_variable(x, name=None, ndim=None):
except TypeError: except TypeError:
try: try:
str_x = str(x) str_x = str(x)
except Exception, e: except Exception:
str_x = repr(x) str_x = repr(x)
raise TypeError("Cannot convert %s to TensorType" % str_x, type(x)) raise TypeError("Cannot convert %s to TensorType" % str_x, type(x))
...@@ -727,7 +726,6 @@ class TensorType(Type): ...@@ -727,7 +726,6 @@ class TensorType(Type):
self=self) self=self)
) )
def value_validity_msg(self, a): def value_validity_msg(self, a):
try: try:
self.filter(a, strict=True) self.filter(a, strict=True)
...@@ -735,33 +733,35 @@ class TensorType(Type): ...@@ -735,33 +733,35 @@ class TensorType(Type):
return str(e) return str(e)
return "value is valid" return "value is valid"
def dtype_specs(self): def dtype_specs(self):
"""Return a tuple (python type, c type, numpy typenum) that corresponds to """Return a tuple (python type, c type, numpy typenum) that corresponds
self.dtype. to self.dtype.
This function is used internally as part of C code generation. This function is used internally as part of C code generation.
""" """
#TODO: add more type correspondances for e.g. int32, int64, float32, #TODO: add more type correspondances for e.g. int32, int64, float32,
#complex64, etc. #complex64, etc.
try: try:
return {'float32': (float, 'npy_float32', 'NPY_FLOAT32'), return {
'float64': (float, 'npy_float64', 'NPY_FLOAT64'), 'float32': (float, 'npy_float32', 'NPY_FLOAT32'),
'uint8': (int, 'npy_uint8', 'NPY_UINT8'), 'float64': (float, 'npy_float64', 'NPY_FLOAT64'),
'int8': (int, 'npy_int8', 'NPY_INT8'), 'uint8': (int, 'npy_uint8', 'NPY_UINT8'),
'uint16': (int, 'npy_uint16', 'NPY_UINT16'), 'int8': (int, 'npy_int8', 'NPY_INT8'),
'int16': (int, 'npy_int16', 'NPY_INT16'), 'uint16': (int, 'npy_uint16', 'NPY_UINT16'),
'uint32': (int, 'npy_uint32', 'NPY_UINT32'), 'int16': (int, 'npy_int16', 'NPY_INT16'),
'int32': (int, 'npy_int32', 'NPY_INT32'), 'uint32': (int, 'npy_uint32', 'NPY_UINT32'),
'uint64': (int, 'npy_uint64', 'NPY_UINT64'), 'int32': (int, 'npy_int32', 'NPY_INT32'),
'int64': (int, 'npy_int64', 'NPY_INT64'), 'uint64': (int, 'npy_uint64', 'NPY_UINT64'),
'complex128': (complex, 'theano_complex128', 'NPY_COMPLEX128'), 'int64': (int, 'npy_int64', 'NPY_INT64'),
'complex64': (complex, 'theano_complex64', 'NPY_COMPLEX64')}[self.dtype] 'complex128': (complex, 'theano_complex128', 'NPY_COMPLEX128'),
'complex64': (complex, 'theano_complex64', 'NPY_COMPLEX64')
}[self.dtype]
except KeyError: except KeyError:
raise TypeError("Unsupported dtype for %s: %s" % (self.__class__.__name__, self.dtype)) raise TypeError("Unsupported dtype for %s: %s"
% (self.__class__.__name__, self.dtype))
def to_scalar_type(self): def to_scalar_type(self):
return scal.Scalar(dtype = self.dtype) return scal.Scalar(dtype=self.dtype)
def __eq__(self, other): def __eq__(self, other):
"""Compare True iff other is the same kind of TensorType""" """Compare True iff other is the same kind of TensorType"""
...@@ -769,10 +769,10 @@ class TensorType(Type): ...@@ -769,10 +769,10 @@ class TensorType(Type):
and other.broadcastable == self.broadcastable and other.broadcastable == self.broadcastable
@staticmethod @staticmethod
def may_share_memory(a,b): def may_share_memory(a, b):
# This is a method of TensorType, so both a and b should be ndarrays # This is a method of TensorType, so both a and b should be ndarrays
if isinstance(a, numpy.ndarray) and isinstance(b, numpy.ndarray): if isinstance(a, numpy.ndarray) and isinstance(b, numpy.ndarray):
return numpy.may_share_memory(a,b) return numpy.may_share_memory(a, b)
else: else:
return False return False
...@@ -784,9 +784,10 @@ class TensorType(Type): ...@@ -784,9 +784,10 @@ class TensorType(Type):
return False return False
if force_same_dtype and a.dtype != b.dtype: if force_same_dtype and a.dtype != b.dtype:
return False return False
a_eq_b = (a==b) a_eq_b = (a == b)
r = numpy.all(a_eq_b) r = numpy.all(a_eq_b)
if r: return True if r:
return True
# maybe the trouble is that there are NaNs # maybe the trouble is that there are NaNs
a_missing = numpy.isnan(a) a_missing = numpy.isnan(a)
if a_missing.any(): if a_missing.any():
...@@ -794,8 +795,9 @@ class TensorType(Type): ...@@ -794,8 +795,9 @@ class TensorType(Type):
return numpy.all(a_eq_b + (a_missing == b_missing)) return numpy.all(a_eq_b + (a_missing == b_missing))
else: else:
return False return False
@staticmethod @staticmethod
def values_eq_approx(a, b, allow_remove_inf = False, allow_remove_nan = False): def values_eq_approx(a, b, allow_remove_inf=False, allow_remove_nan=False):
""" """
:param allow_remove_inf: If True, when there is an inf in a, :param allow_remove_inf: If True, when there is an inf in a,
we allow any value in b in that position. we allow any value in b in that position.
...@@ -810,10 +812,11 @@ class TensorType(Type): ...@@ -810,10 +812,11 @@ class TensorType(Type):
if a.dtype != b.dtype: if a.dtype != b.dtype:
return False return False
if 'int' in str(a.dtype): if 'int' in str(a.dtype):
return numpy.all(a==b) return numpy.all(a == b)
else: else:
#work around a numpy.allclose bug: http://projects.scipy.org/numpy/ticket/1672 # work around a numpy.allclose bug:
if a.ndim==0 and numpy.isinf(a): # http://projects.scipy.org/numpy/ticket/1672
if a.ndim == 0 and numpy.isinf(a):
a = a.reshape(1) a = a.reshape(1)
b = b.reshape(1) b = b.reshape(1)
...@@ -835,9 +838,10 @@ class TensorType(Type): ...@@ -835,9 +838,10 @@ class TensorType(Type):
if not (a_missing.any() or (allow_remove_inf and a_inf.any())): if not (a_missing.any() or (allow_remove_inf and a_inf.any())):
# There are no missing values in a, thus this is not the # There are no missing values in a, thus this is not the
# reason why numpy.allclose(a, b) returned False. # reason why numpy.allclose(a, b) returned False.
_logger.info('numpy allclose failed for abs_err %f and rel_err %f', _logger.info(
numpy.max(abs(a-b)), 'numpy allclose failed for abs_err %f and rel_err %f',
numpy.max(abs(a-b) / (abs(a) + abs(b)))) numpy.max(abs(a - b)),
numpy.max(abs(a - b) / (abs(a) + abs(b))))
return False return False
# The following line is what numpy.allclose bases its decision # The following line is what numpy.allclose bases its decision
# upon, according to its documentation. # upon, according to its documentation.
...@@ -853,11 +857,13 @@ class TensorType(Type): ...@@ -853,11 +857,13 @@ class TensorType(Type):
#cmp_elemwise is weird when we have inf and -inf. #cmp_elemwise is weird when we have inf and -inf.
#set it to False #set it to False
cmp_elemwise = numpy.where(both_inf&cmp_elemwise, cmp_elemwise = numpy.where(
a==b,cmp_elemwise) both_inf & cmp_elemwise,
a == b,
cmp_elemwise)
#check the sign of the inf #check the sign of the inf
both_inf = numpy.where(both_inf,a==b,both_inf) both_inf = numpy.where(both_inf, (a == b), both_inf)
if allow_remove_inf: if allow_remove_inf:
both_inf += a_inf both_inf += a_inf
...@@ -871,37 +877,38 @@ class TensorType(Type): ...@@ -871,37 +877,38 @@ class TensorType(Type):
@staticmethod @staticmethod
def values_eq_approx_remove_inf(a, b): def values_eq_approx_remove_inf(a, b):
return TensorType.values_eq_approx(a,b,True) return TensorType.values_eq_approx(a, b, True)
@staticmethod @staticmethod
def values_eq_approx_remove_nan(a, b): def values_eq_approx_remove_nan(a, b):
return TensorType.values_eq_approx(a,b,False,True) return TensorType.values_eq_approx(a, b, False, True)
@staticmethod @staticmethod
def values_eq_approx_remove_inf_nan(a, b): def values_eq_approx_remove_inf_nan(a, b):
return TensorType.values_eq_approx(a,b,True,True) return TensorType.values_eq_approx(a, b, True, True)
def __hash__(self): def __hash__(self):
"""Hash equal for same kinds of TensorType""" """Hash equal for same kinds of TensorType"""
return hashtype(self) ^ hash(self.dtype) ^ hash(self.broadcastable) return hashtype(self) ^ hash(self.dtype) ^ hash(self.broadcastable)
ndim = property(lambda self: len(self.broadcastable), doc = "number of dimensions") ndim = property(lambda self: len(self.broadcastable),
doc="number of dimensions")
"""Number of dimensions """Number of dimensions
This read-only property is the preferred way to get the number of dimensions This read-only property is the preferred way to get the number of
of a `TensorType`. dimensions of a `TensorType`.
""" """
def make_variable(self, name = None): def make_variable(self, name=None):
"""Return a `TensorVariable` of this type """Return a `TensorVariable` of this type
:Parameters: :Parameters:
- `name`: str - `name`: str
A pretty name to identify this `Variable` when printing and debugging A pretty name to identify this `Variable` when printing and
debugging
""" """
return TensorVariable(self, name = name) return TensorVariable(self, name=name)
def __str__(self): def __str__(self):
if self.name: if self.name:
...@@ -932,14 +939,14 @@ class TensorType(Type): ...@@ -932,14 +939,14 @@ class TensorType(Type):
PyArrayObject* %(name)s; PyArrayObject* %(name)s;
int type_num_%(name)s; int type_num_%(name)s;
typedef %(dtype)s dtype_%(name)s; typedef %(dtype)s dtype_%(name)s;
""" % dict(sub, name = name, dtype = self.dtype_specs()[1]) """ % dict(sub, name=name, dtype=self.dtype_specs()[1])
def c_init(self, name, sub): def c_init(self, name, sub):
"""Override `CLinkerOp.c_init` """ """Override `CLinkerOp.c_init` """
return """ return """
%(name)s = NULL; %(name)s = NULL;
type_num_%(name)s = %(type_num)s; type_num_%(name)s = %(type_num)s;
""" % dict(sub, name = name, type_num = self.dtype_specs()[2]) """ % dict(sub, name=name, type_num=self.dtype_specs()[2])
def c_extract(self, name, sub): def c_extract(self, name, sub):
"""Override `CLinkerOp.c_extract` """ """Override `CLinkerOp.c_extract` """
...@@ -969,7 +976,7 @@ class TensorType(Type): ...@@ -969,7 +976,7 @@ class TensorType(Type):
} }
%(name)s = (PyArrayObject*)(py_%(name)s); %(name)s = (PyArrayObject*)(py_%(name)s);
Py_XINCREF(%(name)s); Py_XINCREF(%(name)s);
""" % dict(sub, name = name, type_num = self.dtype_specs()[2]) """ % dict(sub, name=name, type_num=self.dtype_specs()[2])
def c_cleanup(self, name, sub): def c_cleanup(self, name, sub):
"""Override `CLinkerOp.c_cleanup` """ """Override `CLinkerOp.c_cleanup` """
...@@ -1018,12 +1025,14 @@ class TensorType(Type): ...@@ -1018,12 +1025,14 @@ class TensorType(Type):
# to have OutputGuard generate C code for this type. # to have OutputGuard generate C code for this type.
theano.compile.mode.register_OutputGuard_c_code(TensorType) theano.compile.mode.register_OutputGuard_c_code(TensorType)
# Easy constructors # Easy constructors
def tensor(*args, **kwargs): def tensor(*args, **kwargs):
name = kwargs.pop('name',None) name = kwargs.pop('name', None)
return TensorType(*args, **kwargs).make_variable(name=name) return TensorType(*args, **kwargs).make_variable(name=name)
def _multi(*fns): def _multi(*fns):
def f2(f, *names): def f2(f, *names):
if names and isinstance(names[0], int): if names and isinstance(names[0], int):
...@@ -1051,7 +1060,9 @@ bscalar = TensorType('int8', ()) ...@@ -1051,7 +1060,9 @@ bscalar = TensorType('int8', ())
wscalar = TensorType('int16', ()) wscalar = TensorType('int16', ())
iscalar = TensorType('int32', ()) iscalar = TensorType('int32', ())
lscalar = TensorType('int64', ()) lscalar = TensorType('int64', ())
def scalar(name = None, dtype = None):
def scalar(name=None, dtype=None):
"""Return a symbolic scalar variable. """Return a symbolic scalar variable.
:param dtype: numeric type (None means to use theano.config.floatX) :param dtype: numeric type (None means to use theano.config.floatX)
:param name: a name to attach to this variable :param name: a name to attach to this variable
...@@ -1060,7 +1071,9 @@ def scalar(name = None, dtype = None): ...@@ -1060,7 +1071,9 @@ def scalar(name = None, dtype = None):
dtype = config.floatX dtype = config.floatX
type = TensorType(dtype, ()) type = TensorType(dtype, ())
return type(name) return type(name)
scalars, fscalars, dscalars, iscalars, lscalars = _multi(scalar, fscalar, dscalar, iscalar, lscalar)
scalars, fscalars, dscalars, iscalars, lscalars = _multi(
scalar, fscalar, dscalar, iscalar, lscalar)
int_types = bscalar, wscalar, iscalar, lscalar int_types = bscalar, wscalar, iscalar, lscalar
float_types = fscalar, dscalar float_types = fscalar, dscalar
...@@ -1077,7 +1090,9 @@ bvector = TensorType('int8', (False,)) ...@@ -1077,7 +1090,9 @@ bvector = TensorType('int8', (False,))
wvector = TensorType('int16', (False,)) wvector = TensorType('int16', (False,))
ivector = TensorType('int32', (False, )) ivector = TensorType('int32', (False, ))
lvector = TensorType('int64', (False, )) lvector = TensorType('int64', (False, ))
def vector(name = None, dtype = None):
def vector(name=None, dtype=None):
"""Return a symbolic vector variable. """Return a symbolic vector variable.
:param dtype: numeric type (None means to use theano.config.floatX) :param dtype: numeric type (None means to use theano.config.floatX)
:param name: a name to attach to this variable :param name: a name to attach to this variable
...@@ -1086,7 +1101,9 @@ def vector(name = None, dtype = None): ...@@ -1086,7 +1101,9 @@ def vector(name = None, dtype = None):
dtype = config.floatX dtype = config.floatX
type = TensorType(dtype, (False, )) type = TensorType(dtype, (False, ))
return type(name) return type(name)
vectors, fvectors, dvectors, ivectors, lvectors = _multi(vector, fvector, dvector, ivector, lvector)
vectors, fvectors, dvectors, ivectors, lvectors = _multi(
vector, fvector, dvector, ivector, lvector)
int_vector_types = bvector, wvector, ivector, lvector int_vector_types = bvector, wvector, ivector, lvector
float_vector_types = fvector, dvector float_vector_types = fvector, dvector
...@@ -1100,7 +1117,9 @@ bmatrix = TensorType('int8', (False, False)) ...@@ -1100,7 +1117,9 @@ bmatrix = TensorType('int8', (False, False))
wmatrix = TensorType('int16', (False, False)) wmatrix = TensorType('int16', (False, False))
imatrix = TensorType('int32', (False, False)) imatrix = TensorType('int32', (False, False))
lmatrix = TensorType('int64', (False, False)) lmatrix = TensorType('int64', (False, False))
def matrix(name = None, dtype = None):
def matrix(name=None, dtype=None):
"""Return a symbolic matrix variable. """Return a symbolic matrix variable.
:param dtype: numeric type (None means to use theano.config.floatX) :param dtype: numeric type (None means to use theano.config.floatX)
:param name: a name to attach to this variable :param name: a name to attach to this variable
...@@ -1109,7 +1128,9 @@ def matrix(name = None, dtype = None): ...@@ -1109,7 +1128,9 @@ def matrix(name = None, dtype = None):
dtype = config.floatX dtype = config.floatX
type = TensorType(dtype, (False, False)) type = TensorType(dtype, (False, False))
return type(name) return type(name)
matrices, fmatrices, dmatrices, imatrices, lmatrices = _multi(matrix, fmatrix, dmatrix, imatrix, lmatrix)
matrices, fmatrices, dmatrices, imatrices, lmatrices = _multi(
matrix, fmatrix, dmatrix, imatrix, lmatrix)
int_matrix_types = bmatrix, wmatrix, imatrix, lmatrix int_matrix_types = bmatrix, wmatrix, imatrix, lmatrix
float_matrix_types = fmatrix, dmatrix float_matrix_types = fmatrix, dmatrix
...@@ -1123,7 +1144,9 @@ brow = TensorType('int8', (True, False)) ...@@ -1123,7 +1144,9 @@ brow = TensorType('int8', (True, False))
wrow = TensorType('int16', (True, False)) wrow = TensorType('int16', (True, False))
irow = TensorType('int32', (True, False)) irow = TensorType('int32', (True, False))
lrow = TensorType('int64', (True, False)) lrow = TensorType('int64', (True, False))
def row(name = None, dtype = None):
def row(name=None, dtype=None):
"""Return a symbolic row variable (ndim=2, broadcastable=[True,False]). """Return a symbolic row variable (ndim=2, broadcastable=[True,False]).
:param dtype: numeric type (None means to use theano.config.floatX) :param dtype: numeric type (None means to use theano.config.floatX)
:param name: a name to attach to this variable :param name: a name to attach to this variable
...@@ -1142,7 +1165,9 @@ bcol = TensorType('int8', (False, True)) ...@@ -1142,7 +1165,9 @@ bcol = TensorType('int8', (False, True))
wcol = TensorType('int16', (False, True)) wcol = TensorType('int16', (False, True))
icol = TensorType('int32', (False, True)) icol = TensorType('int32', (False, True))
lcol = TensorType('int64', (False, True)) lcol = TensorType('int64', (False, True))
def col(name = None, dtype = None):
def col(name=None, dtype=None):
"""Return a symbolic column variable (ndim=2, broadcastable=[False,True]). """Return a symbolic column variable (ndim=2, broadcastable=[False,True]).
:param dtype: numeric type (None means to use theano.config.floatX) :param dtype: numeric type (None means to use theano.config.floatX)
:param name: a name to attach to this variable :param name: a name to attach to this variable
...@@ -1153,14 +1178,16 @@ def col(name = None, dtype = None): ...@@ -1153,14 +1178,16 @@ def col(name = None, dtype = None):
return type(name) return type(name)
cols, fcols, dcols, icols, lcols = _multi(col, fcol, dcol, icol, lcol) cols, fcols, dcols, icols, lcols = _multi(col, fcol, dcol, icol, lcol)
ctensor3 = TensorType('complex64', (False,)*3) ctensor3 = TensorType('complex64', ((False,) * 3))
ztensor3 = TensorType('complex128', (False,)*3) ztensor3 = TensorType('complex128', ((False,) * 3))
ftensor3 = TensorType('float32', (False,)*3) ftensor3 = TensorType('float32', ((False,) * 3))
dtensor3 = TensorType('float64', (False,)*3) dtensor3 = TensorType('float64', ((False,) * 3))
btensor3 = TensorType('int8', (False,)*3) btensor3 = TensorType('int8', ((False,) * 3))
wtensor3 = TensorType('int16', (False,)*3) wtensor3 = TensorType('int16', ((False,) * 3))
itensor3 = TensorType('int32', (False,)*3) itensor3 = TensorType('int32', ((False,) * 3))
ltensor3 = TensorType('int64', (False,)*3) ltensor3 = TensorType('int64', ((False,) * 3))
def tensor3(name=None, dtype=None): def tensor3(name=None, dtype=None):
"""Return a symbolic 3-D variable. """Return a symbolic 3-D variable.
:param dtype: numeric type (None means to use theano.config.floatX) :param dtype: numeric type (None means to use theano.config.floatX)
...@@ -1170,17 +1197,20 @@ def tensor3(name=None, dtype=None): ...@@ -1170,17 +1197,20 @@ def tensor3(name=None, dtype=None):
dtype = config.floatX dtype = config.floatX
type = TensorType(dtype, (False, False, False)) type = TensorType(dtype, (False, False, False))
return type(name) return type(name)
tensor3s, ftensor3s, dtensor3s, itensor3s, ltensor3s = _multi(tensor3, ftensor3, dtensor3,
itensor3, ltensor3) tensor3s, ftensor3s, dtensor3s, itensor3s, ltensor3s = _multi(
tensor3, ftensor3, dtensor3, itensor3, ltensor3)
ctensor4 = TensorType('complex64', (False,)*4)
ztensor4 = TensorType('complex128', (False,)*4) ctensor4 = TensorType('complex64', ((False,) * 4))
ftensor4 = TensorType('float32', (False,)*4) ztensor4 = TensorType('complex128', ((False,) * 4))
dtensor4 = TensorType('float64', (False,)*4) ftensor4 = TensorType('float32', ((False,) * 4))
btensor4 = TensorType('int8', (False,)*4) dtensor4 = TensorType('float64', ((False,) * 4))
wtensor4 = TensorType('int16', (False,)*4) btensor4 = TensorType('int8', ((False,) * 4))
itensor4 = TensorType('int32', (False,)*4) wtensor4 = TensorType('int16', ((False,) * 4))
ltensor4 = TensorType('int64', (False,)*4) itensor4 = TensorType('int32', ((False,) * 4))
ltensor4 = TensorType('int64', ((False,) * 4))
def tensor4(name=None, dtype=None): def tensor4(name=None, dtype=None):
"""Return a symbolic 4-D variable. """Return a symbolic 4-D variable.
:param dtype: numeric type (None means to use theano.config.floatX) :param dtype: numeric type (None means to use theano.config.floatX)
...@@ -1190,114 +1220,147 @@ def tensor4(name=None, dtype=None): ...@@ -1190,114 +1220,147 @@ def tensor4(name=None, dtype=None):
dtype = config.floatX dtype = config.floatX
type = TensorType(dtype, (False, False, False, False)) type = TensorType(dtype, (False, False, False, False))
return type(name) return type(name)
tensor4s, ftensor4s, dtensor4s, itensor4s, ltensor4s = _multi(tensor4, ftensor4, dtensor4, tensor4s, ftensor4s, dtensor4s, itensor4s, ltensor4s = _multi(
itensor4, ltensor4) tensor4, ftensor4, dtensor4, itensor4, ltensor4)
class _tensor_py_operators: class _tensor_py_operators:
#UNARY #UNARY
def __abs__(self): return abs_(self) def __abs__(self):
def __neg__(self): return neg(self) return abs_(self)
def __neg__(self):
return neg(self)
#CASTS #CASTS
#### REMOVED THESE BECAUSE PYTHON appears to require __int__ to return an int. -JB 20081112 #### REMOVED THESE BECAUSE PYTHON appears to require __int__ to return
#### an int. -JB 20081112
#def __int__(self): return convert_to_int32(self) #def __int__(self): return convert_to_int32(self)
#def __float__(self): return convert_to_float64(self) #def __float__(self): return convert_to_float64(self)
#def __complex__(self): return convert_to_complex128(self) #def __complex__(self): return convert_to_complex128(self)
#COMPARISONS #COMPARISONS
_is_nonzero = True _is_nonzero = True
def __lt__(self,other):
def __lt__(self, other):
rval = lt(self, other) rval = lt(self, other)
rval._is_nonzero=False rval._is_nonzero = False
return rval return rval
def __le__(self,other):
rval = le(self, other) def __le__(self, other):
rval._is_nonzero=False rval = le(self, other)
rval._is_nonzero = False
return rval return rval
def __gt__(self,other):
def __gt__(self, other):
rval = gt(self, other) rval = gt(self, other)
rval._is_nonzero=False rval._is_nonzero = False
return rval return rval
def __ge__(self,other):
def __ge__(self, other):
rval = ge(self, other) rval = ge(self, other)
rval._is_nonzero=False rval._is_nonzero = False
return rval return rval
def __nonzero__(self): def __nonzero__(self):
# This is meant to prohibit stuff like a < b < c, which is internally implemented as # This is meant to prohibit stuff like a < b < c, which is internally
# (a < b) and (b < c). The trouble with this is the side-effect that checking for a # implemented as (a < b) and (b < c). The trouble with this is the
# non-NULL a by typing "if a: ..." uses the same __nonzero__ method. We want these # side-effect that checking for a non-NULL a by typing "if a: ..."
# both to work, but it seems impossible. Currently, all vars evaluate to nonzero # uses the same __nonzero__ method. We want these both to work, but
# except the return values of comparison operators, which raise this exception. If you # it seems impossible. Currently, all vars evaluate to nonzero except
# can think of a better solution, go for it! # the return values of comparison operators, which raise this
# exception. If you can think of a better solution, go for it!
if self._is_nonzero: if self._is_nonzero:
return True return True
else: else:
raise TypeError("Variable does not support boolean operations.") raise TypeError("Variable does not support boolean operations.")
#BITWISE #BITWISE
def __invert__(self): return invert(self) def __invert__(self):
def __and__(self,other): return and_(self, other) return invert(self)
def __or__(self,other): return or_(self, other)
def __xor__(self,other): return xor(self, other) def __and__(self, other):
def __rand__(self,other): return and_(other,self) return and_(self, other)
def __ror__(self,other): return or_(other, self)
def __rxor__(self,other): return xor(other, self) def __or__(self, other):
# def __iand__(self, other): return _and_inplace(self, other) return or_(self, other)
# def __ior__(self, other): return _or_inplace(self, other)
# def __ixor__(self, other): return _xor_inplace(self, other) def __xor__(self, other):
return xor(self, other)
def __rand__(self, other):
return and_(other, self)
def __ror__(self, other):
return or_(other, self)
def __rxor__(self, other):
return xor(other, self)
#def __iand__(self, other):
# return _and_inplace(self, other)
#
#def __ior__(self, other):
# return _or_inplace(self, other)
#
#def __ixor__(self, other):
# return _xor_inplace(self, other)
#ARITHMETIC - NORMAL #ARITHMETIC - NORMAL
def __add__(self,other): def __add__(self, other):
try: try:
return add(self,other) return add(self, other)
# We should catch the minimum number of exception here. # We should catch the minimum number of exception here.
# Otherwise this will convert error when Theano flags # Otherwise this will convert error when Theano flags
# compute_test_value is used # compute_test_value is used
# Evidently, we need to catch NotImplementedError # Evidently, we need to catch NotImplementedError
# But we also need to catch TypeError # But we also need to catch TypeError
# Oterwise TensorVariable * SparseVariable won't work! # Oterwise TensorVariable * SparseVariable won't work!
except (NotImplementedError, TypeError), e: except (NotImplementedError, TypeError):
# We must return NotImplemented and not an # We must return NotImplemented and not an
# NotImplementedError or raise an NotImplementedError. # NotImplementedError or raise an NotImplementedError.
# That way python will give a good error message like this # That way python will give a good error message like this
# `TypeError: unsupported operand type(s) for +: # `TypeError: unsupported operand type(s) for +:
# 'TensorVariable' and 'TensorVariable'` # 'TensorVariable' and 'TensorVariable'`
return NotImplemented return NotImplemented
def __sub__(self,other):
def __sub__(self, other):
# See explanation in __add__ for the error catched # See explanation in __add__ for the error catched
# adn the return value in that case # adn the return value in that case
try: try:
return sub(self,other) return sub(self, other)
except (NotImplementedError, TypeError), e: except (NotImplementedError, TypeError):
return NotImplemented return NotImplemented
def __mul__(self,other):
def __mul__(self, other):
# See explanation in __add__ for the error catched # See explanation in __add__ for the error catched
# adn the return value in that case # adn the return value in that case
try: try:
return mul(self,other) return mul(self, other)
except (NotImplementedError, TypeError), e: except (NotImplementedError, TypeError):
return NotImplemented return NotImplemented
def __div__(self,other):
def __div__(self, other):
# See explanation in __add__ for the error catched # See explanation in __add__ for the error catched
# adn the return value in that case # adn the return value in that case
try: try:
return div_proxy(self,other) return div_proxy(self, other)
except IntegerDivisionError: except IntegerDivisionError:
# This is to raise the exception that occurs when trying to divide # This is to raise the exception that occurs when trying to divide
# two integer arrays (currently forbidden). # two integer arrays (currently forbidden).
raise raise
except (NotImplementedError, TypeError), e: except (NotImplementedError, TypeError):
return NotImplemented return NotImplemented
def __pow__(self,other):
def __pow__(self, other):
# See explanation in __add__ for the error catched # See explanation in __add__ for the error catched
# adn the return value in that case # adn the return value in that case
try: try:
return pow(self,other) return pow(self, other)
except (NotImplementedError, TypeError), e: except (NotImplementedError, TypeError):
return NotImplemented return NotImplemented
def __mod__(self,other):
def __mod__(self, other):
# See explanation in __add__ for the error catched # See explanation in __add__ for the error catched
# adn the return value in that case # adn the return value in that case
try: try:
...@@ -1306,29 +1369,56 @@ class _tensor_py_operators: ...@@ -1306,29 +1369,56 @@ class _tensor_py_operators:
# This is to raise the exception that occurs when trying to compute # This is to raise the exception that occurs when trying to compute
# x % y with either x or y a complex number. # x % y with either x or y a complex number.
raise raise
except (NotImplementedError, TypeError), e: except (NotImplementedError, TypeError):
return NotImplemented return NotImplemented
def __truediv__(self,other): return true_div(self, other) def __truediv__(self, other):
def __floordiv__(self,other): return floor_div(self, other) return true_div(self, other)
def __rtruediv__(self,other): return true_div(other, self)
def __rfloordiv__(self,other): return floor_div(other, self) def __floordiv__(self, other):
return floor_div(self, other)
# ##### DON"T USE THESE BECAUSE INPLACE OPS SHOULD BE INSERTED BY OPTIMIZATION ONLY
# #ARITHMETIC - INPLACE def __rtruediv__(self, other):
# def __iadd__(self,other): return _add_inplace(self,other) return true_div(other, self)
# def __isub__(self,other): return _sub_inplace(self,other)
# def __imul__(self,other): return _mul_inplace(self,other) def __rfloordiv__(self, other):
# def __idiv__(self,other): return _div_inplace(self,other) return floor_div(other, self)
# def __ipow__(self,other): return _pow_inplace(self,other)
##### DO NOT USE THESE BECAUSE INPLACE OPS SHOULD BE INSERTED
#ARITHMETIC - RIGHT-OPERAND ##### BY OPTIMIZATIONS ONLY
def __radd__(self,other): return add(other,self) ## ARITHMETIC - INPLACE
def __rsub__(self,other): return sub(other,self) #def __iadd__(self, other):
def __rmul__(self,other): return mul(other,self) # return _add_inplace(self, other)
def __rdiv__(self,other): return div_proxy(other,self) #def __isub__(self, other):
def __rmod__(self,other): return mod(other,self) # return _sub_inplace(self, other)
def __rpow__(self,other): return pow(other,self) #
#def __imul__(self, other):
# return _mul_inplace(self, other)
#
#def __idiv__(self, other):
# return _div_inplace(self, other)
#
#def __ipow__(self, other):
# return _pow_inplace(self, other)
# ARITHMETIC - RIGHT-OPERAND
def __radd__(self, other):
return add(other, self)
def __rsub__(self, other):
return sub(other, self)
def __rmul__(self, other):
return mul(other, self)
def __rdiv__(self, other):
return div_proxy(other, self)
def __rmod__(self, other):
return mod(other, self)
def __rpow__(self, other):
return pow(other, self)
#TRANSPOSE #TRANSPOSE
T = property(lambda self: transpose(self)) T = property(lambda self: transpose(self))
...@@ -1360,43 +1450,51 @@ class _tensor_py_operators: ...@@ -1360,43 +1450,51 @@ class _tensor_py_operators:
size = property(lambda self: prod(self.shape)) size = property(lambda self: prod(self.shape))
# We can't implement __len__ to provide a better error message. # We can't implement __len__ to provide a better error message.
def any(self, axis = None): def any(self, axis=None):
return elemwise.Any(axis)(self) return elemwise.Any(axis)(self)
def all(self, axis = None): def all(self, axis=None):
return elemwise.All(axis)(self) return elemwise.All(axis)(self)
# Otherwise TensorVariable[:-1] does not work as Python 2.5.1 calls # Otherwise TensorVariable[:-1] does not work as Python 2.5.1 calls
# __len__ before calling __getitem__. It also does not catch the raised # __len__ before calling __getitem__. It also does not catch the raised
# Exception! # Exception!
# def __len__(self): # def __len__(self):
# # We can't implement __len__ as Python requests that this # # We can't implement __len__ as Python requests that this
# # function returns an integer >=0 # # function returns an integer >=0
# raise Exception("Theano Variables can't work with len(Theano " # raise Exception("Theano Variables can't work with len(Theano "
# "Variable) due to Python restriction. You can use " # "Variable) due to Python restriction. You can use "
# "TheanoVariable.shape[0] instead.") # "TheanoVariable.shape[0] instead.")
def reshape(self, shape, ndim=None): def reshape(self, shape, ndim=None):
"""Return a reshaped view/copy of this variable. """Return a reshaped view/copy of this variable.
:param shape: something that can be converted to a symbolic vector of integers :param shape: something that can be converted to a symbolic vector of
integers
:param ndim: the length of the shape. Passing None here means for theano to try and :param ndim: the length of the shape. Passing None here means for
guess the length of `shape`. theano to try and guess the length of `shape`.
""" """
return reshape(self, shape, ndim=ndim) return reshape(self, shape, ndim=ndim)
def dimshuffle(self, *pattern): def dimshuffle(self, *pattern):
"""Reorder the dimensions of this variable, optionally inserting broadcasted dimensions. """
Reorder the dimensions of this variable, optionally inserting
broadcasted dimensions.
:param pattern: list/tuple of int mixed with 'x' for broadcastable dimensions :param pattern: list/tuple of int mixed with 'x' for broadcastable
dimensions
For example, to create a 3D view of a [2D] matrix, call ``dimshuffle([0,'x',1])``. This For example, to create a 3D view of a [2D] matrix, call
will create a 3D view such that the middle dimension is an implicit broadcasted ``dimshuffle([0,'x',1])``. This will create a 3D view such that the
dimension. To do the same thing on the transpose of that matrix, call ``dimshuffle([1, middle dimension is an implicit broadcasted dimension. To do the same
'x', 0])``. thing on the transpose of that matrix, call
``dimshuffle([1, 'x', 0])``.
This function supports the pattern passed as a tuple, or as a variable-length argument (e.g. ``a.dimshuffle(pattern)`` is equivalent to ``a.dimshuffle(*pattern)`` where ``pattern`` is a list/tuple of ints mixed with 'x' characters). This function supports the pattern passed as a tuple, or as a
variable-length argument (e.g. ``a.dimshuffle(pattern)`` is equivalent
to ``a.dimshuffle(*pattern)`` where ``pattern`` is a list/tuple of ints
mixed with 'x' characters).
For more information, see `DimShuffle`. For more information, see `DimShuffle`.
""" """
...@@ -1447,7 +1545,8 @@ class _tensor_py_operators: ...@@ -1447,7 +1545,8 @@ class _tensor_py_operators:
else: else:
return AdvancedSubtensor()(self, *args) return AdvancedSubtensor()(self, *args)
else: else:
return Subtensor(args)(self, *Subtensor.collapse(args, lambda entry: isinstance(entry, Variable))) return Subtensor(args)(self, *Subtensor.collapse(args,
lambda entry: isinstance(entry, Variable)))
#COPYING #COPYING
def copy(self): def copy(self):
...@@ -1457,11 +1556,11 @@ class _tensor_py_operators: ...@@ -1457,11 +1556,11 @@ class _tensor_py_operators:
try: try:
for i in xrange(get_vector_length(self)): for i in xrange(get_vector_length(self)):
yield self[i] yield self[i]
except TypeError, e: except TypeError:
# This prevents accidental iteration via builtin.sum(self) # This prevents accidental iteration via builtin.sum(self)
raise TypeError('TensorType does not support iteration. ' raise TypeError(('TensorType does not support iteration. '
'Maybe you are using builtin.sum instead of theano.tensor.sum? (Maybe .max?)') 'Maybe you are using builtin.sum instead of '
'theano.tensor.sum? (Maybe .max?)'))
# CONVENIENT ACCESS TO TYPE PROPERTIES # CONVENIENT ACCESS TO TYPE PROPERTIES
ndim = property(lambda self: self.type.ndim) ndim = property(lambda self: self.type.ndim)
...@@ -1471,7 +1570,6 @@ class _tensor_py_operators: ...@@ -1471,7 +1570,6 @@ class _tensor_py_operators:
"""The broadcastable signature of this tensor. """The broadcastable signature of this tensor.
See :doc:`broadcasting` for details. See :doc:`broadcasting` for details.
""" """
dtype = property(lambda self: self.type.dtype) dtype = property(lambda self: self.type.dtype)
...@@ -1493,12 +1591,12 @@ class _tensor_py_operators: ...@@ -1493,12 +1591,12 @@ class _tensor_py_operators:
return prod(self, axis=axis, dtype=dtype) return prod(self, axis=axis, dtype=dtype)
def norm(self, L, axis=None): def norm(self, L, axis=None):
if L==0: if L == 0:
raise NotImplementedError() raise NotImplementedError()
if numpy.isinf(L): if numpy.isinf(L):
raise NotImplementedError() raise NotImplementedError()
#optimizations will/should catch cases like L=1, L=2 #optimizations will/should catch cases like L=1, L=2
return pow(pow(abs_(self), L).sum(axis=axis), 1.0/L) return pow(pow(abs_(self), L).sum(axis=axis), 1.0 / L)
def mean(self, axis=None, dtype=None): def mean(self, axis=None, dtype=None):
"""See `theano.tensor.mean`""" """See `theano.tensor.mean`"""
...@@ -1521,6 +1619,7 @@ class _tensor_py_operators: ...@@ -1521,6 +1619,7 @@ class _tensor_py_operators:
def get_constant_value(self): def get_constant_value(self):
return get_constant_value(self) return get_constant_value(self)
def zeros_like(model): def zeros_like(model):
return zeros_like(model) return zeros_like(model)
...@@ -1540,17 +1639,19 @@ class TensorConstantSignature(tuple): ...@@ -1540,17 +1639,19 @@ class TensorConstantSignature(tuple):
if type(self) != type(other): if type(self) != type(other):
return False return False
try: try:
(t0, d0), (t1,d1) = self, other (t0, d0), (t1, d1) = self, other
except Exception, e: except Exception:
return False return False
#N.B. compare shape to ensure no broadcasting in == #N.B. compare shape to ensure no broadcasting in ==
if t0 != t1 or d0.shape != d1.shape: if t0 != t1 or d0.shape != d1.shape:
return False return False
no_nan = self.no_nan # Ensure has_nan is computed.
self.no_nan # Ensure has_nan is computed.
# Note that in the comparisons below, the elementwise comparisons # Note that in the comparisons below, the elementwise comparisons
# come last because they are the most expensive checks. # come last because they are the most expensive checks.
if self.has_nan: if self.has_nan:
other_no_nan = other.no_nan other.no_nan # Ensure has_nan is computed.
return (other.has_nan and return (other.has_nan and
self.sum == other.sum and self.sum == other.sum and
(self.no_nan.mask == other.no_nan.mask).all() and (self.no_nan.mask == other.no_nan.mask).all() and
...@@ -1620,7 +1721,7 @@ class TensorConstant(_tensor_py_operators, Constant): ...@@ -1620,7 +1721,7 @@ class TensorConstant(_tensor_py_operators, Constant):
To create a TensorConstant, use the `constant` function in this module. To create a TensorConstant, use the `constant` function in this module.
""" """
def __init__(self, type, data, name = None): def __init__(self, type, data, name=None):
Constant.__init__(self, type, data, name) Constant.__init__(self, type, data, name)
if (isinstance(data, numpy.ndarray) and if (isinstance(data, numpy.ndarray) and
data.ndim > 0 and data.ndim > 0 and
...@@ -1631,12 +1732,12 @@ class TensorConstant(_tensor_py_operators, Constant): ...@@ -1631,12 +1732,12 @@ class TensorConstant(_tensor_py_operators, Constant):
def __str__(self): def __str__(self):
if self.tag.unique_value is not None: if self.tag.unique_value is not None:
name = "%s of %s"%(str(self.data.shape), name = "%s of %s" % (str(self.data.shape),
str(self.tag.unique_value)) str(self.tag.unique_value))
else: else:
name = "%s"%self.data name = "%s" % self.data
if len(name) > 20: if len(name) > 20:
name = name[:10]+".."+name[-10:] name = name[:10] + ".." + name[-10:]
return "TensorConstant{%s}" % name return "TensorConstant{%s}" % name
...@@ -1677,15 +1778,19 @@ def _redefine(real_symbol_value, module='tensor'): ...@@ -1677,15 +1778,19 @@ def _redefine(real_symbol_value, module='tensor'):
This is useful to trick epydoc into doing what we want. It's a hack. This is useful to trick epydoc into doing what we want. It's a hack.
""" """
real_symbol_value.__module__ = 'tensor.basic' real_symbol_value.__module__ = 'tensor.basic'
def decorator(f): def decorator(f):
return real_symbol_value return real_symbol_value
return decorator return decorator
def _redefine_asRoutine(real_symbol_value): def _redefine_asRoutine(real_symbol_value):
real_symbol_value.__epydoc_asRoutine = True real_symbol_value.__epydoc_asRoutine = True
def decorator(f): def decorator(f):
return real_symbol_value return real_symbol_value
return decorator return decorator
...@@ -1707,17 +1812,18 @@ def _scal_elemwise_with_nfunc(nfunc, nin, nout): ...@@ -1707,17 +1812,18 @@ def _scal_elemwise_with_nfunc(nfunc, nin, nout):
msg = "inplace" msg = "inplace"
else: else:
msg = "no_inplace" msg = "no_inplace"
n="Elemwise{%s,%s}"%(symbolname,msg)
n = "Elemwise{%s,%s}" % (symbolname, msg)
if inplace: if inplace:
scalar_op = getattr(scal, symbolname[:-len('_inplace')]) scalar_op = getattr(scal, symbolname[:-len('_inplace')])
inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0)) inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
rval = elemwise.Elemwise(inplace_scalar_op, {0: 0}, name=n, rval = elemwise.Elemwise(inplace_scalar_op, {0: 0}, name=n,
nfunc_spec = nfunc and (nfunc, nin, nout)) nfunc_spec=(nfunc and (nfunc, nin, nout)))
else: else:
scalar_op = getattr(scal, symbolname) scalar_op = getattr(scal, symbolname)
rval = elemwise.Elemwise(scalar_op, name=n, rval = elemwise.Elemwise(scalar_op, name=n,
nfunc_spec = nfunc and (nfunc, nin, nout)) nfunc_spec=(nfunc and (nfunc, nin, nout)))
if getattr(symbol, '__doc__', False): if getattr(symbol, '__doc__', False):
rval.__doc__ = symbol.__doc__ + '\n' + rval.__doc__ rval.__doc__ = symbol.__doc__ + '\n' + rval.__doc__
...@@ -1744,35 +1850,44 @@ class TensorFromScalar(Op): ...@@ -1744,35 +1850,44 @@ class TensorFromScalar(Op):
assert isinstance(s.type, scal.Scalar) assert isinstance(s.type, scal.Scalar)
return Apply(self, return Apply(self,
[s], [s],
[tensor(dtype = s.type.dtype, [tensor(dtype=s.type.dtype,
broadcastable = ())]) broadcastable=())])
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
s, = inp s, = inp
out, = out_ out, = out_
out[0] = numpy.asarray(s) out[0] = numpy.asarray(s)
def grad(self, inp, grads): def grad(self, inp, grads):
s, = inp s, = inp
dt, = grads dt, = grads
return [scalar_from_tensor(dt)] return [scalar_from_tensor(dt)]
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
tensor_from_scalar = TensorFromScalar() tensor_from_scalar = TensorFromScalar()
class ScalarFromTensor(Op): class ScalarFromTensor(Op):
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def make_node(self, t): def make_node(self, t):
assert isinstance(t.type, TensorType) assert isinstance(t.type, TensorType)
assert t.type.broadcastable == () assert t.type.broadcastable == ()
return Apply(self, return Apply(self,
[t], [t],
[scal.Scalar(dtype = t.type.dtype).make_variable()]) [scal.Scalar(dtype=t.type.dtype).make_variable()])
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
s, = inp s, = inp
out, = out_ out, = out_
out[0] = s.flatten()[0] out[0] = s.flatten()[0]
def grad(self, inp, grads): def grad(self, inp, grads):
s, = inp s, = inp
dt, = grads dt, = grads
...@@ -1785,66 +1900,81 @@ class ScalarFromTensor(Op): ...@@ -1785,66 +1900,81 @@ class ScalarFromTensor(Op):
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
x, = inputs x, = inputs
z, = outputs z, = outputs
fail = sub['fail'] fail = sub['fail']
return """ return """
%(z)s = ((dtype_%(x)s*)(%(x)s->data))[0]; %(z)s = ((dtype_%(x)s*)(%(x)s->data))[0];
"""%locals() """ % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (1,)
scalar_from_tensor = ScalarFromTensor() scalar_from_tensor = ScalarFromTensor()
#to be removed as we get the epydoc routine-documenting thing going -JB 20080924 #to be removed as we get the epydoc routine-documenting thing going
#-JB 20080924
def _conversion(real_value, name): def _conversion(real_value, name):
__oplist_tag(real_value, 'casting') __oplist_tag(real_value, 'casting')
real_value.__module__='tensor.basic' real_value.__module__ = 'tensor.basic'
pprint.assign(real_value, printing.FunctionPrinter(name)) pprint.assign(real_value, printing.FunctionPrinter(name))
return real_value return real_value
# # These _conver_to_<type> functions have leading underscores to indicate that
# These _conver_to_<type> functions have leading underscores to indicate that they should not # they should not be called directly. They do not perform sanity checks about
# be called directly. They do not perform sanity checks about what types you are casting to # what types you are casting to what. That logic is implemented by the
# what. That logic is implemented by the `cast()` function below. # `cast()` function below.
#
_convert_to_int8 = _conversion(elemwise.Elemwise(scal.convert_to_int8), 'int8') _convert_to_int8 = _conversion(
elemwise.Elemwise(scal.convert_to_int8), 'int8')
"""Cast to 8-bit integer""" """Cast to 8-bit integer"""
_convert_to_int16 = _conversion(elemwise.Elemwise(scal.convert_to_int16), 'int16') _convert_to_int16 = _conversion(
elemwise.Elemwise(scal.convert_to_int16), 'int16')
"""Cast to 16-bit integer""" """Cast to 16-bit integer"""
_convert_to_int32 = _conversion(elemwise.Elemwise(scal.convert_to_int32), 'int32') _convert_to_int32 = _conversion(
elemwise.Elemwise(scal.convert_to_int32), 'int32')
"""Cast to 32-bit integer""" """Cast to 32-bit integer"""
_convert_to_int64 = _conversion(elemwise.Elemwise(scal.convert_to_int64), 'int64') _convert_to_int64 = _conversion(
elemwise.Elemwise(scal.convert_to_int64), 'int64')
"""Cast to 64-bit integer""" """Cast to 64-bit integer"""
_convert_to_uint8 = _conversion(elemwise.Elemwise(scal.convert_to_uint8), 'uint8') _convert_to_uint8 = _conversion(
elemwise.Elemwise(scal.convert_to_uint8), 'uint8')
"""Cast to unsigned 8-bit integer""" """Cast to unsigned 8-bit integer"""
_convert_to_uint16 = _conversion(elemwise.Elemwise(scal.convert_to_uint16), 'uint16') _convert_to_uint16 = _conversion(
elemwise.Elemwise(scal.convert_to_uint16), 'uint16')
"""Cast to unsigned 16-bit integer""" """Cast to unsigned 16-bit integer"""
_convert_to_uint32 = _conversion(elemwise.Elemwise(scal.convert_to_uint32), 'uint32') _convert_to_uint32 = _conversion(
elemwise.Elemwise(scal.convert_to_uint32), 'uint32')
"""Cast to unsigned 32-bit integer""" """Cast to unsigned 32-bit integer"""
_convert_to_uint64 = _conversion(elemwise.Elemwise(scal.convert_to_uint64), 'uint64') _convert_to_uint64 = _conversion(
elemwise.Elemwise(scal.convert_to_uint64), 'uint64')
"""Cast to unsigned 64-bit integer""" """Cast to unsigned 64-bit integer"""
_convert_to_float32 = _conversion(elemwise.Elemwise(scal.convert_to_float32), 'float32') _convert_to_float32 = _conversion(
elemwise.Elemwise(scal.convert_to_float32), 'float32')
"""Cast to single-precision floating point""" """Cast to single-precision floating point"""
_convert_to_float64 = _conversion(elemwise.Elemwise(scal.convert_to_float64), 'float64') _convert_to_float64 = _conversion(
elemwise.Elemwise(scal.convert_to_float64), 'float64')
"""Cast to double-precision floating point""" """Cast to double-precision floating point"""
_convert_to_complex64 = _conversion(elemwise.Elemwise(scal.convert_to_complex64), 'complex64') _convert_to_complex64 = _conversion(
elemwise.Elemwise(scal.convert_to_complex64), 'complex64')
"""Cast to single-precision complex""" """Cast to single-precision complex"""
_convert_to_complex128 = _conversion(elemwise.Elemwise(scal.convert_to_complex128), 'complex128') _convert_to_complex128 = _conversion(
elemwise.Elemwise(scal.convert_to_complex128), 'complex128')
"""Cast to double-precision complex""" """Cast to double-precision complex"""
_cast_mapping = { _cast_mapping = {
...@@ -1860,20 +1990,24 @@ _cast_mapping = { ...@@ -1860,20 +1990,24 @@ _cast_mapping = {
'float64': _convert_to_float64, 'float64': _convert_to_float64,
'complex64': _convert_to_complex64, 'complex64': _convert_to_complex64,
'complex128': _convert_to_complex128} 'complex128': _convert_to_complex128}
@constructor @constructor
def cast(x, dtype): def cast(x, dtype):
"""Symbolically cast `x` to a Tensor of type `dtype`.""" """Symbolically cast `x` to a Tensor of type `dtype`."""
if dtype=='floatX': dtype = config.floatX if dtype == 'floatX':
dtype = config.floatX
_x = as_tensor_variable(x) _x = as_tensor_variable(x)
if _x.type.dtype == dtype: if _x.type.dtype == dtype:
return _x return _x
if _x.type.dtype.startswith('complex') and not dtype.startswith('complex'): if _x.type.dtype.startswith('complex') and not dtype.startswith('complex'):
raise TypeError('Casting from complex to real is ambiguous: consider real(), imag(), angle() or abs()') raise TypeError((
'Casting from complex to real is ambiguous: consider real(), '
'imag(), angle() or abs()'))
return _cast_mapping[dtype](x) return _cast_mapping[dtype](x)
########################## ##########################
# Unary Operations # Unary Operations
########################## ##########################
...@@ -1886,10 +2020,13 @@ class Shape(Op): ...@@ -1886,10 +2020,13 @@ class Shape(Op):
""" """
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
def make_node(self, x): def make_node(self, x):
#Must work for all type that have a shape attribute. #Must work for all type that have a shape attribute.
#This will fail at execution time. #This will fail at execution time.
...@@ -1899,21 +2036,29 @@ class Shape(Op): ...@@ -1899,21 +2036,29 @@ class Shape(Op):
#the type to TensorVariable to have the optimization working #the type to TensorVariable to have the optimization working
#correctly. #correctly.
return Apply(self, [x], [lvector()]) return Apply(self, [x], [lvector()])
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
x, = inp x, = inp
out, = out_ out, = out_
out[0] = theano._asarray(x.shape, dtype = 'int64') out[0] = theano._asarray(x.shape, dtype='int64')
def grad(self, inp, grads): def grad(self, inp, grads):
return [None] return [None]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
return [None] return [None]
@constructor @constructor
def old_shape(a): def old_shape(a):
"""Return the shape tuple of a TensorType Variable, it may be either symbolic or nonsymbolic. """
Return the shape tuple of a TensorType Variable.
It may be either symbolic or nonsymbolic.
If the shape of the expression is not known at graph-construction time, then a symbolic If the shape of the expression is not known at graph-construction time,
lvector will be returned, corresponding to the actual shape at graph-execution time. then a symbolic lvector will be returned, corresponding to the actual
shape at graph-execution time.
""" """
va = as_tensor_variable(a) va = as_tensor_variable(a)
#print 'HERE', va, va.type #print 'HERE', va, va.type
...@@ -1926,7 +2071,7 @@ def old_shape(a): ...@@ -1926,7 +2071,7 @@ def old_shape(a):
return va.type.shape return va.type.shape
shape = Shape() shape = Shape()
_shape = shape #was used in the past, now use shape directly. _shape = shape # was used in the past, now use shape directly.
pprint.assign(_shape, printing.MemberPrinter('shape')) pprint.assign(_shape, printing.MemberPrinter('shape'))
...@@ -1974,7 +2119,7 @@ class SpecifyShape(Op): ...@@ -1974,7 +2119,7 @@ class SpecifyShape(Op):
s = get_constant_value(node.inputs[1][dim]) s = get_constant_value(node.inputs[1][dim])
s = as_tensor_variable(s) s = as_tensor_variable(s)
new_shape.append(s) new_shape.append(s)
except TypeError, e: except TypeError:
new_shape.append(node.inputs[1][dim]) new_shape.append(node.inputs[1][dim])
assert len(new_shape) == len(xshape) assert len(new_shape) == len(xshape)
...@@ -2185,16 +2330,25 @@ def argmin(x, axis=None): ...@@ -2185,16 +2330,25 @@ def argmin(x, axis=None):
@constructor @constructor
def smallest(*args): def smallest(*args):
"""Return the [elementwise] smallest of a variable number of arguments (like python's min).""" """
Return the [elementwise] smallest of a variable number of arguments.
Like python's min.
"""
if len(args) == 2: if len(args) == 2:
a, b = args a, b = args
return switch(a < b, a, b) return switch(a < b, a, b)
else: else:
return min(stack(*args), axis=0) return min(stack(*args), axis=0)
@constructor @constructor
def largest(*args): def largest(*args):
"""Return the [elementwise] largest of a variable number of arguments (like python's max).""" """
Return the [elementwise] largest of a variable number of arguments.
Like python's max.
"""
if len(args) == 2: if len(args) == 2:
a, b = args a, b = args
return switch(a > b, a, b) return switch(a > b, a, b)
...@@ -2210,30 +2364,37 @@ def largest(*args): ...@@ -2210,30 +2364,37 @@ def largest(*args):
def lt(a, b): def lt(a, b):
"""a < b""" """a < b"""
@_scal_elemwise_with_nfunc('greater', 2, 1) @_scal_elemwise_with_nfunc('greater', 2, 1)
def gt(a, b): def gt(a, b):
"""a > b""" """a > b"""
@_scal_elemwise_with_nfunc('less_equal', 2, 1) @_scal_elemwise_with_nfunc('less_equal', 2, 1)
def le(a, b): def le(a, b):
"""a <= b""" """a <= b"""
@_scal_elemwise_with_nfunc('greater_equal', 2, 1) @_scal_elemwise_with_nfunc('greater_equal', 2, 1)
def ge(a, b): def ge(a, b):
"""a >= b""" """a >= b"""
@_scal_elemwise_with_nfunc('equal', 2, 1) @_scal_elemwise_with_nfunc('equal', 2, 1)
def eq(a, b): def eq(a, b):
"""a == b""" """a == b"""
@_scal_elemwise_with_nfunc('not_equal', 2, 1) @_scal_elemwise_with_nfunc('not_equal', 2, 1)
def neq(a, b): def neq(a, b):
"""a != b""" """a != b"""
@_scal_elemwise_with_nfunc('isnan', 1, 1) @_scal_elemwise_with_nfunc('isnan', 1, 1)
def isnan(a): def isnan(a):
"""isnan(a)""" """isnan(a)"""
@_scal_elemwise_with_nfunc('isinf', 1, 1) @_scal_elemwise_with_nfunc('isinf', 1, 1)
def isinf(a): def isinf(a):
"""isinf(a)""" """isinf(a)"""
...@@ -2253,24 +2414,27 @@ def switch(cond, ift, iff): ...@@ -2253,24 +2414,27 @@ def switch(cond, ift, iff):
########################## ##########################
@_scal_elemwise_with_nfunc('bitwise_and', 2, 1) @_scal_elemwise_with_nfunc('bitwise_and', 2, 1)
def and_(a,b): def and_(a, b):
"""bitwise a & b""" """bitwise a & b"""
bitwise_and = and_ # numpy name for it bitwise_and = and_ # numpy name for it
@_scal_elemwise_with_nfunc('bitwise_or', 2, 1) @_scal_elemwise_with_nfunc('bitwise_or', 2, 1)
def or_(a,b): def or_(a, b):
"""bitwise a | b""" """bitwise a | b"""
bitwise_or = or_ # numpy name for it bitwise_or = or_ # numpy name for it
@_scal_elemwise_with_nfunc('bitwise_xor', 2, 1) @_scal_elemwise_with_nfunc('bitwise_xor', 2, 1)
def xor(a,b): def xor(a, b):
"""bitwise a ^ b""" """bitwise a ^ b"""
bitwise_xor = xor # numpy name for it bitwise_xor = xor # numpy name for it
@_scal_elemwise_with_nfunc('invert', 1, 1) @_scal_elemwise_with_nfunc('invert', 1, 1)
def invert(a): def invert(a):
"""bitwise ~a""" """bitwise ~a"""
bitwise_not = invert # numpy alias for it bitwise_not = invert # numpy alias for it
########################## ##########################
...@@ -2288,50 +2452,64 @@ def abs_(a): ...@@ -2288,50 +2452,64 @@ def abs_(a):
pprint.assign(abs_, printing.PatternPrinter(('|%(0)s|', -1000))) pprint.assign(abs_, printing.PatternPrinter(('|%(0)s|', -1000)))
@_scal_elemwise_with_nfunc('exp', 1, 1) @_scal_elemwise_with_nfunc('exp', 1, 1)
def exp(a): def exp(a):
"""e^`a`""" """e^`a`"""
@_scal_elemwise_with_nfunc('negative', 1, 1) @_scal_elemwise_with_nfunc('negative', 1, 1)
def neg(a): def neg(a):
"""-a""" """-a"""
@_scal_elemwise # numpy.reciprocal does integer division on integer inputs (which is not very interesting)
# numpy.reciprocal does integer division on integer inputs
# (which is not very interesting)
@_scal_elemwise
def inv(a): def inv(a):
"""1.0/a""" """1.0/a"""
@_scal_elemwise_with_nfunc('log', 1, 1) @_scal_elemwise_with_nfunc('log', 1, 1)
def log(a): def log(a):
"""base e logarithm of a""" """base e logarithm of a"""
@_scal_elemwise_with_nfunc('log2', 1, 1) @_scal_elemwise_with_nfunc('log2', 1, 1)
def log2(a): def log2(a):
"""base 2 logarithm of a""" """base 2 logarithm of a"""
@_scal_elemwise_with_nfunc('log10', 1, 1) @_scal_elemwise_with_nfunc('log10', 1, 1)
def log10(a): def log10(a):
"""base 10 logarithm of a""" """base 10 logarithm of a"""
@_scal_elemwise_with_nfunc('log1p', 1, 1) @_scal_elemwise_with_nfunc('log1p', 1, 1)
def log1p(a): def log1p(a):
"""log(1+a)""" """log(1+a)"""
@_scal_elemwise_with_nfunc('sign', 1, 1) @_scal_elemwise_with_nfunc('sign', 1, 1)
def sgn(a): def sgn(a):
"""sign of a""" """sign of a"""
@_scal_elemwise_with_nfunc('ceil', 1, 1) @_scal_elemwise_with_nfunc('ceil', 1, 1)
def ceil(a): def ceil(a):
"""ceiling of a""" """ceiling of a"""
@_scal_elemwise_with_nfunc('floor', 1, 1) @_scal_elemwise_with_nfunc('floor', 1, 1)
def floor(a): def floor(a):
"""floor of a""" """floor of a"""
@constructor @constructor
def iround(a, mode="half_away_from_zero"): def iround(a, mode="half_away_from_zero"):
"""cast(round(a,mode),'int64')""" """cast(round(a,mode),'int64')"""
return cast(round(a,mode),'int64') return cast(round(a, mode), 'int64')
@constructor @constructor
def round(a, mode="half_away_from_zero"): def round(a, mode="half_away_from_zero"):
...@@ -2341,80 +2519,99 @@ def round(a, mode="half_away_from_zero"): ...@@ -2341,80 +2519,99 @@ def round(a, mode="half_away_from_zero"):
elif mode == "half_to_even": elif mode == "half_to_even":
return round_half_to_even(a) return round_half_to_even(a)
else: else:
raise Exception("round mode %s is not implemented."%mode) raise Exception("round mode %s is not implemented." % mode)
@_scal_elemwise_with_nfunc('around', 1, -1) @_scal_elemwise_with_nfunc('around', 1, -1)
def round_half_to_even(a): def round_half_to_even(a):
"""round_half_to_even(a)""" """round_half_to_even(a)"""
@_scal_elemwise @_scal_elemwise
def round_half_away_from_zero(a): def round_half_away_from_zero(a):
"""round_half_away_from_zero(a)""" """round_half_away_from_zero(a)"""
@_scal_elemwise_with_nfunc('square', 1, 1) @_scal_elemwise_with_nfunc('square', 1, 1)
def sqr(a): def sqr(a):
"""square of a""" """square of a"""
@_scal_elemwise_with_nfunc('sqrt', 1, 1) @_scal_elemwise_with_nfunc('sqrt', 1, 1)
def sqrt(a): def sqrt(a):
"""square root of a""" """square root of a"""
@_scal_elemwise_with_nfunc('cos', 1, 1) @_scal_elemwise_with_nfunc('cos', 1, 1)
def cos(a): def cos(a):
"""cosine of a""" """cosine of a"""
@_scal_elemwise_with_nfunc('arccos',1,1)
@_scal_elemwise_with_nfunc('arccos', 1, 1)
def arccos(a): def arccos(a):
"""arccosine of a""" """arccosine of a"""
@_scal_elemwise_with_nfunc('sin', 1, 1) @_scal_elemwise_with_nfunc('sin', 1, 1)
def sin(a): def sin(a):
"""sine of a""" """sine of a"""
@_scal_elemwise_with_nfunc('tan', 1, 1) @_scal_elemwise_with_nfunc('tan', 1, 1)
def tan(a): def tan(a):
"""tangent of a""" """tangent of a"""
@_scal_elemwise_with_nfunc('cosh', 1, 1) @_scal_elemwise_with_nfunc('cosh', 1, 1)
def cosh(a): def cosh(a):
"""hyperbolic cosine of a""" """hyperbolic cosine of a"""
@_scal_elemwise_with_nfunc('sinh', 1, 1) @_scal_elemwise_with_nfunc('sinh', 1, 1)
def sinh(a): def sinh(a):
"""hyperbolic sine of a""" """hyperbolic sine of a"""
@_scal_elemwise_with_nfunc('tanh', 1, 1) @_scal_elemwise_with_nfunc('tanh', 1, 1)
def tanh(a): def tanh(a):
"""hyperbolic tangent of a""" """hyperbolic tangent of a"""
@_scal_elemwise @_scal_elemwise
def erf(a): def erf(a):
"""error function""" """error function"""
@_scal_elemwise @_scal_elemwise
def erfc(a): def erfc(a):
"""complementary error function""" """complementary error function"""
@_scal_elemwise_with_nfunc('real', 1, -1) @_scal_elemwise_with_nfunc('real', 1, -1)
def real(z): def real(z):
"""Return real component of complex-valued tensor `z`""" """Return real component of complex-valued tensor `z`"""
@_scal_elemwise_with_nfunc('imag', 1, -1) @_scal_elemwise_with_nfunc('imag', 1, -1)
def imag(z): def imag(z):
"""Return imaginary component of complex-valued tensor `z`""" """Return imaginary component of complex-valued tensor `z`"""
@_scal_elemwise_with_nfunc('angle', 1, -1) @_scal_elemwise_with_nfunc('angle', 1, -1)
def angle(z): def angle(z):
"""Return polar-coordinate angle of complex-valued tensor `z`""" """Return polar-coordinate angle of complex-valued tensor `z`"""
@_scal_elemwise # numpy.complex cannot build tensors
@_scal_elemwise # numpy.complex cannot build tensors
def complex(real, imag): def complex(real, imag):
"""Return complex-valued tensor with `real` and `imag` components""" """Return complex-valued tensor with `real` and `imag` components"""
@_scal_elemwise @_scal_elemwise
def complex_from_polar(abs, angle): def complex_from_polar(abs, angle):
"""Return complex-valued tensor from polar coordinate specification""" """Return complex-valued tensor from polar coordinate specification"""
########################## ##########################
# Misc # Misc
########################## ##########################
...@@ -2434,9 +2631,10 @@ def ones_like(model, dtype=None): ...@@ -2434,9 +2631,10 @@ def ones_like(model, dtype=None):
"""equivalent of numpy.ones_like""" """equivalent of numpy.ones_like"""
if dtype is None: if dtype is None:
dtype = model.type.dtype dtype = model.type.dtype
ret= fill(model, constant(1.0, dtype=dtype)) ret = fill(model, constant(1.0, dtype=dtype))
return ret return ret
@constructor @constructor
def zeros_like(model, dtype=None): def zeros_like(model, dtype=None):
"""equivalent of numpy.zeros_like""" """equivalent of numpy.zeros_like"""
...@@ -2444,6 +2642,7 @@ def zeros_like(model, dtype=None): ...@@ -2444,6 +2642,7 @@ def zeros_like(model, dtype=None):
dtype = model.type.dtype dtype = model.type.dtype
return fill(model, constant(0.0, dtype=dtype)) return fill(model, constant(0.0, dtype=dtype))
def zeros(shape, dtype=config.floatX): def zeros(shape, dtype=config.floatX):
""" """
Create a Tensor filled with zeros, closer to Numpy's syntax than ``alloc``. Create a Tensor filled with zeros, closer to Numpy's syntax than ``alloc``.
...@@ -2458,39 +2657,41 @@ def ones(shape, dtype=config.floatX): ...@@ -2458,39 +2657,41 @@ def ones(shape, dtype=config.floatX):
return alloc(numpy.array(1, dtype=dtype), *shape) return alloc(numpy.array(1, dtype=dtype), *shape)
class Eye(gof.Op): class Eye(gof.Op):
def __init__(self, dtype=config.floatX): def __init__(self, dtype=config.floatX):
self.dtype = dtype self.dtype = dtype
def make_node(self,n,m,k):
def make_node(self, n, m, k):
n = as_tensor_variable(n) n = as_tensor_variable(n)
m = as_tensor_variable(m) m = as_tensor_variable(m)
k = as_tensor_variable(k) k = as_tensor_variable(k)
return gof.Apply(self, [n,m,k], [TensorType(dtype = self.dtype, broadcastable = (False,False))()]) return gof.Apply(self, [n, m, k],
[TensorType(dtype=self.dtype, broadcastable=(False, False))()])
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
n, m, k = inp n, m, k = inp
out, = out_ out, = out_
out[0] = numpy.eye(n,m,k,dtype=self.dtype) out[0] = numpy.eye(n, m, k, dtype=self.dtype)
def grad(self, inp, grads): def grad(self, inp, grads):
return [None, None, None] return [None, None, None]
def __eq__(self,other): def __eq__(self, other):
return type(self) == type(other) and self.dtype == other.dtype return type(self) == type(other) and self.dtype == other.dtype
def __hash__(self): def __hash__(self):
return hash(self.dtype) ^ hash(type(self)) return hash(self.dtype) ^ hash(type(self))
def eye(n, m=None, k = 0, dtype = config.floatX): def eye(n, m=None, k=0, dtype=config.floatX):
if m == None: if m == None:
m = n m = n
localop = Eye(dtype) localop = Eye(dtype)
return localop(n,m,k) return localop(n, m, k)
def identity_like(x): def identity_like(x):
return eye(x.shape[0], x.shape[1], k=0, dtype = x.dtype) return eye(x.shape[0], x.shape[1], k=0, dtype=x.dtype)
if 0: if 0:
## COMMENTED OUT FEB 17 2010 ## COMMENTED OUT FEB 17 2010
...@@ -2552,20 +2753,22 @@ if 0: ...@@ -2552,20 +2753,22 @@ if 0:
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Filler) and r.owner.op.value == 0, printing.FunctionPrinter('zeros')) pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Filler) and r.owner.op.value == 0, printing.FunctionPrinter('zeros'))
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Filler) and r.owner.op.value == 1, printing.FunctionPrinter('ones')) pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Filler) and r.owner.op.value == 1, printing.FunctionPrinter('ones'))
class Alloc(gof.Op): class Alloc(gof.Op):
"""Create a Tensor from an initial value and a desired shape """Create a Tensor from an initial value and a desired shape
alloc(value, shape0, shape1, ..., shapeN) alloc(value, shape0, shape1, ..., shapeN)
Returns an N-dimensional tensor initialized by `value` using something equivalent to Returns an N-dimensional tensor initialized by `value` using something
equivalent to
>>> z = numpy.zeros(shape, value.dtype) >>> z = numpy.zeros(shape, value.dtype)
>>> z += value >>> z += value
The result has N dimensions, has the dtype of `value` and is obtained by broadcasting value The result has N dimensions, has the dtype of `value` and is obtained by
over the output ndarray. broadcasting value over the output ndarray.
This Op is used to replace fill() during optimizations because after shapes are lifted, This Op is used to replace fill() during optimizations because after shapes
the first argument to fill can often be pruned from the graph. are lifted, the first argument to fill can often be pruned from the graph.
""" """
def __init__(self): def __init__(self):
pass pass
...@@ -2599,7 +2802,7 @@ class Alloc(gof.Op): ...@@ -2599,7 +2802,7 @@ class Alloc(gof.Op):
const_shp = None const_shp = None
bcast.append(numpy.all(1 == const_shp)) bcast.append(numpy.all(1 == const_shp))
otype = TensorType(dtype=v.dtype, broadcastable=bcast) otype = TensorType(dtype=v.dtype, broadcastable=bcast)
return gof.Apply(self, [v]+sh, [otype()]) return gof.Apply(self, ([v] + sh), [otype()])
def perform(self, node, inputs, out_): def perform(self, node, inputs, out_):
out, = out_ out, = out_
...@@ -2610,10 +2813,10 @@ class Alloc(gof.Op): ...@@ -2610,10 +2813,10 @@ class Alloc(gof.Op):
out[0] = numpy.zeros(sh, dtype=v.dtype) out[0] = numpy.zeros(sh, dtype=v.dtype)
else: else:
out[0] = numpy.empty(sh, dtype=v.dtype) out[0] = numpy.empty(sh, dtype=v.dtype)
out[0][...] = v # broadcast v to fill us up out[0][...] = v # broadcast v to fill us up
else: else:
#reuse the allocated memory. #reuse the allocated memory.
out[0][...] = v # broadcast v to fill us up out[0][...] = v # broadcast v to fill us up
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
# TODO: use the elemwise code generator here # TODO: use the elemwise code generator here
...@@ -2647,6 +2850,7 @@ class Alloc(gof.Op): ...@@ -2647,6 +2850,7 @@ class Alloc(gof.Op):
zz[i] = vv; zz[i] = vv;
} }
""" % locals() """ % locals()
# else pretend this never happened # else pretend this never happened
return super(Alloc, self).c_code(node, name, inp, out, sub) return super(Alloc, self).c_code(node, name, inp, out, sub)
...@@ -2665,10 +2869,11 @@ class Alloc(gof.Op): ...@@ -2665,10 +2869,11 @@ class Alloc(gof.Op):
If the alloc would be useless, this function returns val. If the alloc would be useless, this function returns val.
If you always want an Alloc node, call make_node. If you always want an Alloc node, call make_node.
""" """
ret = super(Alloc,self).__call__(val, *shapes) ret = super(Alloc, self).__call__(val, *shapes)
try: try:
#It makes optimization difficult when useless allocs are thrown into the graph at every # It makes optimization difficult when useless allocs are thrown
#stage of optimization. This little logic tries to help at least in some cases. # into the graph at every stage of optimization. This little logic
# tries to help at least in some cases.
if val.type == ret.type: if val.type == ret.type:
return val return val
except AttributeError: except AttributeError:
...@@ -2729,9 +2934,11 @@ def prod(input, axis=None, dtype=None): ...@@ -2729,9 +2934,11 @@ def prod(input, axis=None, dtype=None):
""" """
return elemwise.Prod(axis, dtype=dtype)(input) return elemwise.Prod(axis, dtype=dtype)(input)
class Mean(elemwise.CAReduce): class Mean(elemwise.CAReduce):
def __init__(self, axis = None): def __init__(self, axis=None):
elemwise.CAReduce.__init__(self, scal.add, axis) elemwise.CAReduce.__init__(self, scal.add, axis)
def __str__(self): def __str__(self):
if self.axis is not None: if self.axis is not None:
return "Mean{%s}" % (", ".join(str(x) for x in self.axis)) return "Mean{%s}" % (", ".join(str(x) for x in self.axis))
...@@ -2745,22 +2952,24 @@ class Mean(elemwise.CAReduce): ...@@ -2745,22 +2952,24 @@ class Mean(elemwise.CAReduce):
def perform(self, node, inp, out): def perform(self, node, inp, out):
input, = inp input, = inp
output, = out output, = out
output[0]=numpy.mean(input,axis=self.axis) output[0] = numpy.mean(input, axis=self.axis)
def c_code(self, node, name, inames, onames, sub): def c_code(self, node, name, inames, onames, sub):
if self.axis!=None: if self.axis != None:
return super(Op, self).c_code(node, name, inames, onames, sub) return super(Op, self).c_code(node, name, inames, onames, sub)
ret = elemwise.CAReduce.c_code(self, node, name, inames, onames, sub) ret = elemwise.CAReduce.c_code(self, node, name, inames, onames, sub)
#TODO: c_code perform support only axis==None #TODO: c_code perform support only axis==None
return ret + """ return ret + """
*((double *)PyArray_DATA(%s)) /= PyArray_SIZE(%s); *((double *)PyArray_DATA(%s)) /= PyArray_SIZE(%s);
"""%(onames[0],inames[0]) """ % (onames[0], inames[0])
#TODO: implement the grad. When done and tested, you can make this the default version. #TODO: implement the grad. When done and tested, you can make this the default
# version.
# def grad(self, (x,), (gout,)): # def grad(self, (x,), (gout,)):
# import pdb;pdb.set_trace() # import pdb;pdb.set_trace()
# return grad(mean(x, self.axis, op=False),[x]) # return grad(mean(x, self.axis, op=False),[x])
@constructor @constructor
def mean(input, axis=None, dtype=None, op=False): def mean(input, axis=None, dtype=None, op=False):
"""Compute the mean value along the given axis of a tensor `input` """Compute the mean value along the given axis of a tensor `input`
...@@ -2817,8 +3026,9 @@ def mean(input, axis=None, dtype=None, op=False): ...@@ -2817,8 +3026,9 @@ def mean(input, axis=None, dtype=None, op=False):
return s return s
@constructor @constructor
def var(input, axis = None): def var(input, axis=None):
"""Compute the variance along the given axis of a tensor `input`. """Compute the variance along the given axis of a tensor `input`.
:param axis: Compute the variance along this axis of the tensor. :param axis: Compute the variance along this axis of the tensor.
...@@ -2854,7 +3064,8 @@ def var(input, axis = None): ...@@ -2854,7 +3064,8 @@ def var(input, axis = None):
centered_input = input - mean_input centered_input = input - mean_input
#return the mean sqr #return the mean sqr
return mean(centered_input**2, axis) return mean((centered_input ** 2), axis)
@constructor @constructor
def std(input, axis=None): def std(input, axis=None):
...@@ -2901,6 +3112,7 @@ if 0: ...@@ -2901,6 +3112,7 @@ if 0:
repeat = Repeat() repeat = Repeat()
class Default(gof.Op): class Default(gof.Op):
""" """
Takes an input x and a default value. If the input is not None, a Takes an input x and a default value. If the input is not None, a
...@@ -2909,39 +3121,45 @@ class Default(gof.Op): ...@@ -2909,39 +3121,45 @@ class Default(gof.Op):
have exactly the same type. have exactly the same type.
""" """
view_map = {0: [0]} view_map = {0: [0]}
def make_node(self, x, default): def make_node(self, x, default):
x, default = as_tensor_variable(x), as_tensor_variable(default) x, default = as_tensor_variable(x), as_tensor_variable(default)
if x.type != default.type: if x.type != default.type:
raise TypeError('Both default() arguments must have same type', x, default) raise TypeError('Both default() arguments must have same type',
x, default)
return gof.Apply(self, [x, default], [default.type()]) return gof.Apply(self, [x, default], [default.type()])
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
x, default = inp x, default = inp
out, = out_ out, = out_
if x is None: if x is None:
# why copy? Theano can't yet understand out[0] being a view of either x or y, # why copy? Theano can't yet understand out[0] being a view of
# so we can be a view of x, but only a copy of y. # either x or y, so we can be a view of x, but only a copy of y.
out[0] = default.copy() out[0] = default.copy()
else: else:
out[0] = x out[0] = x
default = Default() default = Default()
setdefault = default # legacy setdefault = default # legacy
########################## ##########################
# Arithmetics # Arithmetics
########################## ##########################
@_scal_elemwise_with_nfunc('maximum', 2, 1) @_scal_elemwise_with_nfunc('maximum', 2, 1)
def maximum(x,y): def maximum(x, y):
"""elemwise maximum. See max for the maximum in one tensor """elemwise maximum. See max for the maximum in one tensor
""" """
# see decorator for function body # see decorator for function body
@_scal_elemwise_with_nfunc('minimum', 2, 1) @_scal_elemwise_with_nfunc('minimum', 2, 1)
def minimum(x,y): def minimum(x, y):
"""elemwise minimum. See min for the minimum in one tensor """elemwise minimum. See min for the minimum in one tensor
""" """
# see decorator for function body # see decorator for function body
def div_proxy(x, y): def div_proxy(x, y):
"""Proxy for either true_div or int_div, depending on types of x, y.""" """Proxy for either true_div or int_div, depending on types of x, y."""
f = eval('%s_div' % scal.int_or_true_div( f = eval('%s_div' % scal.int_or_true_div(
...@@ -2949,32 +3167,39 @@ def div_proxy(x, y): ...@@ -2949,32 +3167,39 @@ def div_proxy(x, y):
as_tensor_variable(y).dtype in discrete_dtypes)) as_tensor_variable(y).dtype in discrete_dtypes))
return f(x, y) return f(x, y)
@_scal_elemwise_with_nfunc('add', 2, 1) @_scal_elemwise_with_nfunc('add', 2, 1)
def add(a, *other_terms): def add(a, *other_terms):
"""elementwise addition""" """elementwise addition"""
# see decorator for function body # see decorator for function body
@_scal_elemwise_with_nfunc('subtract', 2, 1) @_scal_elemwise_with_nfunc('subtract', 2, 1)
def sub(a, b): def sub(a, b):
"""elementwise subtraction""" """elementwise subtraction"""
# see decorator for function body # see decorator for function body
@_scal_elemwise_with_nfunc('multiply', 2, 1) @_scal_elemwise_with_nfunc('multiply', 2, 1)
def mul(a, *other_terms): def mul(a, *other_terms):
"""elementwise multiplication""" """elementwise multiplication"""
# see decorator for function body # see decorator for function body
@_scal_elemwise_with_nfunc('true_divide', 2, 1) @_scal_elemwise_with_nfunc('true_divide', 2, 1)
def true_div(a, b): def true_div(a, b):
"""elementwise [true] division (inverse of multiplication)""" """elementwise [true] division (inverse of multiplication)"""
# see decorator for function body # see decorator for function body
@_scal_elemwise_with_nfunc('floor_divide', 2, 1) @_scal_elemwise_with_nfunc('floor_divide', 2, 1)
def floor_div(a, b): def floor_div(a, b):
"""elementwise [floor] division (inverse of multiplication)""" """elementwise [floor] division (inverse of multiplication)"""
# see decorator for function body # see decorator for function body
@_scal_elemwise_with_nfunc('floor_divide', 2, 1) # not a c/p error, floor_div and int_div are the same thing
# not a c/p error, floor_div and int_div are the same thing
@_scal_elemwise_with_nfunc('floor_divide', 2, 1)
def int_div(a, b): def int_div(a, b):
"""elementwise integer-division""" """elementwise integer-division"""
# see decorator for function body # see decorator for function body
...@@ -3009,19 +3234,22 @@ def mod_check(x, y): ...@@ -3009,19 +3234,22 @@ def mod_check(x, y):
else: else:
return mod(x, y) return mod(x, y)
@_scal_elemwise_with_nfunc('mod', 2, 1) @_scal_elemwise_with_nfunc('mod', 2, 1)
def mod(a, b): def mod(a, b):
"""elementwise modulo""" """elementwise modulo"""
# see decorator for function body # see decorator for function body
@_scal_elemwise_with_nfunc('power', 2, 1) @_scal_elemwise_with_nfunc('power', 2, 1)
def pow(a, b): def pow(a, b):
"""elementwise power""" """elementwise power"""
# see decorator for function body # see decorator for function body
# The numpy.clip don't work correctly when # The numpy.clip don't work correctly when
# the min is bigger then the max # the min is bigger then the max
@_scal_elemwise #_with_nfunc('clip', 3, 1) @_scal_elemwise # _with_nfunc('clip', 3, 1)
def clip(x, min, max): def clip(x, min, max):
"""clip x to be between min and max""" """clip x to be between min and max"""
# see decorator for function body # see decorator for function body
...@@ -3036,7 +3264,6 @@ pprint.assign(int_div, printing.OperatorPrinter('//', -1, 'left')) ...@@ -3036,7 +3264,6 @@ pprint.assign(int_div, printing.OperatorPrinter('//', -1, 'left'))
pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right')) pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right'))
########################## ##########################
# View Operations # View Operations
########################## ##########################
...@@ -3045,9 +3272,6 @@ pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right')) ...@@ -3045,9 +3272,6 @@ pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right'))
# Helpful functions to deal with Subtensor and IncSubtensor # Helpful functions to deal with Subtensor and IncSubtensor
########## ##########
def get_idx_list(inputs, idx_list): def get_idx_list(inputs, idx_list):
''' '''
Given a list of inputs to the subtensor and its idx_list reorders Given a list of inputs to the subtensor and its idx_list reorders
...@@ -3105,53 +3329,52 @@ def get_canonical_form_slice(theslice, length): ...@@ -3105,53 +3329,52 @@ def get_canonical_form_slice(theslice, length):
resulting set of numbers needs to be reversed or not. resulting set of numbers needs to be reversed or not.
''' '''
if isinstance(theslice, slice):
if isinstance(theslice,slice):
start = extract_constant(theslice.start) start = extract_constant(theslice.start)
stop = extract_constant(theslice.stop) stop = extract_constant(theslice.stop)
step = extract_constant(theslice.step) step = extract_constant(theslice.step)
if step is None: if step is None:
step = 1 step = 1
defstart = switch(lt(step,0), length-1, 0) defstart = switch(lt(step, 0), (length - 1), 0)
defstop = switch(lt(step,0), -1, length ) defstop = switch(lt(step, 0), -1, length)
if start is None: if start is None:
start = defstart start = defstart
else: else:
start = switch(lt(start,0), start + length, start) start = switch(lt(start, 0), start + length, start)
start = switch(lt(start,0), switch(lt(step,0), -1, 0), start) start = switch(lt(start, 0), switch(lt(step, 0), -1, 0), start)
start = switch(ge(start,length) start = switch(ge(start, length),
, switch(lt(step,0),length-1,length) switch(lt(step, 0), (length - 1), length),
, start) start)
if stop in [None, maxsize]: if stop in [None, maxsize]:
# The special "maxsize" case is probably not needed here, # The special "maxsize" case is probably not needed here,
# as slices containing maxsize are not generated by # as slices containing maxsize are not generated by
# __getslice__ anymore. # __getslice__ anymore.
stop = defstop stop = defstop
else: else:
stop = switch(lt(stop,0), stop + length, stop) stop = switch(lt(stop, 0), stop + length, stop)
stop = switch(lt(stop,0), -1, stop) stop = switch(lt(stop, 0), -1, stop)
stop = switch(ge(stop,length), length,stop) stop = switch(ge(stop, length), length, stop)
nw_stop = switch(lt(step,0), start+1, stop ) nw_stop = switch(lt(step, 0), (start + 1), stop)
slice_len = ( start -stop - 1)//abs(step) + 1 slice_len = (start - stop - 1) // abs(step) + 1
slice_len = switch(lt(slice_len,0), 0, slice_len) slice_len = switch(lt(slice_len, 0), 0, slice_len)
neg_start = nw_stop - (slice_len-1)*abs(step)-1 neg_start = nw_stop - (slice_len - 1) * abs(step) - 1
neg_start = switch(lt(neg_start,0), nw_stop-1, neg_start) neg_start = switch(lt(neg_start, 0), (nw_stop - 1), neg_start)
nw_start = switch(lt(step,0), neg_start, start) nw_start = switch(lt(step, 0), neg_start, start)
nw_start = switch(lt(nw_start,0), 0, nw_start) nw_start = switch(lt(nw_start, 0), 0, nw_start)
nw_stop = switch(lt(nw_stop,0) , 0, nw_stop ) nw_stop = switch(lt(nw_stop, 0), 0, nw_stop)
nw_step = abs(step) nw_step = abs(step)
if step != 1: if step != 1:
reverse = sgn(step) reverse = sgn(step)
return slice(nw_start, nw_stop, nw_step), reverse return slice(nw_start, nw_stop, nw_step), reverse
else: else:
return slice(nw_start, nw_stop, nw_step), 1 return slice(nw_start, nw_stop, nw_step), 1
else: else:
value = extract_constant(theslice) value = extract_constant(theslice)
value = switch(lt(value,0), value+length, value) value = switch(lt(value, 0), (value + length), value)
return value, 1 return value, 1
...@@ -3165,7 +3388,7 @@ def transpose(x, axes=None): ...@@ -3165,7 +3388,7 @@ def transpose(x, axes=None):
""" """
if axes is None: if axes is None:
axes = range(x.ndim-1, -1, -1) axes = range((x.ndim - 1), -1, -1)
return DimShuffle(x.broadcastable, axes, inplace=False)(x) return DimShuffle(x.broadcastable, axes, inplace=False)(x)
...@@ -3175,7 +3398,7 @@ class AdvancedIndexingError(TypeError): ...@@ -3175,7 +3398,7 @@ class AdvancedIndexingError(TypeError):
""" """
def __init__(self, *args): def __init__(self, *args):
TypeError.__init__( self, *args) TypeError.__init__(self, *args)
class Subtensor(Op): class Subtensor(Op):
...@@ -3185,8 +3408,8 @@ class Subtensor(Op): ...@@ -3185,8 +3408,8 @@ class Subtensor(Op):
to remember how the input tensor x should be sliced. The instance variable to remember how the input tensor x should be sliced. The instance variable
idx_list is a list whose elements are either integers, or slices. The idx_list is a list whose elements are either integers, or slices. The
integers are indexes into the inputs array, and the start/stop/step members integers are indexes into the inputs array, and the start/stop/step members
of each slice are also integer indexes into the inputs array (or None). The of each slice are also integer indexes into the inputs array (or None).
inputs array is the tensor x, followed by scalar integer variables. The inputs array is the tensor x, followed by scalar integer variables.
@todo: add support for advanced tensor indexing (in Subtensor_dx too). @todo: add support for advanced tensor indexing (in Subtensor_dx too).
...@@ -3197,7 +3420,7 @@ class Subtensor(Op): ...@@ -3197,7 +3420,7 @@ class Subtensor(Op):
additionally be a Scalar instance, and slice components can also be Scalar additionally be a Scalar instance, and slice components can also be Scalar
instances too. instances too.
""" """
e_invalid = ( 'The index list is longer (size %d) than the number of ' e_invalid = ('The index list is longer (size %d) than the number of '
'dimensions of the tensor(namely %d). You are asking for ' 'dimensions of the tensor(namely %d). You are asking for '
'a dimension of the tensor that does not exist! You might ' 'a dimension of the tensor that does not exist! You might '
'need to use dimshuffle to add extra dimension to your ' 'need to use dimshuffle to add extra dimension to your '
...@@ -3211,33 +3434,42 @@ class Subtensor(Op): ...@@ -3211,33 +3434,42 @@ class Subtensor(Op):
@staticmethod @staticmethod
def collapse(idxs, cond): def collapse(idxs, cond):
ret = [] ret = []
def helper(entry): def helper(entry):
if cond(entry): if cond(entry):
ret.append(entry) ret.append(entry)
elif isinstance(entry, slice): elif isinstance(entry, slice):
helper(entry.start) helper(entry.start)
helper(entry.stop) helper(entry.stop)
helper( entry.step) helper(entry.step)
for idx in idxs: for idx in idxs:
helper(idx) helper(idx)
return ret return ret
@staticmethod @staticmethod
def convert(entry, slice_ok=True): def convert(entry, slice_ok=True):
invalid_scal_types = [scal.float64, scal.float32 ] invalid_scal_types = [scal.float64, scal.float32]
scal_types = [scal.int64, scal.int32, scal.int16, scal.int8] scal_types = [scal.int64, scal.int32, scal.int16, scal.int8]
tensor_types = [lscalar, iscalar, wscalar, bscalar] tensor_types = [lscalar, iscalar, wscalar, bscalar]
invalid_tensor_types = [fscalar, dscalar, cscalar, zscalar ] invalid_tensor_types = [fscalar, dscalar, cscalar, zscalar]
if isinstance(entry, gof.Variable) and (entry.type in invalid_scal_types \ if (isinstance(entry, gof.Variable)
or entry.type in invalid_tensor_types): and (entry.type in invalid_scal_types
or entry.type in invalid_tensor_types)):
raise TypeError("Expected an integer") raise TypeError("Expected an integer")
if isinstance(entry, gof.Variable) and entry.type in scal_types: if isinstance(entry, gof.Variable) and entry.type in scal_types:
return entry.type return entry.type
elif isinstance(entry, gof.Type) and entry in scal_types: elif isinstance(entry, gof.Type) and entry in scal_types:
return entry return entry
if isinstance(entry, gof.Variable) and entry.type in tensor_types and numpy.all(entry.type.broadcastable):
if (isinstance(entry, gof.Variable)
and entry.type in tensor_types
and numpy.all(entry.type.broadcastable)):
return scal.Scalar(entry.type.dtype) return scal.Scalar(entry.type.dtype)
elif isinstance(entry, gof.Type) and entry in tensor_types and numpy.all(entry.broadcastable): elif (isinstance(entry, gof.Type)
and entry in tensor_types
and numpy.all(entry.broadcastable)):
return scal.Scalar(entry.dtype) return scal.Scalar(entry.dtype)
elif slice_ok and isinstance(entry, slice): elif slice_ok and isinstance(entry, slice):
a = entry.start a = entry.start
...@@ -3283,15 +3515,14 @@ class Subtensor(Op): ...@@ -3283,15 +3515,14 @@ class Subtensor(Op):
else: else:
return scal.as_scalar(a) return scal.as_scalar(a)
def make_node(self, x, *inputs): def make_node(self, x, *inputs):
x = as_tensor_variable(x) x = as_tensor_variable(x)
inputs = tuple(self.my_as_scalar(a) for a in inputs) inputs = tuple(self.my_as_scalar(a) for a in inputs)
idx_list = list(self.idx_list) idx_list = list(self.idx_list)
if len(idx_list) > x.type.ndim: if len(idx_list) > x.type.ndim:
exception = ValueError(Subtensor.e_invalid%(len(idx_list), exception = ValueError(Subtensor.e_invalid % (
x.type.ndim)) len(idx_list), x.type.ndim))
exception.subtensor_invalid = True exception.subtensor_invalid = True
raise exception raise exception
...@@ -3310,13 +3541,13 @@ class Subtensor(Op): ...@@ -3310,13 +3541,13 @@ class Subtensor(Op):
for input, expected_type in zip(inputs, input_types): for input, expected_type in zip(inputs, input_types):
if input.type != expected_type: if input.type != expected_type:
raise TypeError( raise TypeError(
"Wrong type for Subtensor template. Expected %s, got %s."%( "Wrong type for Subtensor template. Expected %s, got %s."
input.type, expected_type)) % (input.type, expected_type))
return gof.Apply(self, return gof.Apply(self,
(x, ) + inputs, (x, ) + inputs,
[tensor(dtype = x.type.dtype, [tensor(dtype=x.type.dtype,
broadcastable = broadcastable)]) broadcastable=broadcastable)])
def perform(self, node, inputs, out_): def perform(self, node, inputs, out_):
out, = out_ out, = out_
...@@ -3342,21 +3573,22 @@ class Subtensor(Op): ...@@ -3342,21 +3573,22 @@ class Subtensor(Op):
assert len(xshp) == node.inputs[0].ndim assert len(xshp) == node.inputs[0].ndim
outshp = [] outshp = []
actual_idx_list = list(get_idx_list(node.inputs, self.idx_list)) actual_idx_list = list(get_idx_list(node.inputs, self.idx_list))
padded = ( actual_idx_list + padded = (actual_idx_list +
[slice(None, None, None)]*(len(xshp)-len(self.idx_list))) [slice(None, None, None)] * (len(xshp) - len(self.idx_list)))
i = 0 i = 0
for idx, xl in izip(padded, xshp): for idx, xl in izip(padded, xshp):
if isinstance(idx, slice): if isinstance(idx, slice):
# If it is the default (None, None, None) slice, or a variant, # If it is the default (None, None, None) slice, or a variant,
# the shape will be xl # the shape will be xl
if ( (idx.start in [None, 0]) if ((idx.start in [None, 0])
and (idx.stop in [None, maxsize]) and (idx.stop in [None, maxsize])
and (idx.step is None or idx.step == 1) ): and (idx.step is None or idx.step == 1)):
outshp.append(xl) outshp.append(xl)
else: else:
cnf = get_canonical_form_slice(idx, xl) cnf = get_canonical_form_slice(idx, xl)
length = (cnf[0].stop - cnf[0].start -1) // cnf[0].step + 1 length = ((cnf[0].stop - cnf[0].start - 1) // cnf[0].step
length = switch(lt(length,0), 0, length) + 1)
length = switch(lt(length, 0), 0, length)
outshp.append(length) outshp.append(length)
i += 1 i += 1
else: else:
...@@ -3370,7 +3602,8 @@ class Subtensor(Op): ...@@ -3370,7 +3602,8 @@ class Subtensor(Op):
gz, = grads gz, = grads
x = inputs[0] x = inputs[0]
rest = inputs[1:] rest = inputs[1:]
return [IncSubtensor(self.idx_list)(zeros_like(x), gz, *rest)] + [None] * len(rest) return ([IncSubtensor(self.idx_list)(zeros_like(x), gz, *rest)]
+ [None] * len(rest))
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.idx_list == other.idx_list return type(self) == type(other) and self.idx_list == other.idx_list
...@@ -3401,6 +3634,7 @@ class Subtensor(Op): ...@@ -3401,6 +3634,7 @@ class Subtensor(Op):
else: else:
msg.append(str(x)) msg.append(str(x))
return ":".join(msg) return ":".join(msg)
def __str__(self): def __str__(self):
indices = [] indices = []
for entry in self.idx_list: for entry in self.idx_list:
...@@ -3420,44 +3654,53 @@ class Subtensor(Op): ...@@ -3420,44 +3654,53 @@ class Subtensor(Op):
# subtensor_spec: len = n_ints + 3 * n_slices # subtensor_spec: len = n_ints + 3 * n_slices
# #
fail = sub['fail'] fail = sub['fail']
init_cmds = [] # initialization for subtensor_spec init_cmds = [] # initialization for subtensor_spec
is_slice = [] is_slice = []
#TODO: change that, it might lead to unexpected results, #TODO: change that, it might lead to unexpected results,
# see assembla-#767 # see assembla-#767
NONE_CODE = maxsize - 1 NONE_CODE = maxsize - 1
pos = [0,1] #annoying version of global variable for init_entry pos = [0, 1] # annoying version of global variable for init_entry
def inc_spec_pos(amt): pos[0] += amt
def inc_input_pos(amt): pos[1] += amt def inc_spec_pos(amt):
def spec_pos(): return pos[0] pos[0] += amt
def input_pos(): return pos[1]
def inc_input_pos(amt):
pos[1] += amt
def spec_pos():
return pos[0]
def input_pos():
return pos[1]
def init_entry(entry, depth=0): def init_entry(entry, depth=0):
if isinstance(entry, int): if isinstance(entry, int):
init_cmds.append( init_cmds.append(
"subtensor_spec[%i] = %i;" %(spec_pos(), "subtensor_spec[%i] = %i;" % (spec_pos(),
entry)) entry))
inc_spec_pos(1) inc_spec_pos(1)
if depth==0: if depth == 0:
is_slice.append(0) is_slice.append(0)
elif isinstance(entry, Type): elif isinstance(entry, Type):
init_cmds.append( init_cmds.append(
"subtensor_spec[%i] = %s;" %(spec_pos(), "subtensor_spec[%i] = %s;" % (spec_pos(),
inputs[input_pos()])) inputs[input_pos()]))
inc_spec_pos(1) inc_spec_pos(1)
inc_input_pos(1) inc_input_pos(1)
if depth==0: if depth == 0:
is_slice.append(0) is_slice.append(0)
elif entry is None: elif entry is None:
init_cmds.append( init_cmds.append(
"subtensor_spec[%i] = %i;" %(spec_pos(), "subtensor_spec[%i] = %i;" % (spec_pos(),
NONE_CODE)) NONE_CODE))
inc_spec_pos(1) inc_spec_pos(1)
if depth==0: if depth == 0:
is_slice.append(0) is_slice.append(0)
elif depth==0 and isinstance(entry, slice): elif depth == 0 and isinstance(entry, slice):
init_entry(entry.start, depth+1) init_entry(entry.start, depth + 1)
init_entry(entry.stop, depth+1) init_entry(entry.stop, depth + 1)
init_entry(entry.step, depth+1) init_entry(entry.step, depth + 1)
is_slice.append(1) is_slice.append(1)
else: else:
assert 0, entry assert 0, entry
...@@ -3469,7 +3712,7 @@ class Subtensor(Op): ...@@ -3469,7 +3712,7 @@ class Subtensor(Op):
assert len(is_slice) <= node.inputs[0].ndim, node.inputs[0].ndim assert len(is_slice) <= node.inputs[0].ndim, node.inputs[0].ndim
len_is_slice = len(is_slice) len_is_slice = len(is_slice)
view_ndim = node.inputs[0].ndim - (numpy.asarray(is_slice)==0).sum() view_ndim = node.inputs[0].ndim - (numpy.asarray(is_slice) == 0).sum()
len_subtensor_spec = spec_pos() len_subtensor_spec = spec_pos()
...@@ -3635,7 +3878,7 @@ class Subtensor(Op): ...@@ -3635,7 +3878,7 @@ class Subtensor(Op):
outer_ii += 1; outer_ii += 1;
} }
PyArray_UpdateFlags(xview, NPY_C_CONTIGUOUS|NPY_F_CONTIGUOUS); PyArray_UpdateFlags(xview, NPY_C_CONTIGUOUS|NPY_F_CONTIGUOUS);
"""% locals() """ % locals()
#print rval #print rval
return rval return rval
...@@ -3643,7 +3886,7 @@ class Subtensor(Op): ...@@ -3643,7 +3886,7 @@ class Subtensor(Op):
def helper_c_code_cache_version(): def helper_c_code_cache_version():
return (4,) return (4,)
def c_code(self, node, name, inputs, outputs, sub): #DEBUG def c_code(self, node, name, inputs, outputs, sub): # DEBUG
part0 = self.helper_c_code(node, name, inputs, outputs, sub, part0 = self.helper_c_code(node, name, inputs, outputs, sub,
self.idx_list) self.idx_list)
...@@ -3655,11 +3898,10 @@ class Subtensor(Op): ...@@ -3655,11 +3898,10 @@ class Subtensor(Op):
xview->base = py_%(x)s; xview->base = py_%(x)s;
assert(py_%(x)s == (PyObject*)%(x)s); assert(py_%(x)s == (PyObject*)%(x)s);
%(z)s = xview; %(z)s = xview;
""" %locals() """ % locals()
return part0 + part1 return part0 + part1
def c_code_cache_version(self): def c_code_cache_version(self):
hv = self.helper_c_code_cache_version() hv = self.helper_c_code_cache_version()
# If `helper_c_code_cache_version` is not versioned we do not want to # If `helper_c_code_cache_version` is not versioned we do not want to
...@@ -3676,6 +3918,7 @@ class Subtensor(Op): ...@@ -3676,6 +3918,7 @@ class Subtensor(Op):
return [None] return [None]
return self.make_node(eval_points[0], *inputs[1:]).outputs return self.make_node(eval_points[0], *inputs[1:]).outputs
class SubtensorPrinter: class SubtensorPrinter:
def process(self, r, pstate): def process(self, r, pstate):
...@@ -3686,34 +3929,39 @@ class SubtensorPrinter: ...@@ -3686,34 +3929,39 @@ class SubtensorPrinter:
inputs = list(r.owner.inputs) inputs = list(r.owner.inputs)
input = inputs.pop() input = inputs.pop()
sidxs = [] sidxs = []
inbrack_pstate = pstate.clone(precedence = -1000) inbrack_pstate = pstate.clone(precedence=-1000)
for entry in idxs: for entry in idxs:
if isinstance(entry, int): if isinstance(entry, int):
sidxs.append(str(entry)) sidxs.append(str(entry))
elif isinstance(entry, scal.Scalar): elif isinstance(entry, scal.Scalar):
sidxs.append(inbrack_pstate.pprinter.process(inputs.pop())) sidxs.append(inbrack_pstate.pprinter.process(inputs.pop()))
elif isinstance(entry, slice): elif isinstance(entry, slice):
if entry.start is None or entry.start==0: if entry.start is None or entry.start == 0:
msg1 = "" msg1 = ""
else: else:
msg1 = entry.start msg1 = entry.start
if entry.stop is None or entry.stop == maxsize: if entry.stop is None or entry.stop == maxsize:
msg2 = "" msg2 = ""
else: else:
msg2 = entry.stop msg2 = entry.stop
if entry.step is None: if entry.step is None:
msg3 = "" msg3 = ""
else: else:
msg3 = ":%s" % entry.step msg3 = ":%s" % entry.step
sidxs.append("%s:%s%s" % (msg1, msg2, msg3)) sidxs.append("%s:%s%s" % (msg1, msg2, msg3))
return "%s[%s]" % (pstate.pprinter.process(input, pstate.clone(precedence = 1000)), ", ".join(sidxs)) return "%s[%s]" % (pstate.pprinter.process(
input,
pstate.clone(precedence=1000)),
", ".join(sidxs))
else: else:
raise TypeError("Can only print Subtensor.") raise TypeError("Can only print Subtensor.")
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Subtensor), SubtensorPrinter()) pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Subtensor),
SubtensorPrinter())
def set_subtensor(x, y, inplace=False, def set_subtensor(x, y, inplace=False,
tolerate_inplace_aliasing=False): tolerate_inplace_aliasing=False):
...@@ -3730,6 +3978,7 @@ def set_subtensor(x, y, inplace=False, ...@@ -3730,6 +3978,7 @@ def set_subtensor(x, y, inplace=False,
return inc_subtensor(x, y, inplace, set_instead_of_inc=True, return inc_subtensor(x, y, inplace, set_instead_of_inc=True,
tolerate_inplace_aliasing=tolerate_inplace_aliasing) tolerate_inplace_aliasing=tolerate_inplace_aliasing)
def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False, def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False,
tolerate_inplace_aliasing=False): tolerate_inplace_aliasing=False):
"""Return x with the given subtensor incremented by y. """Return x with the given subtensor incremented by y.
...@@ -3753,7 +4002,8 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False, ...@@ -3753,7 +4002,8 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False,
else: else:
destroyhandler_tolerate_aliased = [] destroyhandler_tolerate_aliased = []
the_op = IncSubtensor(x.owner.op.idx_list, inplace, set_instead_of_inc, the_op = IncSubtensor(x.owner.op.idx_list, inplace, set_instead_of_inc,
destroyhandler_tolerate_aliased=destroyhandler_tolerate_aliased) destroyhandler_tolerate_aliased=destroyhandler_tolerate_aliased
)
real_x = x.owner.inputs[0] real_x = x.owner.inputs[0]
real_idxargs = x.owner.inputs[1:] real_idxargs = x.owner.inputs[1:]
return the_op(real_x, y, *real_idxargs) return the_op(real_x, y, *real_idxargs)
...@@ -3790,7 +4040,8 @@ class IncSubtensor(Op): ...@@ -3790,7 +4040,8 @@ class IncSubtensor(Op):
self.inplace = inplace self.inplace = inplace
if inplace: if inplace:
self.destroy_map = {0: [0]} self.destroy_map = {0: [0]}
self.destroyhandler_tolerate_aliased = list(destroyhandler_tolerate_aliased) self.destroyhandler_tolerate_aliased = list(
destroyhandler_tolerate_aliased)
self.set_instead_of_inc = set_instead_of_inc self.set_instead_of_inc = set_instead_of_inc
def __eq__(self, other): def __eq__(self, other):
...@@ -3843,19 +4094,13 @@ class IncSubtensor(Op): ...@@ -3843,19 +4094,13 @@ class IncSubtensor(Op):
idx_list = list(self.idx_list) idx_list = list(self.idx_list)
if len(idx_list) > x.type.ndim: if len(idx_list) > x.type.ndim:
exception = ValueError( exception = ValueError(
Subtensor.e_invalid%( Subtensor.e_invalid % (
len(idx_list), len(idx_list),
x.type.ndim)) x.type.ndim))
exception.subtensor_invalid = True exception.subtensor_invalid = True
raise exception raise exception
#infer the broadcasting pattern input_types = Subtensor.collapse(idx_list,
padded = (idx_list
+ [slice(None, None, None)] * (x.type.ndim - len(idx_list)))
broadcastable = [bc for p, bc in zip(padded, x.type.broadcastable)
if isinstance(p, slice)]
input_types = Subtensor.collapse( idx_list,
lambda entry: isinstance(entry, gof.Type)) lambda entry: isinstance(entry, gof.Type))
if len(inputs) != len(input_types): if len(inputs) != len(input_types):
raise IndexError( raise IndexError(
...@@ -3864,8 +4109,8 @@ class IncSubtensor(Op): ...@@ -3864,8 +4109,8 @@ class IncSubtensor(Op):
for input, expected_type in zip(inputs, input_types): for input, expected_type in zip(inputs, input_types):
if input.type != expected_type: if input.type != expected_type:
raise TypeError( raise TypeError(
"Wrong type for Subtensor template. Expected %s, got %s."%( "Wrong type for Subtensor template. Expected %s, got %s."
input.type, expected_type)) % (input.type, expected_type))
return gof.Apply(self, return gof.Apply(self,
(x, y) + inputs, (x, y) + inputs,
...@@ -3907,16 +4152,16 @@ class IncSubtensor(Op): ...@@ -3907,16 +4152,16 @@ class IncSubtensor(Op):
x.__setitem__(cdata, y) x.__setitem__(cdata, y)
out[0] = x out[0] = x
def c_code(self, node, name, inputs, outputs, sub): #DEBUG def c_code(self, node, name, inputs, outputs, sub): # DEBUG
if self.inplace: # convert bool to int if self.inplace: # convert bool to int
inplace = 1 inplace = 1
else: else:
inplace = 0 inplace = 0
x = inputs[0] x = inputs[0]
y = inputs[1] y = inputs[1]
z, = outputs z, = outputs
if self.set_instead_of_inc: # convert bool to int if self.set_instead_of_inc: # convert bool to int
op_is_set = 1 op_is_set = 1
else: else:
op_is_set = 0 op_is_set = 0
...@@ -3941,10 +4186,9 @@ class IncSubtensor(Op): ...@@ -3941,10 +4186,9 @@ class IncSubtensor(Op):
# make xview actually a view of %(z)s # make xview actually a view of %(z)s
get_xview = Subtensor.helper_c_code(node, name, get_xview = Subtensor.helper_c_code(node, name,
outputs[:1]+inputs[2:], outputs[:1] + inputs[2:],
outputs, sub, self.idx_list) outputs, sub, self.idx_list)
make_modification = """ make_modification = """
if (%(op_is_set)s) if (%(op_is_set)s)
{ {
...@@ -3970,7 +4214,7 @@ class IncSubtensor(Op): ...@@ -3970,7 +4214,7 @@ class IncSubtensor(Op):
%(fail)s; %(fail)s;
} }
} }
""" %locals() """ % locals()
return (copy_input_if_necessary return (copy_input_if_necessary
+ get_xview + get_xview
...@@ -4295,7 +4539,6 @@ class Join(Op): ...@@ -4295,7 +4539,6 @@ class Join(Op):
def _make_node_internal(self, axis, tensors, def _make_node_internal(self, axis, tensors,
as_tensor_variable_args, output_maker): as_tensor_variable_args, output_maker):
orig = as_tensor_variable_args
if not python_all(targs.type.ndim for targs if not python_all(targs.type.ndim for targs
in as_tensor_variable_args): in as_tensor_variable_args):
raise TypeError('Join cannot handle arguments of dimension 0.' raise TypeError('Join cannot handle arguments of dimension 0.'
...@@ -4340,7 +4583,7 @@ class Join(Op): ...@@ -4340,7 +4583,7 @@ class Join(Op):
bcastable[current_axis] = True bcastable[current_axis] = True
try: try:
bcastable[axis] = False bcastable[axis] = False
except IndexError, e: except IndexError:
raise ValueError('Join argument "axis" is out of range' raise ValueError('Join argument "axis" is out of range'
' (given input dimensions)') ' (given input dimensions)')
as_tensor_variable_args = [unbroadcast(x, axis) as_tensor_variable_args = [unbroadcast(x, axis)
...@@ -4699,7 +4942,6 @@ if 0: ...@@ -4699,7 +4942,6 @@ if 0:
x, y = inp x, y = inp
gz, = grads gz, = grads
xs = shape(x) xs = shape(x)
ys = shape(y)
return gz[:xs[0]], gz[xs[0]:] return gz[:xs[0]], gz[xs[0]:]
vertical_stack = VerticalStack() vertical_stack = VerticalStack()
...@@ -4765,7 +5007,7 @@ class Reshape(Op): ...@@ -4765,7 +5007,7 @@ class Reshape(Op):
', should be %i' % (len(shp), self.ndim), shp) ', should be %i' % (len(shp), self.ndim), shp)
try: try:
out[0] = numpy.reshape(x, shp) out[0] = numpy.reshape(x, shp)
except Exception, e: except Exception:
raise ValueError('Cannot reshape input of shape %s to shape %s' % raise ValueError('Cannot reshape input of shape %s to shape %s' %
(x.shape, shp)) (x.shape, shp))
......
...@@ -721,14 +721,14 @@ class ShapeFeature(object): ...@@ -721,14 +721,14 @@ class ShapeFeature(object):
def shape_ir(self, i, r): def shape_ir(self, i, r):
"""Return symbolic r.shape[i] for tensor variable r, int i""" """Return symbolic r.shape[i] for tensor variable r, int i"""
if hasattr(r.type,"broadcastable") and r.type.broadcastable[i]: if hasattr(r.type, "broadcastable") and r.type.broadcastable[i]:
return self.lscalar_one return self.lscalar_one
else: else:
return Shape_i(i).make_node(r).outputs[0] return Shape_i(i).make_node(r).outputs[0]
def shape_tuple(self, r): def shape_tuple(self, r):
"""Return a tuple of symbolic shape vars for tensor variable r""" """Return a tuple of symbolic shape vars for tensor variable r"""
return tuple([self.shape_ir(i,r) for i in xrange(r.ndim)]) return tuple([self.shape_ir(i, r) for i in xrange(r.ndim)])
def default_infer_shape(self, node, i_shapes): def default_infer_shape(self, node, i_shapes):
"""Return a list of shape tuple or None for the outputs of node. """Return a list of shape tuple or None for the outputs of node.
...@@ -861,7 +861,7 @@ class ShapeFeature(object): ...@@ -861,7 +861,7 @@ class ShapeFeature(object):
if r not in self.shape_of: if r not in self.shape_of:
try: try:
self.set_shape(r, self.shape_tuple(r)) self.set_shape(r, self.shape_tuple(r))
except AttributeError: #XXX: where would this come from? except AttributeError: # XXX: where would this come from?
self.set_shape(r, None) self.set_shape(r, None)
def make_vector_shape(self, r): def make_vector_shape(self, r):
...@@ -949,17 +949,18 @@ class ShapeFeature(object): ...@@ -949,17 +949,18 @@ class ShapeFeature(object):
if sh is None: if sh is None:
continue continue
for i, d in enumerate(sh): for i, d in enumerate(sh):
# Note: we ignore any shape element that is not typed (i.e. does # Note: we ignore any shape element that is not typed (i.e.,
# not have a 'dtype' attribute). This means there may still # does not have a 'dtype' attribute). This means there may
# remain int elements that are int32 on 32-bit platforms, but # still remain int elements that are int32 on 32-bit platforms,
# this works with `local_useless_subtensor`, so for now we # but this works with `local_useless_subtensor`, so for now we
# keep it this way. See #266 for a better long-term fix. # keep it this way. See #266 for a better long-term fix.
if getattr(d, 'dtype', 'int64') != 'int64': if getattr(d, 'dtype', 'int64') != 'int64':
assert d.dtype in theano.tensor.int_dtypes assert d.dtype in theano.tensor.int_dtypes
new_shape += sh[len(new_shape):i + 1] new_shape += sh[len(new_shape):i + 1]
new_shape[i] = theano.tensor.cast(d, 'int64') new_shape[i] = theano.tensor.cast(d, 'int64')
if new_shape: if new_shape:
# We replace the shape with wrong dtype by the one with 'int64'. # We replace the shape with wrong dtype by the one with
# 'int64'.
new_shape += sh[len(new_shape):] new_shape += sh[len(new_shape):]
o_shapes[sh_idx] = tuple(new_shape) o_shapes[sh_idx] = tuple(new_shape)
new_shape = [] new_shape = []
...@@ -990,8 +991,8 @@ class ShapeFeature(object): ...@@ -990,8 +991,8 @@ class ShapeFeature(object):
for (shpnode, idx) in (r.clients + [(node, i)]): for (shpnode, idx) in (r.clients + [(node, i)]):
if isinstance(getattr(shpnode, 'op', None), Shape_i): if isinstance(getattr(shpnode, 'op', None), Shape_i):
self.scheduled[shpnode] = new_r self.scheduled[shpnode] = new_r
# In case 2, if r is a variable that we've scheduled for shape update, then we # In case 2, if r is a variable that we've scheduled for shape update,
# should cancel it. # then we should cancel it.
unscheduled = [k for k, v in self.scheduled.items() if v == r] unscheduled = [k for k, v in self.scheduled.items() if v == r]
for k in unscheduled: for k in unscheduled:
del self.scheduled[k] del self.scheduled[k]
...@@ -1212,9 +1213,10 @@ def local_alloc_unary(node): ...@@ -1212,9 +1213,10 @@ def local_alloc_unary(node):
class Assert(T.Op): class Assert(T.Op):
""" """
Implements assertion in a computational graph. Implements assertion in a computational graph.
Notes: Notes:
This Op can be removed from the graph because of optimizations, and can hide This Op can be removed from the graph because of optimizations, and can
some possible optimizations to the optimizer. hide some possible optimizations to the optimizer.
Also, the output of the Op must be returned by the function computing the Also, the output of the Op must be returned by the function computing the
graph, otherwise it will not be used. graph, otherwise it will not be used.
""" """
...@@ -2773,7 +2775,6 @@ class Canonizer(gof.LocalOptimizer): ...@@ -2773,7 +2775,6 @@ class Canonizer(gof.LocalOptimizer):
if op not in [self.main, self.inverse, self.reciprocal]: if op not in [self.main, self.inverse, self.reciprocal]:
return False return False
inputs = node.inputs
out = node.outputs[0] out = node.outputs[0]
assert len(node.outputs) == 1 assert len(node.outputs) == 1
...@@ -2934,7 +2935,6 @@ def local_sum_div_dimshuffle(node): ...@@ -2934,7 +2935,6 @@ def local_sum_div_dimshuffle(node):
axis = range(node.inputs[0].ndim) axis = range(node.inputs[0].ndim)
#print 'axis =', axis #print 'axis =', axis
thing_summed = node.inputs[0] thing_summed = node.inputs[0]
dimshuffled = None
if thing_summed.owner and thing_summed.owner.op == T.true_div: if thing_summed.owner and thing_summed.owner.op == T.true_div:
numerator, denominator = thing_summed.owner.inputs numerator, denominator = thing_summed.owner.inputs
...@@ -3035,11 +3035,13 @@ def local_sum_sum(node): ...@@ -3035,11 +3035,13 @@ def local_sum_sum(node):
if summed.owner.op.axis is None: if summed.owner.op.axis is None:
# special case of local_cut_useless_reduce # special case of local_cut_useless_reduce
return [T.Sum(None, dtype=out_dtype)(summed.owner.inputs[0])] return [T.Sum(None, dtype=out_dtype)(
summed.owner.inputs[0])]
if node.op.axis is None: if node.op.axis is None:
# we're summing up everything anyway so lets # we're summing up everything anyway so lets
# do it all at once # do it all at once
return [T.Sum(None, dtype=out_dtype)(summed.owner.inputs[0])] return [T.Sum(None, dtype=out_dtype)(
summed.owner.inputs[0])]
newaxis = list(tuple(summed.owner.op.axis)) newaxis = list(tuple(summed.owner.op.axis))
# figure out which dimensions of the original input # figure out which dimensions of the original input
...@@ -3113,7 +3115,7 @@ def local_sum_alloc(node): ...@@ -3113,7 +3115,7 @@ def local_sum_alloc(node):
assert val.size == 1 assert val.size == 1
val = val.reshape(1)[0] * T.mul(*shapes) val = val.reshape(1)[0] * T.mul(*shapes)
return [T.cast(val, dtype=node.outputs[0].dtype)] return [T.cast(val, dtype=node.outputs[0].dtype)]
except TypeError, e: except TypeError:
pass pass
else: else:
try: try:
...@@ -3127,7 +3129,7 @@ def local_sum_alloc(node): ...@@ -3127,7 +3129,7 @@ def local_sum_alloc(node):
return [T.alloc(T.cast(val, dtype=node.outputs[0].dtype), return [T.alloc(T.cast(val, dtype=node.outputs[0].dtype),
*[shapes[i] for i in xrange(len(shapes)) *[shapes[i] for i in xrange(len(shapes))
if i not in node.op.axis])] if i not in node.op.axis])]
except TypeError, e: except TypeError:
pass pass
...@@ -4433,7 +4435,6 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024): ...@@ -4433,7 +4435,6 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
fusion optimization. We skip this optimization. You can ignore this message, fusion optimization. We skip this optimization. You can ignore this message,
your code will run correctly, but may be slower.""") your code will run correctly, but may be slower.""")
otype = node.outputs[0].type
s_new_out = node.op.scalar_op(*s_g) s_new_out = node.op.scalar_op(*s_g)
try: try:
s_new_out.owner.op.c_code(s_new_out.owner, s_new_out.owner.op.c_code(s_new_out.owner,
...@@ -4509,7 +4510,7 @@ class FusionOptimizer(Optimizer): ...@@ -4509,7 +4510,7 @@ class FusionOptimizer(Optimizer):
zip(node.outputs, new_outputs), zip(node.outputs, new_outputs),
reason=self.__class__.__name__) reason=self.__class__.__name__)
did_something = True did_something = True
except InconsistencyError, e: except InconsistencyError:
pass pass
if config.tensor.local_elemwise_fusion: if config.tensor.local_elemwise_fusion:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论