Backporting code to work with Python 2.4

Summary of changes: - Added implementations of functools.partial, defaultdict, any, all - Re-wrote relative imports to be absolute imports - Re-wrote conditional expressions to normal if/else expressions
上级 fc3d3432
...@@ -24,6 +24,7 @@ To learn more, check out: ...@@ -24,6 +24,7 @@ To learn more, check out:
""" """
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
import gof
from gof import \ from gof import \
CLinker, OpWiseCLinker, DualLinker, Linker, LocalLinker, PerformLinker, \ CLinker, OpWiseCLinker, DualLinker, Linker, LocalLinker, PerformLinker, \
......
from .. import gof from theano import gof
from .. import gradient as G from theano import gradient as G
from function_module import function from function_module import function
......
...@@ -6,19 +6,19 @@ from StringIO import StringIO ...@@ -6,19 +6,19 @@ from StringIO import StringIO
import numpy import numpy
from .. import gof from theano import gof
from ..gof import Env, graph, utils, link from theano.gof import Env, graph, utils, link
from ..gof.link import WrapLinkerMany, raise_with_op from theano.gof.link import WrapLinkerMany, raise_with_op
from ..gof.cutils import run_cthunk #from theano.gof.cutils import run_cthunk
from ..gof.cc import OpWiseCLinker, CLinker from theano.gof.cc import OpWiseCLinker, CLinker
from ..compile.function_module import (FunctionMaker, from theano.compile.function_module import (FunctionMaker,
Function, Function,
infer_reuse_pattern, infer_reuse_pattern,
SymbolicInput, SymbolicInput,
SymbolicInputKit, SymbolicInputKit,
SymbolicOutput, SymbolicOutput,
Supervisor) Supervisor)
from ..compile.mode import Mode, register_mode from theano.compile.mode import Mode, register_mode
######################## ########################
...@@ -221,7 +221,10 @@ def _debugprint(r, prefix='', depth=-1, done=None, file=sys.stdout): ...@@ -221,7 +221,10 @@ def _debugprint(r, prefix='', depth=-1, done=None, file=sys.stdout):
""" """
if depth==0: if depth==0:
return return
done = set() if done is None else done #backport
if done is None:
done = set()
#done = set() if done is None else done
if hasattr(r.owner, 'op'): if hasattr(r.owner, 'op'):
# this variable is the output of computation, # this variable is the output of computation,
# so just print out the apply # so just print out the apply
...@@ -509,9 +512,15 @@ def _find_bad_optimizations2(order, reasons, r_vals): ...@@ -509,9 +512,15 @@ def _find_bad_optimizations2(order, reasons, r_vals):
checked_variables.add(r) checked_variables.add(r)
# (recursively) first check all the variables that could make r look bad: # (recursively) first check all the variables that could make r look bad:
list_of_vars = [old_r for (reason, old_r, olds, news) in reasons[r]]
if (None is not r.owner):
list_of_vars += r.owner.inputs
for var_that_could_make_r_look_bad in \ for var_that_could_make_r_look_bad in \
[old_r for (reason, old_r, olds, news) in reasons[r]] \ list_of_vars:
+ ([] if (None is r.owner) else r.owner.inputs): #backport
#[old_r for (reason, old_r, olds, news) in reasons[r]] \
#+ ([] if (None is r.owner) else r.owner.inputs):
check_variable(var_that_could_make_r_look_bad) check_variable(var_that_could_make_r_look_bad)
check_variable_norec(r) check_variable_norec(r)
...@@ -559,11 +568,18 @@ class _EnvEvent(object): ...@@ -559,11 +568,18 @@ class _EnvEvent(object):
def __str__(self): def __str__(self):
if self.kind == 'change': if self.kind == 'change':
if (self.op != 'output'):
msg = str(len(self.node.inputs))
else:
msg = ''
return ' '.join(['change', return ' '.join(['change',
self.reason, self.reason,
str(self.op), str(self.op),
str(self.idx), str(self.idx),
str(len(self.node.inputs)) if (self.op != 'output') else '']) msg])
#backport
#str(len(self.node.inputs)) if (self.op != 'output') else ''])
else: else:
return str(self.__dict__) return str(self.__dict__)
...@@ -1015,7 +1031,7 @@ class _Maker(FunctionMaker): #inheritance buys a few helper functions ...@@ -1015,7 +1031,7 @@ class _Maker(FunctionMaker): #inheritance buys a few helper functions
# WARNING: this is a global mechanism... so it will screw up if we are trying to use # WARNING: this is a global mechanism... so it will screw up if we are trying to use
# multiple modes at once. # multiple modes at once.
from ..tensor import TensorType #to set filter_check_isfinite from theano.tensor import TensorType #to set filter_check_isfinite
TensorType.filter_checks_isfinite = mode.check_isfinite TensorType.filter_checks_isfinite = mode.check_isfinite
# Handle the case where inputs and/or outputs is a single Variable (not in a list) # Handle the case where inputs and/or outputs is a single Variable (not in a list)
...@@ -1053,8 +1069,19 @@ class _Maker(FunctionMaker): #inheritance buys a few helper functions ...@@ -1053,8 +1069,19 @@ class _Maker(FunctionMaker): #inheritance buys a few helper functions
for j in xrange(max(len(li), len(l0))): for j in xrange(max(len(li), len(l0))):
if j >= len(li) or j >= len(l0) or li[j] != l0[j]: if j >= len(li) or j >= len(l0) or li[j] != l0[j]:
print >> infolog, "* ", j, print >> infolog, "* ", j,
print >> infolog, " ", str(li[j]) if j < len(li) else '-', if j < len(li):
print >> infolog, " ", str(l0[j]) if j < len(l0) else '-' msg = str(li[j])
else:
msg = '-'
print >> infolog, " ", msg
if j < len(l0):
msg = str(l0[j])
else:
msg = '-'
print >> infolog, " ", msg
#backport
#print >> infolog, " ", str(li[j]) if j < len(li) else '-',
#print >> infolog, " ", str(l0[j]) if j < len(l0) else '-'
else: else:
pass pass
raise StochasticOrder(infolog.getvalue()) raise StochasticOrder(infolog.getvalue())
......
...@@ -6,11 +6,14 @@ __docformat__ = "restructuredtext en" ...@@ -6,11 +6,14 @@ __docformat__ = "restructuredtext en"
import copy_reg import copy_reg
import cPickle import cPickle
from functools import partial import sys
if sys.version_info[:2] >= (2,5):
from functools import partial
import numpy import numpy
from .. import gof import theano.gof
import sys #from theano import gof
import copy import copy
import mode as mode_module import mode as mode_module
...@@ -33,8 +36,18 @@ def infer_reuse_pattern(env, outputs_to_disown): ...@@ -33,8 +36,18 @@ def infer_reuse_pattern(env, outputs_to_disown):
do_not_reuse.append(r) do_not_reuse.append(r)
node = r.owner node = r.owner
op = node.op op = node.op
dmap = op.destroy_map if hasattr(op, 'destroy_map') else {} if hasattr(op, 'destroy_map'):
vmap = op.view_map if hasattr(op, 'view_map') else {} dmap = op.destroy_map
else:
dmap = {}
if hasattr(op, 'view_map'):
vmap = op.view_map
else:
vmap = {}
#backport
#dmap = op.destroy_map if hasattr(op, 'destroy_map') else {}
#vmap = op.view_map if hasattr(op, 'view_map') else {}
for l in dmap.values() + vmap.values(): for l in dmap.values() + vmap.values():
for i in l: for i in l:
walk(node.inputs[i]) walk(node.inputs[i])
...@@ -252,7 +265,12 @@ class Function(object): ...@@ -252,7 +265,12 @@ class Function(object):
c.provided = 0 # this is a count of how many times the input has been provided (reinitialized to 0 on __call__) c.provided = 0 # this is a count of how many times the input has been provided (reinitialized to 0 on __call__)
finder[i] = c finder[i] = c
finder[input.variable] = c finder[input.variable] = c
finder[input.name] = c if input.name not in finder else DUPLICATE if input.name not in finder:
finder[input.name] = c
else:
finder[input.name] = DUPLICATE
#backport
#finder[input.name] = c if input.name not in finder else DUPLICATE
# inv_finder maps the container to the input (useful for one error message) # inv_finder maps the container to the input (useful for one error message)
inv_finder[c] = input inv_finder[c] = input
#setters.append(partial(assign, c)) #setters.append(partial(assign, c))
...@@ -271,7 +289,12 @@ class Function(object): ...@@ -271,7 +289,12 @@ class Function(object):
# can reinitialize all the containers # can reinitialize all the containers
finder[i] = f finder[i] = f
finder[input] = f finder[input] = f
finder[input.name] = f if input.name not in finder else DUPLICATE if input.name not in finder:
finder[input.name] = f
else:
finder[input.name] = DUPLICATE
#backport
#finder[input.name] = f if input.name not in finder else DUPLICATE
#setters.append(f) #setters.append(f)
# For each input in the kit and its corresponding container, we put an entry in finder. # For each input in the kit and its corresponding container, we put an entry in finder.
# This allows the user to micro-manage elements of the kit if need be. # This allows the user to micro-manage elements of the kit if need be.
...@@ -279,7 +302,12 @@ class Function(object): ...@@ -279,7 +302,12 @@ class Function(object):
for c, sin in zip(cs, sinputs): for c, sin in zip(cs, sinputs):
finder[sin.variable] = c finder[sin.variable] = c
finder[sin.name] = c finder[sin.name] = c
finder[sin.name] = c if sin.name not in finder else DUPLICATE if sin.name not in finder:
finder[sin.name] = c
else:
finder[sin.name] = DUPLICATE
#backport
#finder[sin.name] = c if sin.name not in finder else DUPLICATE
inv_finder[c] = input inv_finder[c] = input
c.required = required c.required = required
c.provided = 0 c.provided = 0
...@@ -507,9 +535,15 @@ class SanityCheckFunction(Function): ...@@ -507,9 +535,15 @@ class SanityCheckFunction(Function):
if not input.mutable: if not input.mutable:
if not self.check_equal(c1.value, c2.value): if not self.check_equal(c1.value, c2.value):
name = c2.name name = c2.name
if name:
the_name = name
else:
the_name = ""
raise ValueError("Input #%i%s using %s and %s differs." raise ValueError("Input #%i%s using %s and %s differs."
% (i, % (i,
" (%s)" % name if name else "", #backport
#" (%s)" % name if name else "",
" (%s)" % the_name,
self.maker.mode, self.maker.mode,
fn.maker.mode), fn.maker.mode),
c1.value, c2.value) c1.value, c2.value)
...@@ -520,9 +554,15 @@ class SanityCheckFunction(Function): ...@@ -520,9 +554,15 @@ class SanityCheckFunction(Function):
r2 = c2.value r2 = c2.value
if not self.check_equal(r1, r2): if not self.check_equal(r1, r2):
name = c2.name name = c2.name
if name:
the_name = name
else:
the_name = ""
raise ValueError("Variable #%i%s using %s and %s differs." raise ValueError("Variable #%i%s using %s and %s differs."
% (i, % (i,
" (%s)" % name if name else "", #backport
#" (%s)" % name if name else "",
" (%s)" % the_name,
self.maker.mode, self.maker.mode,
fn.maker.mode), fn.maker.mode),
r1, r2) r1, r2)
...@@ -598,7 +638,10 @@ class FunctionMaker(object): ...@@ -598,7 +638,10 @@ class FunctionMaker(object):
in the graph from the inputs to the outputs in the graph from the inputs to the outputs
""" """
mode = mode if mode is not None else mode_module.default_mode if mode is None:
mode = mode_module.default_mode
#backport
#mode = mode if mode is not None else mode_module.default_mode
# Handle the case where inputs and/or outputs is a single Variable (not in a list) # Handle the case where inputs and/or outputs is a single Variable (not in a list)
unpack_single = False unpack_single = False
...@@ -703,7 +746,16 @@ class FunctionMaker(object): ...@@ -703,7 +746,16 @@ class FunctionMaker(object):
def _pickle_FunctionMaker(fm): def _pickle_FunctionMaker(fm):
outputs = None if fm.return_none else (fm.outputs[0] if fm.unpack_single else fm.outputs) if fm.return_none:
outputs = None
else:
if fm.unpack_single:
outputs = fm.outputs[0]
else:
outputs = fm.outputs
#backport
#outputs = None if fm.return_none else (fm.outputs[0] if fm.unpack_single else fm.outputs)
rval = (_constructor_FunctionMaker, (fm.inputs, outputs, fm.mode, fm.accept_inplace)) rval = (_constructor_FunctionMaker, (fm.inputs, outputs, fm.mode, fm.accept_inplace))
return rval return rval
...@@ -788,12 +840,19 @@ def function(inputs, outputs, mode=None, accept_inplace = False): ...@@ -788,12 +840,19 @@ def function(inputs, outputs, mode=None, accept_inplace = False):
f[<kitname>] = seed #re-seed the elements of a RandomKit f[<kitname>] = seed #re-seed the elements of a RandomKit
""" """
mode = mode if mode is not None else mode_module.default_mode if mode is None:
mode = mode_module.default_mode
#backport
#mode = mode if mode is not None else mode_module.default_mode
inputs = map(convert_function_input, inputs) inputs = map(convert_function_input, inputs)
if outputs is not None: if outputs is not None:
outputs = map(FunctionMaker.wrap_out, outputs) if isinstance(outputs, (list, tuple)) else FunctionMaker.wrap_out(outputs) if isinstance(outputs, (list, tuple)):
outputs = map(FunctionMaker.wrap_out, outputs)
else:
outputs = FunctionMaker.wrap_out(outputs)
#backport
#outputs = map(FunctionMaker.wrap_out, outputs) if isinstance(outputs, (list, tuple)) else FunctionMaker.wrap_out(outputs)
defaults = [getattr(input, 'value', None) for input in inputs] defaults = [getattr(input, 'value', None) for input in inputs]
...@@ -807,9 +866,17 @@ def function(inputs, outputs, mode=None, accept_inplace = False): ...@@ -807,9 +866,17 @@ def function(inputs, outputs, mode=None, accept_inplace = False):
#return a different kind of function #return a different kind of function
def dup_defaults(): def dup_defaults():
# TODO This may need to be changed to use containers as defaults. # TODO This may need to be changed to use containers as defaults.
return [copy.copy(default.value) if isinstance(default, gof.Container) else retval = []
copy.copy(default) for default in defaults:
for default in defaults] if isinstance(default, gof.Container):
retval +=[copy.copy(default.value)]
else:
retval +=[copy.copy(default)]
return retval
#backport
#return [copy.copy(default.value) if isinstance(default, gof.Container) else
# copy.copy(default)
# for default in defaults]
makers = [FunctionMaker(inputs, outputs, m, accept_inplace = accept_inplace) for m in mode[1:]] makers = [FunctionMaker(inputs, outputs, m, accept_inplace = accept_inplace) for m in mode[1:]]
fns = [maker.create(dup_defaults(), trustme = True) for maker in makers] fns = [maker.create(dup_defaults(), trustme = True) for maker in makers]
builder = partial(SanityCheckFunction, fns, check_equal) builder = partial(SanityCheckFunction, fns, check_equal)
......
...@@ -39,11 +39,23 @@ class SymbolicInput(object): ...@@ -39,11 +39,23 @@ class SymbolicInput(object):
implicit=False): implicit=False):
assert implicit is not None # Safety check. assert implicit is not None # Safety check.
self.variable = variable self.variable = variable
self.name = variable.name if (autoname and name is None) else name if (autoname and name is None):
self.name = variable.name
else:
self.name = name
#backport
#self.name = variable.name if (autoname and name is None) else name
if self.name is not None and not isinstance(self.name, str): if self.name is not None and not isinstance(self.name, str):
raise TypeError("name must be a string! (got: %s)" % self.name) raise TypeError("name must be a string! (got: %s)" % self.name)
self.update = update self.update = update
self.mutable = mutable if (mutable is not None) else (update is not None) if (mutable is not None):
self.mutable = mutable
else:
self.mutable = (update is not None)
#backport
#self.mutable = mutable if (mutable is not None) else (update is not None)
self.strict = strict self.strict = strict
self.implicit = implicit self.implicit = implicit
......
...@@ -8,10 +8,14 @@ __docformat__ = "restructuredtext en" ...@@ -8,10 +8,14 @@ __docformat__ = "restructuredtext en"
from theano import gof from theano import gof
from theano.printing import pprint from theano.printing import pprint
from collections import defaultdict
from itertools import chain
from functools import partial
import io, sys import io, sys
if sys.version_info[:2] >= (2,5):
from collections import defaultdict
from itertools import chain
if sys.version_info[:2] >= (2,5):
from functools import partial
import function_module as F import function_module as F
import mode as get_mode import mode as get_mode
...@@ -418,8 +422,26 @@ class Method(Component): ...@@ -418,8 +422,26 @@ class Method(Component):
outputs = self.outputs outputs = self.outputs
_inputs = [x.variable for x in inputs] _inputs = [x.variable for x in inputs]
# Grab the variables that are not accessible from either the inputs or the updates. # Grab the variables that are not accessible from either the inputs or the updates.
outputs_list = [] if outputs is None else (list(outputs) if isinstance(outputs, (list, tuple)) else [outputs]) if outputs is None:
outputs_variable_list = [o.variable if isinstance(o, io.Out) else o for o in outputs_list] outputs_list = []
else:
if isinstance(outputs, (list, tuple)):
outputs_list = list(outputs)
else:
outputs_list = [outputs]
#backport
#outputs_list = [] if outputs is None else (list(outputs) if isinstance(outputs, (list, tuple)) else [outputs])
outputs_variable_list = []
for o in outputs_list:
if isinstance(o, io.Out):
outputs_variable_list += [o.variable]
else:
outputs_variable_list += [o]
#backport
#outputs_variable_list = [o.variable if isinstance(o, io.Out) else o for o in outputs_list]
for input in gof.graph.inputs(outputs_variable_list for input in gof.graph.inputs(outputs_variable_list
+ [x.update for x in inputs if getattr(x, 'update', False)], + [x.update for x in inputs if getattr(x, 'update', False)],
blockers = _inputs): blockers = _inputs):
...@@ -448,7 +470,13 @@ class Method(Component): ...@@ -448,7 +470,13 @@ class Method(Component):
assert type(storage) is io.In assert type(storage) is io.In
inputs.append(storage) inputs.append(storage)
effective_mode = mode if self.mode is None else self.mode if self.mode is None:
effective_mode = mode
else:
effective_mode = self.mode
#backport
#effective_mode = mode if self.mode is None else self.mode
rval = F.function(inputs, outputs, effective_mode) rval = F.function(inputs, outputs, effective_mode)
memo[self] = rval memo[self] = rval
return rval return rval
...@@ -459,7 +487,13 @@ class Method(Component): ...@@ -459,7 +487,13 @@ class Method(Component):
rval = 'inputs: %s\n' % ", ".join(map(str, self.inputs)) rval = 'inputs: %s\n' % ", ".join(map(str, self.inputs))
else: else:
rval = '' rval = ''
inputs, outputs, updates = self.inputs, self.outputs if isinstance(self.outputs, (list, tuple)) else [self.outputs], self.updates if isinstance(self.outputs, (list, tuple)):
inputs, outputs, updates = self.inputs, self.outputs
else:
inputs, outputs, updates = [self.outputs], self.updates
#backport
#inputs, outputs, updates = self.inputs, self.outputs if isinstance(self.outputs, (list, tuple)) else [self.outputs], self.updates
# If mode is in kwargs, prints the optimized version of the method # If mode is in kwargs, prints the optimized version of the method
mode = kwargs.pop('mode', None) mode = kwargs.pop('mode', None)
...@@ -472,10 +506,16 @@ class Method(Component): ...@@ -472,10 +506,16 @@ class Method(Component):
return rval return rval
def __str__(self): def __str__(self):
if self.updates:
sep = "; "
else:
sep = ""
return "Method(%s -> %s%s%s)" % \ return "Method(%s -> %s%s%s)" % \
(self.inputs, (self.inputs,
self.outputs, self.outputs,
"; " if self.updates else "", sep,
#backport
#"; " if self.updates else "",
", ".join("%s <= %s" % (old, new) for old, new in self.updates.iteritems())) ", ".join("%s <= %s" % (old, new) for old, new in self.updates.iteritems()))
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
...@@ -603,7 +643,16 @@ class Composite(Component): ...@@ -603,7 +643,16 @@ class Composite(Component):
self.set(item, value) self.set(item, value)
def __iter__(self): def __iter__(self):
return (c.r if isinstance(c, (External, Member)) else c for c in self.components()) retval = []
for c in self.components():
if isinstance(c, (External, Member)):
retval += [c.r]
else:
retval += [c]
return retval
#backport
#return (c.r if isinstance(c, (External, Member)) else c for c in self.components())
...@@ -1047,7 +1096,12 @@ class Module(ComponentDict): ...@@ -1047,7 +1096,12 @@ class Module(ComponentDict):
# Function to go through member lists and dictionaries recursively, # Function to go through member lists and dictionaries recursively,
# to look for submodules on which make_module_instance needs to be called # to look for submodules on which make_module_instance needs to be called
def recurse(v): def recurse(v):
iter = enumerate(v) if isinstance(v,list) else v.iteritems() if isinstance(v,list):
iter = enumerate(v)
else:
iter = v.iteritems()
#backport
#iter = enumerate(v) if isinstance(v,list) else v.iteritems()
for sk,sv in iter: for sk,sv in iter:
if isinstance(sv,(list,dict)): if isinstance(sv,(list,dict)):
sv = recurse(sv) sv = recurse(sv)
......
import time, atexit import time, atexit
from ..gof.link import WrapLinkerMany from theano.gof.link import WrapLinkerMany
from ..gof.cutils import run_cthunk from theano.gof.cutils import run_cthunk
from ..compile.mode import Mode, predefined_linkers, register_mode, predefined_modes from theano.compile.mode import Mode, predefined_linkers, register_mode, predefined_modes
from ..gof.cc import OpWiseCLinker from theano.gof.cc import OpWiseCLinker
class ProfileMode(Mode): class ProfileMode(Mode):
def __init__(self, linker=OpWiseCLinker(), optimizer=None): def __init__(self, linker=OpWiseCLinker(), optimizer=None):
...@@ -82,7 +82,13 @@ class ProfileMode(Mode): ...@@ -82,7 +82,13 @@ class ProfileMode(Mode):
tot=0 tot=0
for f,t,a,ci in otimes[:n_ops_to_print]: for f,t,a,ci in otimes[:n_ops_to_print]:
tot+=t tot+=t
print ' %.2f%% %.3fs %.3fs %s %s' % (f*100, tot, t, '*' if ci else ' ', a) if ci:
msg = '*'
else:
msg = ' '
print ' %.2f%% %.3fs %.3fs %s %s' % (f*100, tot, t, msg, a)
#backport
#print ' %.2f%% %.3fs %.3fs %s %s' % (f*100, tot, t, '*' if ci else ' ', a)
print ' ... (remaining %i Ops account for %.2f%%(%.2fs) of the runtime)'\ print ' ... (remaining %i Ops account for %.2f%%(%.2fs) of the runtime)'\
%(max(0, len(otimes)-n_ops_to_print), %(max(0, len(otimes)-n_ops_to_print),
sum(f for f, t, a, ci in otimes[n_ops_to_print:])*100, sum(f for f, t, a, ci in otimes[n_ops_to_print:])*100,
...@@ -104,7 +110,13 @@ class ProfileMode(Mode): ...@@ -104,7 +110,13 @@ class ProfileMode(Mode):
tot=0 tot=0
for f,t,a,ci in sotimes[:n_ops_to_print]: for f,t,a,ci in sotimes[:n_ops_to_print]:
tot+=t tot+=t
print ' %.2f%% %.3fs %.3fs %s %s' % (f*100, tot, t, '*' if ci else ' ', a) if ci:
msg = '*'
else:
msg = ' '
print ' %.2f%% %.3fs %.3fs %s %s' % (f*100, tot, t, msg, a)
#backport
#print ' %.2f%% %.3fs %.3fs %s %s' % (f*100, tot, t, '*' if ci else ' ', a)
print ' ... (remaining %i Ops account for %.2f%%(%.2fs) of the runtime)'\ print ' ... (remaining %i Ops account for %.2f%%(%.2fs) of the runtime)'\
%(max(0, len(sotimes)-n_ops_to_print), %(max(0, len(sotimes)-n_ops_to_print),
sum(f for f, t, a in sotimes[n_ops_to_print:])*100, sum(f for f, t, a in sotimes[n_ops_to_print:])*100,
......
...@@ -29,13 +29,21 @@ class CallCache(object): ...@@ -29,13 +29,21 @@ class CallCache(object):
self.cache = {} self.cache = {}
def persist(self, filename=None): def persist(self, filename=None):
filename = self.filename if filename is None else filename if filename is None:
filename = self.filename
#backport
#filename = self.filename if filename is None else filename
f = file(filename, 'w') f = file(filename, 'w')
cPickle.dump(self.cache, f) cPickle.dump(self.cache, f)
f.close() f.close()
def call(self, fn, args=(), key=None): def call(self, fn, args=(), key=None):
key = (fn, tuple(args)) if key is None else key if key is None:
key = (fn, tuple(args))
#backport
#key = (fn, tuple(args)) if key is None else key
if key not in self.cache: if key not in self.cache:
debug('cache miss', len(self.cache)) debug('cache miss', len(self.cache))
self.cache[key] = fn(*args) self.cache[key] = fn(*args)
......
...@@ -5,7 +5,10 @@ Defines Linkers that deal with C implementations. ...@@ -5,7 +5,10 @@ Defines Linkers that deal with C implementations.
# Python imports # Python imports
from copy import copy from copy import copy
import re #for set_compiledir import re #for set_compiledir
import os, sys, platform, StringIO, time, hashlib import os, sys, platform, StringIO, time
import md5
if sys.version_info[:2] >= (2,5):
import hashlib
# weave import # weave import
from scipy import weave from scipy import weave
...@@ -37,7 +40,7 @@ def error(*args): ...@@ -37,7 +40,7 @@ def error(*args):
sys.stderr.write('ERROR:'+ ' '.join(str(a) for a in args)+'\n') sys.stderr.write('ERROR:'+ ' '.join(str(a) for a in args)+'\n')
_logger.error(' '.join(str(a) for a in args)) _logger.error(' '.join(str(a) for a in args))
from .callcache import CallCache from theano.gof.callcache import CallCache
def get_module_cache(): def get_module_cache():
return cmodule.get_module_cache(get_compiledir()) return cmodule.get_module_cache(get_compiledir())
...@@ -514,7 +517,12 @@ class CLinker(link.Linker): ...@@ -514,7 +517,12 @@ class CLinker(link.Linker):
# The hash calculated on the code identifies it so weave can cache properly. # The hash calculated on the code identifies it so weave can cache properly.
# (the hash has to be used outside of the support code because weave does not consider changes in the support code) # (the hash has to be used outside of the support code because weave does not consider changes in the support code)
hash = hashlib.md5(struct_code).hexdigest() # hashlib is new to 2.5
if sys.version_info[:2] < (2,5):
hash = md5.new(struct_code).hexdigest()
else:
hash = hashlib.md5(struct_code).hexdigest()
struct_name = '__struct_compiled_op_%s' % hash struct_name = '__struct_compiled_op_%s' % hash
#struct_code %= dict(name = struct_name) #struct_code %= dict(name = struct_name)
struct_code = re.sub("<<<<NAME>>>>", struct_name, struct_code) struct_code = re.sub("<<<<NAME>>>>", struct_name, struct_code)
...@@ -769,7 +777,10 @@ class CLinker(link.Linker): ...@@ -769,7 +777,10 @@ class CLinker(link.Linker):
""" """
This method is a callback for `ModuleCache.module_from_key` This method is a callback for `ModuleCache.module_from_key`
""" """
location = get_compiledir() if location is None else location if location is None:
location = get_compiledir()
#backport
#location = get_compiledir() if location is None else location
mod = self.build_dynamic_module() mod = self.build_dynamic_module()
get_lock() get_lock()
try: try:
......
...@@ -237,7 +237,12 @@ class ModuleCache(object): ...@@ -237,7 +237,12 @@ class ModuleCache(object):
self.module_from_name = dict(self.module_from_name) self.module_from_name = dict(self.module_from_name)
self.entry_from_key = dict(self.entry_from_key) self.entry_from_key = dict(self.entry_from_key)
self.stats = [0, 0, 0] self.stats = [0, 0, 0]
self.force_fresh = self.force_fresh if force_fresh is None else force_fresh if force_fresh is None:
self.force_fresh = self.force_fresh
else:
self.force_fresh = force_fresh
#backport
#self.force_fresh = self.force_fresh if force_fresh is None else force_fresh
self.loaded_key_pkl = set() self.loaded_key_pkl = set()
self.refresh() self.refresh()
...@@ -393,7 +398,11 @@ class ModuleCache(object): ...@@ -393,7 +398,11 @@ class ModuleCache(object):
:param age_thresh: dynamic modules whose last access time is more than ``age_thresh`` :param age_thresh: dynamic modules whose last access time is more than ``age_thresh``
seconds ago will be erased. seconds ago will be erased.
""" """
age_thresh = self.age_thresh if age_thresh is None else age_thresh if age_thresh is None:
age_thresh = self.age_thresh
#backport
#age_thresh = self.age_thresh if age_thresh is None else age_thresh
compilelock.get_lock() compilelock.get_lock()
try: try:
# update the age of modules that have been accessed by other processes # update the age of modules that have been accessed by other processes
...@@ -474,7 +483,14 @@ def get_gcc_shared_library_arg(): ...@@ -474,7 +483,14 @@ def get_gcc_shared_library_arg():
def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[], lib_dirs=[], libs=[], def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[], lib_dirs=[], libs=[],
preargs=[], tmpdir=None): preargs=[], tmpdir=None):
#TODO: don't to the dlimport in this function #TODO: don't to the dlimport in this function
preargs= [] if preargs is None else list(preargs)
if preargs is None:
preargs = []
else:
preargs = list(preargs)
#backport
#preargs= [] if preargs is None else list(preargs)
preargs.append('-fPIC') preargs.append('-fPIC')
no_opt = False no_opt = False
...@@ -537,7 +553,13 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[] ...@@ -537,7 +553,13 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[]
def nvcc_module_compile_str(module_name, src_code, location=None, include_dirs=[], lib_dirs=[], libs=[], def nvcc_module_compile_str(module_name, src_code, location=None, include_dirs=[], lib_dirs=[], libs=[],
preargs=[], tmpdir=None): preargs=[], tmpdir=None):
preargs= [] if preargs is None else list(preargs) if preargs is None:
preargs = []
else:
preargs = list(preargs)
#backport
#preargs= [] if preargs is None else list(preargs)
preargs.append('-fPIC') preargs.append('-fPIC')
no_opt = False no_opt = False
......
"""WRITEME""" """WRITEME"""
from collections import defaultdict import sys
if sys.version_info[:2] >= (2,5):
from collections import defaultdict
# otherwise it's implemented in python25.py
import toolbox import toolbox
import graph import graph
......
...@@ -138,7 +138,11 @@ class Container(object): ...@@ -138,7 +138,11 @@ class Container(object):
self.type = r self.type = r
else: else:
self.type = r.type self.type = r.type
self.name = r.name if name is None else name if name is None:
self.name = r.name
#backport
#self.name = r.name if name is None else name
self.storage = storage self.storage = storage
self.readonly = readonly self.readonly = readonly
self.strict = strict self.strict = strict
......
...@@ -4,6 +4,7 @@ amount of useful generic optimization tools. ...@@ -4,6 +4,7 @@ amount of useful generic optimization tools.
""" """
import sys
import graph import graph
from env import InconsistencyError from env import InconsistencyError
import utils import utils
...@@ -11,9 +12,13 @@ import unify ...@@ -11,9 +12,13 @@ import unify
import toolbox import toolbox
import op import op
from copy import copy from copy import copy
from collections import deque, defaultdict from theano.gof.python25 import any, all
#if sys.version_info[:2] >= (2,5):
# from collections import defaultdict
from collections import deque
import destroyhandler as dh import destroyhandler as dh
import sys
import traceback import traceback
_optimizer_idx = [0] _optimizer_idx = [0]
...@@ -900,7 +905,12 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -900,7 +905,12 @@ class EquilibriumOptimizer(NavigatorOptimizer):
max_use_abort = True max_use_abort = True
else: else:
lopt_change = self.process_node(env, node, lopt) lopt_change = self.process_node(env, node, lopt)
process_count[lopt] += 1 if lopt_change else 0 if lopt_change:
process_count[lopt] += 1
else:
process_count[lopt] += 0
#backport
#process_count[lopt] += 1 if lopt_change else 0
changed |= lopt_change changed |= lopt_change
finally: finally:
self.detach_updater(env, u) self.detach_updater(env, u)
......
import sys
if sys.version_info[:2] >= (2,5):
from collections import defaultdict
else:
from python25 import defaultdict
from collections import defaultdict
import opt import opt
......
...@@ -23,6 +23,42 @@ if sys.version_info[:2] < (2,5): ...@@ -23,6 +23,42 @@ if sys.version_info[:2] < (2,5):
newfunc.args = args newfunc.args = args
newfunc.keywords = keywords newfunc.keywords = keywords
return newfunc return newfunc
class defaultdict(dict):
def __init__(self, default_factory=None, *a, **kw):
if (default_factory is not None and
not hasattr(default_factory, '__call__')):
raise TypeError('first argument must be callable')
dict.__init__(self, *a, **kw)
self.default_factory = default_factory
def __getitem__(self, key):
try:
return dict.__getitem__(self, key)
except KeyError:
return self.__missing__(key)
def __missing__(self, key):
if self.default_factory is None:
raise KeyError(key)
self[key] = value = self.default_factory()
return value
def __reduce__(self):
if self.default_factory is None:
args = tuple()
else:
args = self.default_factory,
# consider replacing items() with iteritems()
return type(self), args, None, None, self.items()
def copy(self):
return self.__copy__()
def __copy__(self):
return type(self)(self.default_factory, self)
def __deepcopy__(self, memo):
import copy
return type(self)(self.default_factory,
copy.deepcopy(self.items()))
def __repr__(self):
return 'defaultdict(%s, %s)' % (self.default_factory,
dict.__repr__(self))
else: else:
# Only bother with this else clause and the __all__ line if you are putting # Only bother with this else clause and the __all__ line if you are putting
# this in a separate file. # this in a separate file.
......
import sys
if sys.version_info[:2] >= (2,5):
from functools import partial
from functools import partial
import graph import graph
import sys
class AlreadyThere(Exception): class AlreadyThere(Exception):
......
...@@ -32,7 +32,13 @@ class Print(Op): ...@@ -32,7 +32,13 @@ class Print(Op):
xout[0] = xin xout[0] = xin
for attr in self.attrs: for attr in self.attrs:
temp = getattr(xin, attr) temp = getattr(xin, attr)
print self.message, attr,'=', temp() if callable(temp) else temp if callable(temp):
pmsg = temp()
else:
psmg = temp
print self.message, attr,'=', pmsg
#backport
#print self.message, attr,'=', temp() if callable(temp) else temp
def grad(self,input,output_gradients): def grad(self,input,output_gradients):
return output_gradients return output_gradients
...@@ -233,7 +239,12 @@ class PPrinter: ...@@ -233,7 +239,12 @@ class PPrinter:
strings.append((i + 1000, "%s <- %s" % (name, pprinter.process(output)))) strings.append((i + 1000, "%s <- %s" % (name, pprinter.process(output))))
i += 1 i += 1
if output.name is not None or output in outputs: if output.name is not None or output in outputs:
name = 'out[%i]' % outputs.index(output) if output.name is None else output.name if output.name is None:
name = 'out[%i]' % outputs.index(output)
else:
name = output.name
#backport
#name = 'out[%i]' % outputs.index(output) if output.name is None else output.name
current = output current = output
try: try:
idx = 2000 + outputs.index(output) idx = 2000 + outputs.index(output)
......
差异被折叠。
...@@ -12,10 +12,10 @@ from scipy import sparse ...@@ -12,10 +12,10 @@ from scipy import sparse
import scipy.sparse import scipy.sparse
from theano.printing import Print from theano.printing import Print
from .. import gof from theano import gof
from .. import tensor from theano import tensor
from .. import compile from theano import compile
from .. import scalar from theano import scalar
#TODO: move this decorator to the compile submodule #TODO: move this decorator to the compile submodule
def register_specialize(lopt, *tags, **kwargs): def register_specialize(lopt, *tags, **kwargs):
...@@ -273,7 +273,12 @@ class CSMProperties(gof.Op): ...@@ -273,7 +273,12 @@ class CSMProperties(gof.Op):
[data, tensor.ivector(), tensor.ivector(), tensor.ivector()]) [data, tensor.ivector(), tensor.ivector(), tensor.ivector()])
def perform(self, node, (csm,), out): def perform(self, node, (csm,), out):
out[0][0] = csm.data if self.kmap is None else csm.data[self.kmap] if self.kmap is None:
out[0][0] = csm.data
else:
out[0][0] = csm.data[self.kmap]
#backport
#out[0][0] = csm.data if self.kmap is None else csm.data[self.kmap]
out[1][0] = numpy.asarray(csm.indices, dtype='int32') out[1][0] = numpy.asarray(csm.indices, dtype='int32')
out[2][0] = numpy.asarray(csm.indptr, dtype='int32') out[2][0] = numpy.asarray(csm.indptr, dtype='int32')
out[3][0] = numpy.asarray(csm.shape, dtype='int32') out[3][0] = numpy.asarray(csm.shape, dtype='int32')
...@@ -1082,8 +1087,19 @@ register_specialize(local_structured_dot) ...@@ -1082,8 +1087,19 @@ register_specialize(local_structured_dot)
def structured_dot_grad(sparse_A, dense_B, ga): def structured_dot_grad(sparse_A, dense_B, ga):
if sparse_A.type.format in ('csc','csr'): if sparse_A.type.format in ('csc','csr'):
sdgcsx = sdg_csc if sparse_A.type.format == 'csc' else sdg_csr if sparse_A.type.format == 'csc':
CSx = CSC if sparse_A.type.format == 'csc' else CSR sdgcsx = sdg_csc
else:
sdgcsx = sdg_csr
#backport
#sdgcsx = sdg_csc if sparse_A.type.format == 'csc' else sdg_csr
if sparse_A.type.format == 'csc':
CSx = CSC
else:
CSx = CSR
#backport
#CSx = CSC if sparse_A.type.format == 'csc' else CSR
g_A_data = sdgcsx(csm_indices(sparse_A),\ g_A_data = sdgcsx(csm_indices(sparse_A),\
csm_indptr(sparse_A), dense_B, ga) csm_indptr(sparse_A), dense_B, ga)
......
差异被折叠。
...@@ -3,19 +3,19 @@ ...@@ -3,19 +3,19 @@
import os, sys, traceback import os, sys, traceback
import numpy import numpy
from ..gof import (utils, Op, Apply, view_roots, PatternSub, DestroyHandler, from theano.gof import (utils, Op, Apply, view_roots, PatternSub, DestroyHandler,
SeqOptimizer, local_optimizer, LocalOptimizer, OpKeyOptimizer, SeqOptimizer, local_optimizer, LocalOptimizer, OpKeyOptimizer,
InconsistencyError) InconsistencyError)
from ..printing import pprint, FunctionPrinter from theano.printing import pprint, FunctionPrinter
from .opt import register_specialize, out2in, insert_inplace_optimizer from theano.tensor.opt import register_specialize, out2in, insert_inplace_optimizer
# opt.py # opt.py
import basic as T import basic as T
#NB: this clobbers the builtin 'compile' symbol #NB: this clobbers the builtin 'compile' symbol
from .. import compile #to register the optimizer built by this file from theano import compile #to register the optimizer built by this file
from .blas_headers import cblas_header_text, blas_header_text from theano.tensor.blas_headers import cblas_header_text, blas_header_text
@utils.memoize @utils.memoize
def ldflags(libs=True, flags=False): def ldflags(libs=True, flags=False):
...@@ -365,9 +365,16 @@ gemm = Gemm() ...@@ -365,9 +365,16 @@ gemm = Gemm()
pprint.assign(gemm, FunctionPrinter('gemm')) pprint.assign(gemm, FunctionPrinter('gemm'))
def res_is_a(node, op, maxclients=None): def res_is_a(node, op, maxclients=None):
return node.owner \ if maxclients is not None:
retval = (len(node.clients) <= maxclients)
else:
retval = True
return node.owner \
and node.owner.op == op \ and node.owner.op == op \
and (len(node.clients) <= maxclients if maxclients is not None else True) and retval
#backport
# and (len(node.clients) <= maxclients if maxclients is not None else True)
class GemmLocalOptimizer(LocalOptimizer): class GemmLocalOptimizer(LocalOptimizer):
"""This is a massive beast for recognizing all the ways that a subtraction or addition """This is a massive beast for recognizing all the ways that a subtraction or addition
......
...@@ -2,13 +2,13 @@ import sys ...@@ -2,13 +2,13 @@ import sys
import elemwise_cgen as cgen import elemwise_cgen as cgen
import numpy import numpy
from .. import gof from theano import gof
from ..gof import Op, Apply from theano.gof import Op, Apply
from .. import scalar from theano import scalar
from ..scalar import Scalar from theano.scalar import Scalar
from .. import printing from theano import printing
from ..printing import pprint from theano.printing import pprint
from ..gof.python25 import all from theano.gof.python25 import all
from copy import copy, deepcopy from copy import copy, deepcopy
...@@ -215,19 +215,31 @@ class DimShuffle(Op): ...@@ -215,19 +215,31 @@ class DimShuffle(Op):
'0, 0, NPY_ALIGNED|NPY_ENSURECOPY, NULL)')] '0, 0, NPY_ALIGNED|NPY_ENSURECOPY, NULL)')]
shape_statements = ['npy_intp dimensions[%i]'%nd_out] shape_statements = ['npy_intp dimensions[%i]'%nd_out]
shape_statements += [('dimensions['+str(i)+'] = %(basename)s->dimensions['+str(o)+']') for i, o in enumerate(self.new_order):
if o != 'x' else if o != 'x':
('dimensions['+str(i)+'] = 1') shape_statements += [('dimensions['+str(i)+'] = %(basename)s->dimensions['+str(o)+']')]
for i, o in enumerate(self.new_order)] else:
shape_statements += [('dimensions['+str(i)+'] = 1')]
#backport
#shape_statements += [('dimensions['+str(i)+'] = %(basename)s->dimensions['+str(o)+']')
# if o != 'x' else
# ('dimensions['+str(i)+'] = 1')
# for i, o in enumerate(self.new_order)]
strides_statements = ['npy_intp strides[%i]'%nd_out] strides_statements = ['npy_intp strides[%i]'%nd_out]
#set the strides of the non-broadcasted dimensions #set the strides of the non-broadcasted dimensions
strides_statements += [('strides['+str(i)+'] = %(basename)s->strides['+str(o)+']') for i, o in enumerate(self.new_order):
if o != 'x' else if o != 'x':
('strides['+str(i)+'] = 0') strides_statements += [('strides['+str(i)+'] = %(basename)s->strides['+str(o)+']')]
for i, o in enumerate(self.new_order)] else:
strides_statements += [('strides['+str(i)+'] = 0')]
#backport
#strides_statements += [('strides['+str(i)+'] = %(basename)s->strides['+str(o)+']')
# if o != 'x' else
# ('strides['+str(i)+'] = 0')
# for i, o in enumerate(self.new_order)]
# set the strides of the broadcasted dimensions # set the strides of the broadcasted dimensions
# this algorithm is from numpy: PyArray_Newshape() in cvs/numpy/numpy/core/src/multiarraymodule.c # this algorithm is from numpy: PyArray_Newshape() in cvs/numpy/numpy/core/src/multiarraymodule.c
...@@ -442,7 +454,16 @@ class Elemwise(Op): ...@@ -442,7 +454,16 @@ class Elemwise(Op):
def _rehash(self): def _rehash(self):
items = self.inplace_pattern.items() items = self.inplace_pattern.items()
items.sort() items.sort()
tuple_items = tuple([k for k,v in items] + [(tuple(v) if isinstance(v, (tuple, list)) else v) for k,v in items]) first_part = [k for k,v in items]
second_part = []
for k,v in items:
if isinstance(v, (tuple, list)):
second_part += [tuple(v)]
else:
second_part += [v]
tuple_items = tuple(first_part + second_part)
#backport
#tuple_items = tuple([k for k,v in items] + [(tuple(v) if isinstance(v, (tuple, list)) else v) for k,v in items])
h = hash('Elemwise') ^ hash(self.scalar_op) ^ hash(tuple_items) h = hash('Elemwise') ^ hash(self.scalar_op) ^ hash(tuple_items)
assert h == getattr(self,'_hashval', h) assert h == getattr(self,'_hashval', h)
self._hashval = h self._hashval = h
...@@ -517,10 +538,21 @@ class Elemwise(Op): ...@@ -517,10 +538,21 @@ class Elemwise(Op):
for dims in zip(*[[(1, True)]*(maxsize - len(input.shape)) + zip(input.shape, sinput.type.broadcastable) for dims in zip(*[[(1, True)]*(maxsize - len(input.shape)) + zip(input.shape, sinput.type.broadcastable)
for input, sinput in zip(inputs, node.inputs)]): for input, sinput in zip(inputs, node.inputs)]):
if max(d for d,b in dims) != 1 and (1, False) in dims: if max(d for d,b in dims) != 1 and (1, False) in dims:
msg = []
for input, sinput in zip(inputs, node.inputs):
for d, b in zip(input.shape, sinput.type.broadcastable):
if b:
msg += ['*']
else:
msg += [str(d)]
raise ValueError('Dimension mismatch; shapes are %s' % raise ValueError('Dimension mismatch; shapes are %s' %
', '.join('(%s)' % ', '.join('*' if b else str(d) ', '.join('(%s)' % ', '.join(msg)))
for d, b in zip(input.shape, sinput.type.broadcastable)) #backport
for input, sinput in zip(inputs, node.inputs))) #raise ValueError('Dimension mismatch; shapes are %s' %
# ', '.join('(%s)' % ', '.join('*' if b else str(d)
# for d, b in zip(input.shape, sinput.type.broadcastable))
# for input, sinput in zip(inputs, node.inputs)))
# Other mismatches will be caught by the ufunc # Other mismatches will be caught by the ufunc
if not self.inplace_pattern: if not self.inplace_pattern:
for output, storage in zip(node.outputs, output_storage): for output, storage in zip(node.outputs, output_storage):
......
from .basic import _scal_elemwise #, _transpose_inplace from basic import _scal_elemwise #, _transpose_inplace
from .. import scalar as scal from theano import scalar as scal
import elemwise import elemwise
from .. import printing from theano import printing
from ..printing import pprint from theano.printing import pprint
from theano.gof.python25 import any
def _scal_inplace(symbol): def _scal_inplace(symbol):
"""Replace a symbol definition with an elementwise version of the corresponding scalar Op""" """Replace a symbol definition with an elementwise version of the corresponding scalar Op"""
......
...@@ -4,20 +4,20 @@ ...@@ -4,20 +4,20 @@
# TODO: 0*x -> 0 # TODO: 0*x -> 0
from .. import gof from theano import gof
from ..gof import opt, InconsistencyError, TopoOptimizer, graph from theano.gof import opt, InconsistencyError, TopoOptimizer, graph
from elemwise import Elemwise, DimShuffle from elemwise import Elemwise, DimShuffle
from .. import scalar from theano import scalar
import basic as T import basic as T
import inplace as I import inplace as I
import numpy as N import numpy as N
import operator import operator
import itertools import itertools
import sys import sys
from .. import compile #to register the optimizer built by this file from theano import compile #to register the optimizer built by this file
from ..compile.debugmode import _debugprint
from theano.compile.debugmode import _debugprint
from theano.gof.python25 import any
# Utilities # Utilities
...@@ -738,11 +738,20 @@ class Canonizer(gof.LocalOptimizer): ...@@ -738,11 +738,20 @@ class Canonizer(gof.LocalOptimizer):
def mul_calculate(num, denum, aslist=False, out_type=None): def mul_calculate(num, denum, aslist=False, out_type=None):
if not num and not denum: if not num and not denum:
# Smallest 1 possible. # Smallest 1 possible.
return [] if aslist else N.int8(1) if aslist:
return []
else:
return N.int8(1)
#return [] if aslist else N.int8(1)
# Make sure we do not accidently upcast data types. # Make sure we do not accidently upcast data types.
if out_type is None: if out_type is None:
# TODO: remove this error-causing heuristic # TODO: remove this error-causing heuristic
first = num[0] if num else denum[0] if num:
first = num[0]
else:
first = denum[0]
#first = num[0] if num else denum[0]
one = N.asarray(first).dtype.type(1) one = N.asarray(first).dtype.type(1)
else: else:
one = N.asarray(1, dtype=out_type.dtype) one = N.asarray(1, dtype=out_type.dtype)
...@@ -850,15 +859,29 @@ def local_mul_specialize(node): ...@@ -850,15 +859,29 @@ def local_mul_specialize(node):
new_inputs.append(input) new_inputs.append(input)
if len(new_inputs) < len(node.inputs): if len(new_inputs) < len(node.inputs):
if len(new_inputs) == 0: if len(new_inputs) == 0:
newval = -y.flatten()[0] if neg else y.flatten()[0] if neg:
newval = -y.flatten()[0]
else:
newval = y.flatten()[0]
#newval = -y.flatten()[0] if neg else y.flatten()[0]
return fill_chain(T.TensorConstant(T.TensorType(dtype=node.outputs[0].type.dtype, return fill_chain(T.TensorConstant(T.TensorType(dtype=node.outputs[0].type.dtype,
broadcastable = [True] * node.outputs[0].ndim), N.asarray(newval))) broadcastable = [True] * node.outputs[0].ndim), N.asarray(newval)))
if len(new_inputs) == 1: if len(new_inputs) == 1:
return fill_chain(-new_inputs[0] if neg else new_inputs[0]) if neg:
msg = -new_inputs[0]
else:
msg = new_inputs[0]
return fill_chain(msg)
# return fill_chain(-new_inputs[0] if neg else new_inputs[0])
else: else:
return fill_chain(-T.mul(*new_inputs) if neg else \ if neg:
T.mul(*new_inputs)) msg = -T.mul(*new_inputs)
else:
msg = T.mul(*new_inputs)
#return fill_chain(-T.mul(*new_inputs) if neg else \
# T.mul(*new_inputs))
else: else:
return False return False
register_specialize(local_mul_specialize) register_specialize(local_mul_specialize)
...@@ -914,7 +937,11 @@ mul_canonizer = in2out(gof.LocalOptGroup(local_mul_canonizer, local_fill_cut, lo ...@@ -914,7 +937,11 @@ mul_canonizer = in2out(gof.LocalOptGroup(local_mul_canonizer, local_fill_cut, lo
def add_calculate(num, denum, aslist = False, out_type=None): def add_calculate(num, denum, aslist = False, out_type=None):
#TODO: make sure that this function and mul_calculate are similar #TODO: make sure that this function and mul_calculate are similar
zero = 0.0 if out_type is None else N.asarray(0, dtype=out_type.dtype) if out_type is None:
zero = 0.0
else:
zero = N.asarray(0, dtype=out_type.dtype)
#zero = 0.0 if out_type is None else N.asarray(0, dtype=out_type.dtype)
v = reduce(N.add, num, zero) - reduce(N.add, denum, zero) v = reduce(N.add, num, zero) - reduce(N.add, denum, zero)
if aslist: if aslist:
if N.all(v == 0): if N.all(v == 0):
...@@ -1061,7 +1088,15 @@ def constant_folding(node): ...@@ -1061,7 +1088,15 @@ def constant_folding(node):
storage = [[None] for output in node.outputs] storage = [[None] for output in node.outputs]
node.op.perform(node, [x.data for x in node.inputs], storage) node.op.perform(node, [x.data for x in node.inputs], storage)
#TODO: think about how to extend to more types #TODO: think about how to extend to more types
return [(T.TensorConstant if isinstance(s[0], (N.ndarray,int,float)) else gof.Constant)(output.type, s[0]) for s, output in zip(storage, node.outputs)] msg = []
for s, output in zip(storage, node.outputs):
if isinstance(s[0], (N.ndarray,int,float)):
msg += [T.TensorConstant(output.type,s[0])]
else:
msg += [gof.Constant(output.type, s[0])]
return msg
#TODO: verify this backport!!
#return [(T.TensorConstant if isinstance(s[0], (N.ndarray,int,float)) else gof.Constant)(output.type, s[0]) for s, output in zip(storage, node.outputs)]
register_canonicalize(constant_folding) register_canonicalize(constant_folding)
register_specialize(constant_folding) register_specialize(constant_folding)
......
...@@ -4,9 +4,9 @@ __docformat__ = "restructuredtext en" ...@@ -4,9 +4,9 @@ __docformat__ = "restructuredtext en"
import sys import sys
import numpy import numpy
from ..compile import module, In, Component from theano.compile import module, In, Component
from ..gof import Container from theano.gof import Container
from ..tensor import raw_random from theano.tensor import raw_random
class RandomStreamsInstance(object): class RandomStreamsInstance(object):
"""RandomStreamsInstance""" """RandomStreamsInstance"""
...@@ -37,7 +37,10 @@ class RandomStreamsInstance(object): ...@@ -37,7 +37,10 @@ class RandomStreamsInstance(object):
:rtype: None :rtype: None
""" """
seed = self.default_seed if seed is None else seed if seed is None:
seed = self.default_seed
#backport
#seed = self.default_seed if seed is None else seed
seedgen = numpy.random.RandomState(seed) seedgen = numpy.random.RandomState(seed)
for old_r, new_r in self.random_streams.random_state_variables: for old_r, new_r in self.random_streams.random_state_variables:
old_r_seed = seedgen.randint(2**30) old_r_seed = seedgen.randint(2**30)
......
...@@ -7,8 +7,8 @@ import numpy ...@@ -7,8 +7,8 @@ import numpy
#local imports #local imports
import basic as tensor import basic as tensor
import opt import opt
from .. import gof from theano import gof
from ..compile import optdb from theano.compile import optdb
class RandomStateType(gof.Type): class RandomStateType(gof.Type):
"""A Type wrapper for numpy.RandomState """A Type wrapper for numpy.RandomState
...@@ -85,7 +85,12 @@ class RandomFunction(gof.Op): ...@@ -85,7 +85,12 @@ class RandomFunction(gof.Op):
def __setstate__(self, state): def __setstate__(self, state):
self.state = state self.state = state
fn, outtype, args, kwargs = state fn, outtype, args, kwargs = state
self.fn = getattr(numpy.random.RandomState, fn) if isinstance(fn, str) else fn if isinstance(fn, str):
self.fn = getattr(numpy.random.RandomState, fn)
else:
self.fn = fn
#backport
#self.fn = getattr(numpy.random.RandomState, fn) if isinstance(fn, str) else fn
self.outtype = outtype self.outtype = outtype
self.args = tuple(tensor.as_tensor_variable(arg) for arg in args) self.args = tuple(tensor.as_tensor_variable(arg) for arg in args)
self.inplace = kwargs.pop('inplace', False) self.inplace = kwargs.pop('inplace', False)
...@@ -139,7 +144,12 @@ class RandomFunction(gof.Op): ...@@ -139,7 +144,12 @@ class RandomFunction(gof.Op):
inputs = [] inputs = []
for arg, default in zip(args, self.args): for arg, default in zip(args, self.args):
assert arg is None or default.type.dtype == arg.type.dtype assert arg is None or default.type.dtype == arg.type.dtype
input = default if arg is None else arg if arg is None:
input = default
else:
input = arg
#backport
#input = default if arg is None else arg
inputs.append(input) inputs.append(input)
return gof.Apply(self, return gof.Apply(self,
......
...@@ -10,7 +10,7 @@ from copy import copy ...@@ -10,7 +10,7 @@ from copy import copy
from theano import compile from theano import compile
from theano import gradient from theano import gradient
from theano import gof from theano import gof
from theano.gof.python25 import any from theano.gof.python25 import any, all
from theano import gof from theano import gof
from theano.tensor.elemwise import DimShuffle from theano.tensor.elemwise import DimShuffle
......
import traceback import traceback
import theano.tensor as T import theano.tensor as T
from ...gof import Env from theano.gof import Env
from ...printing import pp from theano.printing import pp
import numpy import numpy
from theano.tensor.blas import * from theano.tensor.blas import *
from theano.tensor.blas import _dot22, res_is_a from theano.tensor.blas import _dot22, res_is_a
...@@ -13,7 +13,7 @@ _as_scalar = GemmLocalOptimizer._as_scalar ...@@ -13,7 +13,7 @@ _as_scalar = GemmLocalOptimizer._as_scalar
_is_real_matrix = GemmLocalOptimizer._is_real_matrix _is_real_matrix = GemmLocalOptimizer._is_real_matrix
from theano import In, Out from theano import In, Out
from .test_basic import (_approx_eq, as_tensor_variable, inplace_func, from test_basic import (_approx_eq, as_tensor_variable, inplace_func,
compile, value, constant, inplace, eval_outputs) compile, value, constant, inplace, eval_outputs)
class t_gemm(TestCase): class t_gemm(TestCase):
...@@ -415,4 +415,3 @@ def test_inplace1(): ...@@ -415,4 +415,3 @@ def test_inplace1():
# gemm should operate in-place on (Z+Z) # gemm should operate in-place on (Z+Z)
if (not gemm in [n.op for n in f.maker.env.nodes]): if (not gemm in [n.op for n in f.maker.env.nodes]):
raise Failure('no gemm in graph') raise Failure('no gemm in graph')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论