提交 d325be92 authored 作者: Frederic Bastien's avatar Frederic Bastien

merge of changeset af0638d6d9bb. Back-port to python 2.4

...@@ -24,6 +24,7 @@ To learn more, check out: ...@@ -24,6 +24,7 @@ To learn more, check out:
""" """
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
import gof
from gof import \ from gof import \
CLinker, OpWiseCLinker, DualLinker, Linker, LocalLinker, PerformLinker, \ CLinker, OpWiseCLinker, DualLinker, Linker, LocalLinker, PerformLinker, \
......
from .. import gof from theano import gof
from .. import gradient as G from theano import gradient as G
from function_module import function from function_module import function
...@@ -56,7 +56,7 @@ class OpFromGraph(gof.Op): ...@@ -56,7 +56,7 @@ class OpFromGraph(gof.Op):
def make_node(self, *inputs): def make_node(self, *inputs):
for input, type in zip(inputs, self.input_types): for input, type in zip(inputs, self.input_types):
if not type == input.type: if not type == input.type:
raise TypeError("Wrong type, expected %s but got %s" % type, input.type) raise TypeError("Wrong type, expected %s but got %s" % (type, input.type))
return gof.Apply(self, return gof.Apply(self,
inputs, inputs,
[type() for type in self.output_types]) [type() for type in self.output_types])
......
...@@ -6,19 +6,19 @@ from StringIO import StringIO ...@@ -6,19 +6,19 @@ from StringIO import StringIO
import numpy import numpy
from .. import gof from theano import gof
from ..gof import Env, graph, utils, link from theano.gof import Env, graph, utils, link
from ..gof.link import WrapLinkerMany, raise_with_op from theano.gof.link import WrapLinkerMany, raise_with_op
from ..gof.cutils import run_cthunk #from theano.gof.cutils import run_cthunk
from ..gof.cc import OpWiseCLinker, CLinker from theano.gof.cc import OpWiseCLinker, CLinker
from ..compile.function_module import (FunctionMaker, from theano.compile.function_module import (FunctionMaker,
Function, Function,
infer_reuse_pattern, infer_reuse_pattern,
SymbolicInput, SymbolicInput,
SymbolicInputKit, SymbolicInputKit,
SymbolicOutput, SymbolicOutput,
Supervisor) Supervisor)
from ..compile.mode import Mode, register_mode from theano.compile.mode import Mode, register_mode
import logging import logging
_logger=logging.getLogger("theano.compile.debugmode") _logger=logging.getLogger("theano.compile.debugmode")
...@@ -260,7 +260,10 @@ def _debugprint(r, prefix='', depth=-1, done=None, file=sys.stdout): ...@@ -260,7 +260,10 @@ def _debugprint(r, prefix='', depth=-1, done=None, file=sys.stdout):
""" """
if depth==0: if depth==0:
return return
done = set() if done is None else done #backport
if done is None:
done = set()
#done = set() if done is None else done
if hasattr(r.owner, 'op'): if hasattr(r.owner, 'op'):
# this variable is the output of computation, # this variable is the output of computation,
# so just print out the apply # so just print out the apply
...@@ -548,9 +551,15 @@ def _find_bad_optimizations2(order, reasons, r_vals): ...@@ -548,9 +551,15 @@ def _find_bad_optimizations2(order, reasons, r_vals):
checked_variables.add(r) checked_variables.add(r)
# (recursively) first check all the variables that could make r look bad: # (recursively) first check all the variables that could make r look bad:
list_of_vars = [old_r for (reason, old_r, olds, news) in reasons[r]]
if (None is not r.owner):
list_of_vars += r.owner.inputs
for var_that_could_make_r_look_bad in \ for var_that_could_make_r_look_bad in \
[old_r for (reason, old_r, olds, news) in reasons[r]] \ list_of_vars:
+ ([] if (None is r.owner) else r.owner.inputs): #backport
#[old_r for (reason, old_r, olds, news) in reasons[r]] \
#+ ([] if (None is r.owner) else r.owner.inputs):
check_variable(var_that_could_make_r_look_bad) check_variable(var_that_could_make_r_look_bad)
check_variable_norec(r) check_variable_norec(r)
...@@ -598,11 +607,18 @@ class _EnvEvent(object): ...@@ -598,11 +607,18 @@ class _EnvEvent(object):
def __str__(self): def __str__(self):
if self.kind == 'change': if self.kind == 'change':
if (self.op != 'output'):
msg = str(len(self.node.inputs))
else:
msg = ''
return ' '.join(['change', return ' '.join(['change',
self.reason, self.reason,
str(self.op), str(self.op),
str(self.idx), str(self.idx),
str(len(self.node.inputs)) if (self.op != 'output') else '']) msg])
#backport
#str(len(self.node.inputs)) if (self.op != 'output') else ''])
else: else:
return str(self.__dict__) return str(self.__dict__)
...@@ -1087,7 +1103,7 @@ class _Maker(FunctionMaker): #inheritance buys a few helper functions ...@@ -1087,7 +1103,7 @@ class _Maker(FunctionMaker): #inheritance buys a few helper functions
# WARNING: this is a global mechanism... so it will screw up if we are trying to use # WARNING: this is a global mechanism... so it will screw up if we are trying to use
# multiple modes at once. # multiple modes at once.
from ..tensor import TensorType #to set filter_check_isfinite from theano.tensor import TensorType #to set filter_check_isfinite
TensorType.filter_checks_isfinite = mode.check_isfinite TensorType.filter_checks_isfinite = mode.check_isfinite
# Handle the case where inputs and/or outputs is a single Variable (not in a list) # Handle the case where inputs and/or outputs is a single Variable (not in a list)
...@@ -1125,8 +1141,19 @@ class _Maker(FunctionMaker): #inheritance buys a few helper functions ...@@ -1125,8 +1141,19 @@ class _Maker(FunctionMaker): #inheritance buys a few helper functions
for j in xrange(max(len(li), len(l0))): for j in xrange(max(len(li), len(l0))):
if j >= len(li) or j >= len(l0) or li[j] != l0[j]: if j >= len(li) or j >= len(l0) or li[j] != l0[j]:
print >> infolog, "* ", j, print >> infolog, "* ", j,
print >> infolog, " ", str(li[j]) if j < len(li) else '-', if j < len(li):
print >> infolog, " ", str(l0[j]) if j < len(l0) else '-' msg = str(li[j])
else:
msg = '-'
print >> infolog, " ", msg
if j < len(l0):
msg = str(l0[j])
else:
msg = '-'
print >> infolog, " ", msg
#backport
#print >> infolog, " ", str(li[j]) if j < len(li) else '-',
#print >> infolog, " ", str(l0[j]) if j < len(l0) else '-'
else: else:
pass pass
raise StochasticOrder(infolog.getvalue()) raise StochasticOrder(infolog.getvalue())
......
...@@ -6,14 +6,16 @@ __docformat__ = "restructuredtext en" ...@@ -6,14 +6,16 @@ __docformat__ = "restructuredtext en"
import copy_reg import copy_reg
import cPickle import cPickle
from functools import partial import sys
if sys.version_info[:2] >= (2,5):
from functools import partial
import numpy import numpy
from .. import gof import theano.gof
import sys #from theano import gof
import copy import copy
import time import time
import mode as mode_module import mode as mode_module
from io import * from io import *
...@@ -34,8 +36,18 @@ def infer_reuse_pattern(env, outputs_to_disown): ...@@ -34,8 +36,18 @@ def infer_reuse_pattern(env, outputs_to_disown):
do_not_reuse.append(r) do_not_reuse.append(r)
node = r.owner node = r.owner
op = node.op op = node.op
dmap = op.destroy_map if hasattr(op, 'destroy_map') else {} if hasattr(op, 'destroy_map'):
vmap = op.view_map if hasattr(op, 'view_map') else {} dmap = op.destroy_map
else:
dmap = {}
if hasattr(op, 'view_map'):
vmap = op.view_map
else:
vmap = {}
#backport
#dmap = op.destroy_map if hasattr(op, 'destroy_map') else {}
#vmap = op.view_map if hasattr(op, 'view_map') else {}
for l in dmap.values() + vmap.values(): for l in dmap.values() + vmap.values():
for i in l: for i in l:
walk(node.inputs[i]) walk(node.inputs[i])
...@@ -253,7 +265,12 @@ class Function(object): ...@@ -253,7 +265,12 @@ class Function(object):
c.provided = 0 # this is a count of how many times the input has been provided (reinitialized to 0 on __call__) c.provided = 0 # this is a count of how many times the input has been provided (reinitialized to 0 on __call__)
finder[i] = c finder[i] = c
finder[input.variable] = c finder[input.variable] = c
finder[input.name] = c if input.name not in finder else DUPLICATE if input.name not in finder:
finder[input.name] = c
else:
finder[input.name] = DUPLICATE
#backport
#finder[input.name] = c if input.name not in finder else DUPLICATE
# inv_finder maps the container to the input (useful for one error message) # inv_finder maps the container to the input (useful for one error message)
inv_finder[c] = input inv_finder[c] = input
#setters.append(partial(assign, c)) #setters.append(partial(assign, c))
...@@ -272,7 +289,12 @@ class Function(object): ...@@ -272,7 +289,12 @@ class Function(object):
# can reinitialize all the containers # can reinitialize all the containers
finder[i] = f finder[i] = f
finder[input] = f finder[input] = f
finder[input.name] = f if input.name not in finder else DUPLICATE if input.name not in finder:
finder[input.name] = f
else:
finder[input.name] = DUPLICATE
#backport
#finder[input.name] = f if input.name not in finder else DUPLICATE
#setters.append(f) #setters.append(f)
# For each input in the kit and its corresponding container, we put an entry in finder. # For each input in the kit and its corresponding container, we put an entry in finder.
# This allows the user to micro-manage elements of the kit if need be. # This allows the user to micro-manage elements of the kit if need be.
...@@ -280,7 +302,12 @@ class Function(object): ...@@ -280,7 +302,12 @@ class Function(object):
for c, sin in zip(cs, sinputs): for c, sin in zip(cs, sinputs):
finder[sin.variable] = c finder[sin.variable] = c
finder[sin.name] = c finder[sin.name] = c
finder[sin.name] = c if sin.name not in finder else DUPLICATE if sin.name not in finder:
finder[sin.name] = c
else:
finder[sin.name] = DUPLICATE
#backport
#finder[sin.name] = c if sin.name not in finder else DUPLICATE
inv_finder[c] = input inv_finder[c] = input
c.required = required c.required = required
c.provided = 0 c.provided = 0
...@@ -508,9 +535,15 @@ class SanityCheckFunction(Function): ...@@ -508,9 +535,15 @@ class SanityCheckFunction(Function):
if not input.mutable: if not input.mutable:
if not self.check_equal(c1.value, c2.value): if not self.check_equal(c1.value, c2.value):
name = c2.name name = c2.name
if name:
the_name = name
else:
the_name = ""
raise ValueError("Input #%i%s using %s and %s differs." raise ValueError("Input #%i%s using %s and %s differs."
% (i, % (i,
" (%s)" % name if name else "", #backport
#" (%s)" % name if name else "",
" (%s)" % the_name,
self.maker.mode, self.maker.mode,
fn.maker.mode), fn.maker.mode),
c1.value, c2.value) c1.value, c2.value)
...@@ -521,9 +554,15 @@ class SanityCheckFunction(Function): ...@@ -521,9 +554,15 @@ class SanityCheckFunction(Function):
r2 = c2.value r2 = c2.value
if not self.check_equal(r1, r2): if not self.check_equal(r1, r2):
name = c2.name name = c2.name
if name:
the_name = name
else:
the_name = ""
raise ValueError("Variable #%i%s using %s and %s differs." raise ValueError("Variable #%i%s using %s and %s differs."
% (i, % (i,
" (%s)" % name if name else "", #backport
#" (%s)" % name if name else "",
" (%s)" % the_name,
self.maker.mode, self.maker.mode,
fn.maker.mode), fn.maker.mode),
r1, r2) r1, r2)
...@@ -599,7 +638,10 @@ class FunctionMaker(object): ...@@ -599,7 +638,10 @@ class FunctionMaker(object):
in the graph from the inputs to the outputs in the graph from the inputs to the outputs
""" """
mode = mode if mode is not None else mode_module.default_mode if mode is None:
mode = mode_module.default_mode
#backport
#mode = mode if mode is not None else mode_module.default_mode
# Handle the case where inputs and/or outputs is a single Variable (not in a list) # Handle the case where inputs and/or outputs is a single Variable (not in a list)
unpack_single = False unpack_single = False
...@@ -704,7 +746,16 @@ class FunctionMaker(object): ...@@ -704,7 +746,16 @@ class FunctionMaker(object):
def _pickle_FunctionMaker(fm): def _pickle_FunctionMaker(fm):
outputs = None if fm.return_none else (fm.outputs[0] if fm.unpack_single else fm.outputs) if fm.return_none:
outputs = None
else:
if fm.unpack_single:
outputs = fm.outputs[0]
else:
outputs = fm.outputs
#backport
#outputs = None if fm.return_none else (fm.outputs[0] if fm.unpack_single else fm.outputs)
rval = (_constructor_FunctionMaker, (fm.inputs, outputs, fm.mode, fm.accept_inplace)) rval = (_constructor_FunctionMaker, (fm.inputs, outputs, fm.mode, fm.accept_inplace))
return rval return rval
...@@ -790,12 +841,19 @@ def function(inputs, outputs, mode=None, accept_inplace = False): ...@@ -790,12 +841,19 @@ def function(inputs, outputs, mode=None, accept_inplace = False):
""" """
t1 = time.time() t1 = time.time()
mode = mode if mode is not None else mode_module.default_mode if mode is None:
mode = mode_module.default_mode
#backport
#mode = mode if mode is not None else mode_module.default_mode
inputs = map(convert_function_input, inputs) inputs = map(convert_function_input, inputs)
if outputs is not None: if outputs is not None:
outputs = map(FunctionMaker.wrap_out, outputs) if isinstance(outputs, (list, tuple)) else FunctionMaker.wrap_out(outputs) if isinstance(outputs, (list, tuple)):
outputs = map(FunctionMaker.wrap_out, outputs)
else:
outputs = FunctionMaker.wrap_out(outputs)
#backport
#outputs = map(FunctionMaker.wrap_out, outputs) if isinstance(outputs, (list, tuple)) else FunctionMaker.wrap_out(outputs)
defaults = [getattr(input, 'value', None) for input in inputs] defaults = [getattr(input, 'value', None) for input in inputs]
...@@ -809,9 +867,17 @@ def function(inputs, outputs, mode=None, accept_inplace = False): ...@@ -809,9 +867,17 @@ def function(inputs, outputs, mode=None, accept_inplace = False):
#return a different kind of function #return a different kind of function
def dup_defaults(): def dup_defaults():
# TODO This may need to be changed to use containers as defaults. # TODO This may need to be changed to use containers as defaults.
return [copy.copy(default.value) if isinstance(default, gof.Container) else retval = []
copy.copy(default) for default in defaults:
for default in defaults] if isinstance(default, gof.Container):
retval +=[copy.copy(default.value)]
else:
retval +=[copy.copy(default)]
return retval
#backport
#return [copy.copy(default.value) if isinstance(default, gof.Container) else
# copy.copy(default)
# for default in defaults]
makers = [FunctionMaker(inputs, outputs, m, accept_inplace = accept_inplace) for m in mode[1:]] makers = [FunctionMaker(inputs, outputs, m, accept_inplace = accept_inplace) for m in mode[1:]]
fns = [maker.create(dup_defaults(), trustme = True) for maker in makers] fns = [maker.create(dup_defaults(), trustme = True) for maker in makers]
builder = partial(SanityCheckFunction, fns, check_equal) builder = partial(SanityCheckFunction, fns, check_equal)
...@@ -820,10 +886,11 @@ def function(inputs, outputs, mode=None, accept_inplace = False): ...@@ -820,10 +886,11 @@ def function(inputs, outputs, mode=None, accept_inplace = False):
else: else:
Maker = getattr(mode, 'function_maker', FunctionMaker) Maker = getattr(mode, 'function_maker', FunctionMaker)
fn = Maker(inputs, outputs, mode, accept_inplace = accept_inplace).create(defaults) fn = Maker(inputs, outputs, mode, accept_inplace = accept_inplace).create(defaults)
t2 = time.time() t2 = time.time()
if hasattr(mode, 'compile_time'): if hasattr(mode, 'compile_time'):
mode.compile_time+=t2-t1 mode.compile_time+=t2-t1
return fn return fn
......
...@@ -39,11 +39,23 @@ class SymbolicInput(object): ...@@ -39,11 +39,23 @@ class SymbolicInput(object):
implicit=False): implicit=False):
assert implicit is not None # Safety check. assert implicit is not None # Safety check.
self.variable = variable self.variable = variable
self.name = variable.name if (autoname and name is None) else name if (autoname and name is None):
self.name = variable.name
else:
self.name = name
#backport
#self.name = variable.name if (autoname and name is None) else name
if self.name is not None and not isinstance(self.name, str): if self.name is not None and not isinstance(self.name, str):
raise TypeError("name must be a string! (got: %s)" % self.name) raise TypeError("name must be a string! (got: %s)" % self.name)
self.update = update self.update = update
self.mutable = mutable if (mutable is not None) else (update is not None) if (mutable is not None):
self.mutable = mutable
else:
self.mutable = (update is not None)
#backport
#self.mutable = mutable if (mutable is not None) else (update is not None)
self.strict = strict self.strict = strict
self.implicit = implicit self.implicit = implicit
......
...@@ -8,10 +8,14 @@ __docformat__ = "restructuredtext en" ...@@ -8,10 +8,14 @@ __docformat__ = "restructuredtext en"
from theano import gof from theano import gof
from theano.printing import pprint from theano.printing import pprint
from collections import defaultdict
from itertools import chain
from functools import partial
import io, sys import io, sys
if sys.version_info[:2] >= (2,5):
from collections import defaultdict
from itertools import chain
if sys.version_info[:2] >= (2,5):
from functools import partial
import function_module as F import function_module as F
import mode as get_mode import mode as get_mode
...@@ -418,8 +422,26 @@ class Method(Component): ...@@ -418,8 +422,26 @@ class Method(Component):
outputs = self.outputs outputs = self.outputs
_inputs = [x.variable for x in inputs] _inputs = [x.variable for x in inputs]
# Grab the variables that are not accessible from either the inputs or the updates. # Grab the variables that are not accessible from either the inputs or the updates.
outputs_list = [] if outputs is None else (list(outputs) if isinstance(outputs, (list, tuple)) else [outputs]) if outputs is None:
outputs_variable_list = [o.variable if isinstance(o, io.Out) else o for o in outputs_list] outputs_list = []
else:
if isinstance(outputs, (list, tuple)):
outputs_list = list(outputs)
else:
outputs_list = [outputs]
#backport
#outputs_list = [] if outputs is None else (list(outputs) if isinstance(outputs, (list, tuple)) else [outputs])
outputs_variable_list = []
for o in outputs_list:
if isinstance(o, io.Out):
outputs_variable_list += [o.variable]
else:
outputs_variable_list += [o]
#backport
#outputs_variable_list = [o.variable if isinstance(o, io.Out) else o for o in outputs_list]
for input in gof.graph.inputs(outputs_variable_list for input in gof.graph.inputs(outputs_variable_list
+ [x.update for x in inputs if getattr(x, 'update', False)], + [x.update for x in inputs if getattr(x, 'update', False)],
blockers = _inputs): blockers = _inputs):
...@@ -448,7 +470,13 @@ class Method(Component): ...@@ -448,7 +470,13 @@ class Method(Component):
assert type(storage) is io.In assert type(storage) is io.In
inputs.append(storage) inputs.append(storage)
effective_mode = mode if self.mode is None else self.mode if self.mode is None:
effective_mode = mode
else:
effective_mode = self.mode
#backport
#effective_mode = mode if self.mode is None else self.mode
rval = F.function(inputs, outputs, effective_mode) rval = F.function(inputs, outputs, effective_mode)
memo[self] = rval memo[self] = rval
return rval return rval
...@@ -459,7 +487,13 @@ class Method(Component): ...@@ -459,7 +487,13 @@ class Method(Component):
rval = 'inputs: %s\n' % ", ".join(map(str, self.inputs)) rval = 'inputs: %s\n' % ", ".join(map(str, self.inputs))
else: else:
rval = '' rval = ''
inputs, outputs, updates = self.inputs, self.outputs if isinstance(self.outputs, (list, tuple)) else [self.outputs], self.updates if isinstance(self.outputs, (list, tuple)):
inputs, outputs, updates = self.inputs, self.outputs
else:
inputs, outputs, updates = [self.outputs], self.updates
#backport
#inputs, outputs, updates = self.inputs, self.outputs if isinstance(self.outputs, (list, tuple)) else [self.outputs], self.updates
# If mode is in kwargs, prints the optimized version of the method # If mode is in kwargs, prints the optimized version of the method
mode = kwargs.pop('mode', None) mode = kwargs.pop('mode', None)
...@@ -472,10 +506,16 @@ class Method(Component): ...@@ -472,10 +506,16 @@ class Method(Component):
return rval return rval
def __str__(self): def __str__(self):
if self.updates:
sep = "; "
else:
sep = ""
return "Method(%s -> %s%s%s)" % \ return "Method(%s -> %s%s%s)" % \
(self.inputs, (self.inputs,
self.outputs, self.outputs,
"; " if self.updates else "", sep,
#backport
#"; " if self.updates else "",
", ".join("%s <= %s" % (old, new) for old, new in self.updates.iteritems())) ", ".join("%s <= %s" % (old, new) for old, new in self.updates.iteritems()))
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
...@@ -603,7 +643,16 @@ class Composite(Component): ...@@ -603,7 +643,16 @@ class Composite(Component):
self.set(item, value) self.set(item, value)
def __iter__(self): def __iter__(self):
return (c.r if isinstance(c, (External, Member)) else c for c in self.components()) retval = []
for c in self.components():
if isinstance(c, (External, Member)):
retval += [c.r]
else:
retval += [c]
return retval
#backport
#return (c.r if isinstance(c, (External, Member)) else c for c in self.components())
...@@ -1047,7 +1096,12 @@ class Module(ComponentDict): ...@@ -1047,7 +1096,12 @@ class Module(ComponentDict):
# Function to go through member lists and dictionaries recursively, # Function to go through member lists and dictionaries recursively,
# to look for submodules on which make_module_instance needs to be called # to look for submodules on which make_module_instance needs to be called
def recurse(v): def recurse(v):
iter = enumerate(v) if isinstance(v,list) else v.iteritems() if isinstance(v,list):
iter = enumerate(v)
else:
iter = v.iteritems()
#backport
#iter = enumerate(v) if isinstance(v,list) else v.iteritems()
for sk,sv in iter: for sk,sv in iter:
if isinstance(sv,(list,dict)): if isinstance(sv,(list,dict)):
sv = recurse(sv) sv = recurse(sv)
......
import time, atexit import time, atexit
from ..gof.link import WrapLinkerMany from theano.gof.link import WrapLinkerMany
from ..gof.cutils import run_cthunk from theano.gof.cutils import run_cthunk
from ..compile.mode import Mode, predefined_linkers, register_mode, predefined_modes from theano.compile.mode import Mode, predefined_linkers, register_mode, predefined_modes
from ..gof.cc import OpWiseCLinker from theano.gof.cc import OpWiseCLinker
class ProfileMode(Mode): class ProfileMode(Mode):
def __init__(self, linker=OpWiseCLinker(), optimizer=None): def __init__(self, linker=OpWiseCLinker(), optimizer=None):
...@@ -93,20 +93,24 @@ class ProfileMode(Mode): ...@@ -93,20 +93,24 @@ class ProfileMode(Mode):
break break
print '\nOp-wise summary: < of local_time spent on this kind of Op> <cumulative seconds> <self seconds>%s <Op name>'%(flops_msg) print '\nOp-wise summary: < of local_time spent on this kind of Op> <cumulative seconds> <self seconds>%s <Op name>'%(flops_msg)
otimes = [(t/local_time, t, a, self.op_cimpl[a]) for a, t in op_time.items()] otimes = [(t/local_time, t, a, self.op_cimpl[a]) for a, t in op_time.items()]
otimes.sort() otimes.sort()
otimes.reverse() otimes.reverse()
tot=0 tot=0
for f,t,a,ci in otimes[:n_ops_to_print]: for f,t,a,ci in otimes[:n_ops_to_print]:
tot+=t tot+=t
if ci:
msg = '*'
else:
msg = ' '
m=-1 m=-1
if hasattr(a,'flops'): if hasattr(a,'flops'):
m=a.flops*self.op_call[a]/t/1e6 m=a.flops*self.op_call[a]/t/1e6
if flops: if flops:
print ' %4.1f%% %.3fs %.3fs %s %7.1f %s' % (f*100, tot, t, '*' if ci else ' ', m,a) print ' %4.1f%% %.3fs %.3fs %s %7.1f %s' % (f*100, tot, t, msg, m,a)
else: else:
print ' %4.1f%% %.3fs %.3fs %s %s' % (f*100, tot, t, '*' if ci else ' ', a) print ' %4.1f%% %.3fs %.3fs %s %s' % (f*100, tot, t, msg, a)
print ' ... (remaining %i Ops account for %6.2f%%(%.2fs) of the runtime)'\ print ' ... (remaining %i Ops account for %6.2f%%(%.2fs) of the runtime)'\
%(max(0, len(otimes)-n_ops_to_print), %(max(0, len(otimes)-n_ops_to_print),
sum(f for f, t, a, ci in otimes[n_ops_to_print:])*100, sum(f for f, t, a, ci in otimes[n_ops_to_print:])*100,
...@@ -128,7 +132,11 @@ class ProfileMode(Mode): ...@@ -128,7 +132,11 @@ class ProfileMode(Mode):
tot=0 tot=0
for f,t,a,ci in sotimes[:n_ops_to_print]: for f,t,a,ci in sotimes[:n_ops_to_print]:
tot+=t tot+=t
print ' %4.1f%% %.3fs %.3fs %s %s' % (f*100, tot, t, '*' if ci else ' ', a) if ci:
msg = '*'
else:
msg = ' '
print ' %4.1f%% %.3fs %.3fs %s %s' % (f*100, tot, t, msg, a)
print ' ... (remaining %i Ops account for %.2f%%(%.2fs) of the runtime)'\ print ' ... (remaining %i Ops account for %.2f%%(%.2fs) of the runtime)'\
%(max(0, len(sotimes)-n_ops_to_print), %(max(0, len(sotimes)-n_ops_to_print),
sum(f for f, t, a in sotimes[n_ops_to_print:])*100, sum(f for f, t, a in sotimes[n_ops_to_print:])*100,
...@@ -150,3 +158,4 @@ def atexit_print_default_profile_mode(): ...@@ -150,3 +158,4 @@ def atexit_print_default_profile_mode():
#Register atexit_print_default_profile_mode to have the summary of the #Register atexit_print_default_profile_mode to have the summary of the
#predefined mode PROFILE_MODE if it is used printed when the program terminate. #predefined mode PROFILE_MODE if it is used printed when the program terminate.
atexit.register(atexit_print_default_profile_mode) atexit.register(atexit_print_default_profile_mode)
...@@ -29,13 +29,21 @@ class CallCache(object): ...@@ -29,13 +29,21 @@ class CallCache(object):
self.cache = {} self.cache = {}
def persist(self, filename=None): def persist(self, filename=None):
filename = self.filename if filename is None else filename if filename is None:
filename = self.filename
#backport
#filename = self.filename if filename is None else filename
f = file(filename, 'w') f = file(filename, 'w')
cPickle.dump(self.cache, f) cPickle.dump(self.cache, f)
f.close() f.close()
def call(self, fn, args=(), key=None): def call(self, fn, args=(), key=None):
key = (fn, tuple(args)) if key is None else key if key is None:
key = (fn, tuple(args))
#backport
#key = (fn, tuple(args)) if key is None else key
if key not in self.cache: if key not in self.cache:
debug('cache miss', len(self.cache)) debug('cache miss', len(self.cache))
self.cache[key] = fn(*args) self.cache[key] = fn(*args)
......
...@@ -5,7 +5,10 @@ Defines Linkers that deal with C implementations. ...@@ -5,7 +5,10 @@ Defines Linkers that deal with C implementations.
# Python imports # Python imports
from copy import copy from copy import copy
import re #for set_compiledir import re #for set_compiledir
import os, sys, platform, StringIO, time, hashlib import os, sys, platform, StringIO, time
import md5
if sys.version_info[:2] >= (2,5):
import hashlib
# weave import # weave import
from scipy import weave from scipy import weave
...@@ -37,7 +40,7 @@ def error(*args): ...@@ -37,7 +40,7 @@ def error(*args):
sys.stderr.write('ERROR:'+ ' '.join(str(a) for a in args)+'\n') sys.stderr.write('ERROR:'+ ' '.join(str(a) for a in args)+'\n')
_logger.error(' '.join(str(a) for a in args)) _logger.error(' '.join(str(a) for a in args))
from .callcache import CallCache from theano.gof.callcache import CallCache
def get_module_cache(): def get_module_cache():
return cmodule.get_module_cache(get_compiledir()) return cmodule.get_module_cache(get_compiledir())
...@@ -520,7 +523,12 @@ class CLinker(link.Linker): ...@@ -520,7 +523,12 @@ class CLinker(link.Linker):
# The hash calculated on the code identifies it so weave can cache properly. # The hash calculated on the code identifies it so weave can cache properly.
# (the hash has to be used outside of the support code because weave does not consider changes in the support code) # (the hash has to be used outside of the support code because weave does not consider changes in the support code)
hash = hashlib.md5(struct_code).hexdigest() # hashlib is new to 2.5
if sys.version_info[:2] < (2,5):
hash = md5.new(struct_code).hexdigest()
else:
hash = hashlib.md5(struct_code).hexdigest()
struct_name = '__struct_compiled_op_%s' % hash struct_name = '__struct_compiled_op_%s' % hash
#struct_code %= dict(name = struct_name) #struct_code %= dict(name = struct_name)
struct_code = re.sub("<<<<NAME>>>>", struct_name, struct_code) struct_code = re.sub("<<<<NAME>>>>", struct_name, struct_code)
...@@ -819,7 +827,10 @@ class CLinker(link.Linker): ...@@ -819,7 +827,10 @@ class CLinker(link.Linker):
""" """
This method is a callback for `ModuleCache.module_from_key` This method is a callback for `ModuleCache.module_from_key`
""" """
location = get_compiledir() if location is None else location if location is None:
location = get_compiledir()
#backport
#location = get_compiledir() if location is None else location
mod = self.build_dynamic_module() mod = self.build_dynamic_module()
get_lock() get_lock()
try: try:
......
...@@ -242,7 +242,8 @@ class ModuleCache(object): ...@@ -242,7 +242,8 @@ class ModuleCache(object):
self.module_from_name = dict(self.module_from_name) self.module_from_name = dict(self.module_from_name)
self.entry_from_key = dict(self.entry_from_key) self.entry_from_key = dict(self.entry_from_key)
self.stats = [0, 0, 0] self.stats = [0, 0, 0]
self.force_fresh = self.force_fresh if force_fresh is None else force_fresh if force_fresh is not None:
self.force_fresh = force_fresh
self.loaded_key_pkl = set() self.loaded_key_pkl = set()
self.refresh() self.refresh()
...@@ -398,7 +399,9 @@ class ModuleCache(object): ...@@ -398,7 +399,9 @@ class ModuleCache(object):
:param age_thresh: dynamic modules whose last access time is more than ``age_thresh`` :param age_thresh: dynamic modules whose last access time is more than ``age_thresh``
seconds ago will be erased. seconds ago will be erased.
""" """
age_thresh = self.age_thresh if age_thresh is None else age_thresh if age_thresh is None:
age_thresh = self.age_thresh
compilelock.get_lock() compilelock.get_lock()
try: try:
# update the age of modules that have been accessed by other processes # update the age of modules that have been accessed by other processes
...@@ -469,6 +472,12 @@ def get_lib_extension(): ...@@ -469,6 +472,12 @@ def get_lib_extension():
else: else:
return 'so' return 'so'
def get_gcc_shared_library_arg():
"""Return the platform-dependent GCC argument for shared libraries."""
if sys.platform == 'darwin':
return '-dynamiclib'
else:
return '-shared'
def std_include_dirs(): def std_include_dirs():
return [distutils.sysconfig.get_python_inc()] + numpy.distutils.misc_util.get_numpy_include_dirs() return [distutils.sysconfig.get_python_inc()] + numpy.distutils.misc_util.get_numpy_include_dirs()
...@@ -494,7 +503,6 @@ def std_libs(): ...@@ -494,7 +503,6 @@ def std_libs():
def std_lib_dirs(): def std_lib_dirs():
return std_lib_dirs_and_libs()[1] return std_lib_dirs_and_libs()[1]
def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[], lib_dirs=[], libs=[], def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[], lib_dirs=[], libs=[],
preargs=[]): preargs=[]):
""" """
...@@ -509,13 +517,34 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[] ...@@ -509,13 +517,34 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[]
:returns: dynamically-imported python module of the compiled code. :returns: dynamically-imported python module of the compiled code.
""" """
#TODO: don't to the dlimport in this function #TODO: don't to the dlimport in this function
preargs= [] if preargs is None else list(preargs)
if preargs is None:
preargs = []
else:
preargs = list(preargs)
preargs.append('-fPIC') preargs.append('-fPIC')
no_opt = False no_opt = False
include_dirs = std_include_dirs() + include_dirs include_dirs = std_include_dirs() + include_dirs
libs = std_libs() + libs libs = std_libs() + libs
lib_dirs = std_lib_dirs() + lib_dirs lib_dirs = std_lib_dirs() + lib_dirs
if sys.platform == 'win32':
python_inc = distutils.sysconfig.get_python_inc()
# Typical include directory: C:\Python26\include
libname = os.path.basename(os.path.dirname(python_inc)).lower()
# Also add directory containing the Python library to the library
# directories.
python_lib_dir = os.path.join(os.path.dirname(python_inc), 'libs')
lib_dirs = [python_lib_dir] + lib_dirs
else:
# Typical include directory: /usr/include/python2.6
python_inc = distutils.sysconfig.get_python_inc()
libname = os.path.basename(python_inc)
libs = [libname] + libs
workdir = location
cppfilename = os.path.join(location, 'mod.cpp') cppfilename = os.path.join(location, 'mod.cpp')
cppfile = file(cppfilename, 'w') cppfile = file(cppfilename, 'w')
...@@ -531,7 +560,7 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[] ...@@ -531,7 +560,7 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[]
(module_name, get_lib_extension())) (module_name, get_lib_extension()))
debug('Generating shared lib', lib_filename) debug('Generating shared lib', lib_filename)
cmd = ['g++', '-shared', '-g'] cmd = ['g++', get_gcc_shared_library_arg(), '-g']
if no_opt: if no_opt:
cmd.extend(p for p in preargs if not p.startswith('-O')) cmd.extend(p for p in preargs if not p.startswith('-O'))
else: else:
...@@ -556,10 +585,8 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[] ...@@ -556,10 +585,8 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[]
#touch the __init__ file #touch the __init__ file
file(os.path.join(location, "__init__.py"),'w').close() file(os.path.join(location, "__init__.py"),'w').close()
return dlimport(lib_filename) return dlimport(lib_filename)
def icc_module_compile_str(*args): def icc_module_compile_str(*args):
raise NotImplementedError() raise NotImplementedError()
"""WRITEME""" """WRITEME"""
from collections import defaultdict import sys
if sys.version_info[:2] >= (2,5):
from collections import defaultdict
# otherwise it's implemented in python25.py
import toolbox import toolbox
import graph import graph
......
...@@ -138,7 +138,11 @@ class Container(object): ...@@ -138,7 +138,11 @@ class Container(object):
self.type = r self.type = r
else: else:
self.type = r.type self.type = r.type
self.name = r.name if name is None else name if name is None:
self.name = r.name
#backport
#self.name = r.name if name is None else name
self.storage = storage self.storage = storage
self.readonly = readonly self.readonly = readonly
self.strict = strict self.strict = strict
......
...@@ -4,6 +4,7 @@ amount of useful generic optimization tools. ...@@ -4,6 +4,7 @@ amount of useful generic optimization tools.
""" """
import sys
import graph import graph
from env import InconsistencyError from env import InconsistencyError
import utils import utils
...@@ -11,9 +12,13 @@ import unify ...@@ -11,9 +12,13 @@ import unify
import toolbox import toolbox
import op import op
from copy import copy from copy import copy
from collections import deque, defaultdict from theano.gof.python25 import any, all
#if sys.version_info[:2] >= (2,5):
# from collections import defaultdict
from collections import deque
import destroyhandler as dh import destroyhandler as dh
import sys
import traceback import traceback
_optimizer_idx = [0] _optimizer_idx = [0]
...@@ -901,7 +906,12 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -901,7 +906,12 @@ class EquilibriumOptimizer(NavigatorOptimizer):
max_use_abort = True max_use_abort = True
else: else:
lopt_change = self.process_node(env, node, lopt) lopt_change = self.process_node(env, node, lopt)
process_count[lopt] += 1 if lopt_change else 0 if lopt_change:
process_count[lopt] += 1
else:
process_count[lopt] += 0
#backport
#process_count[lopt] += 1 if lopt_change else 0
changed |= lopt_change changed |= lopt_change
finally: finally:
self.detach_updater(env, u) self.detach_updater(env, u)
......
import sys, StringIO import sys, StringIO
from collections import defaultdict
if sys.version_info[:2] >= (2,5):
from collections import defaultdict
else:
from python25 import defaultdict
import opt import opt
......
...@@ -23,6 +23,42 @@ if sys.version_info[:2] < (2,5): ...@@ -23,6 +23,42 @@ if sys.version_info[:2] < (2,5):
newfunc.args = args newfunc.args = args
newfunc.keywords = keywords newfunc.keywords = keywords
return newfunc return newfunc
class defaultdict(dict):
def __init__(self, default_factory=None, *a, **kw):
if (default_factory is not None and
not hasattr(default_factory, '__call__')):
raise TypeError('first argument must be callable')
dict.__init__(self, *a, **kw)
self.default_factory = default_factory
def __getitem__(self, key):
try:
return dict.__getitem__(self, key)
except KeyError:
return self.__missing__(key)
def __missing__(self, key):
if self.default_factory is None:
raise KeyError(key)
self[key] = value = self.default_factory()
return value
def __reduce__(self):
if self.default_factory is None:
args = tuple()
else:
args = self.default_factory,
# consider replacing items() with iteritems()
return type(self), args, None, None, self.items()
def copy(self):
return self.__copy__()
def __copy__(self):
return type(self)(self.default_factory, self)
def __deepcopy__(self, memo):
import copy
return type(self)(self.default_factory,
copy.deepcopy(self.items()))
def __repr__(self):
return 'defaultdict(%s, %s)' % (self.default_factory,
dict.__repr__(self))
else: else:
# Only bother with this else clause and the __all__ line if you are putting # Only bother with this else clause and the __all__ line if you are putting
# this in a separate file. # this in a separate file.
......
import sys
if sys.version_info[:2] >= (2,5):
from functools import partial
from functools import partial
import graph import graph
import sys
class AlreadyThere(Exception): class AlreadyThere(Exception):
......
...@@ -32,7 +32,13 @@ class Print(Op): ...@@ -32,7 +32,13 @@ class Print(Op):
xout[0] = xin xout[0] = xin
for attr in self.attrs: for attr in self.attrs:
temp = getattr(xin, attr) temp = getattr(xin, attr)
print self.message, attr,'=', temp() if callable(temp) else temp if callable(temp):
pmsg = temp()
else:
psmg = temp
print self.message, attr,'=', pmsg
#backport
#print self.message, attr,'=', temp() if callable(temp) else temp
def grad(self,input,output_gradients): def grad(self,input,output_gradients):
return output_gradients return output_gradients
...@@ -233,7 +239,12 @@ class PPrinter: ...@@ -233,7 +239,12 @@ class PPrinter:
strings.append((i + 1000, "%s <- %s" % (name, pprinter.process(output)))) strings.append((i + 1000, "%s <- %s" % (name, pprinter.process(output))))
i += 1 i += 1
if output.name is not None or output in outputs: if output.name is not None or output in outputs:
name = 'out[%i]' % outputs.index(output) if output.name is None else output.name if output.name is None:
name = 'out[%i]' % outputs.index(output)
else:
name = output.name
#backport
#name = 'out[%i]' % outputs.index(output) if output.name is None else output.name
current = output current = output
try: try:
idx = 2000 + outputs.index(output) idx = 2000 + outputs.index(output)
......
...@@ -99,7 +99,7 @@ class ConvOp(Op): ...@@ -99,7 +99,7 @@ class ConvOp(Op):
new-=1 new-=1
print "OPTIMISATION WARNING: in ConvOp.__init__() unroll_batch(%s) must be 0 or a divisor of bsize(%s). We revert it to %d. This won't change the result, but may make it slower."%(str(self.unroll_batch),str(self.bsize),new) print "OPTIMISATION WARNING: in ConvOp.__init__() unroll_batch(%s) must be 0 or a divisor of bsize(%s). We revert it to %d. This won't change the result, but may make it slower."%(str(self.unroll_batch),str(self.bsize),new)
self.unroll_batch=mew self.unroll_batch=new
if self.unroll_kern>0 and self.nkern % unroll_kern!=0: if self.unroll_kern>0 and self.nkern % unroll_kern!=0:
if self.nkern<=self.unroll_kern: if self.nkern<=self.unroll_kern:
self.unroll_kern = self.nkern self.unroll_kern = self.nkern
......
...@@ -4,9 +4,9 @@ from copy import copy ...@@ -4,9 +4,9 @@ from copy import copy
import numpy import numpy
from .. import gof from theano import gof
from ..gof import Op, utils, Variable, Constant, Type, Apply, Env from theano.gof import Op, utils, Variable, Constant, Type, Apply, Env
from ..gof.python25 import partial from theano.gof.python25 import partial, all
def upcast(dtype, *dtypes): def upcast(dtype, *dtypes):
z = numpy.zeros((), dtype = dtype) z = numpy.zeros((), dtype = dtype)
...@@ -278,7 +278,14 @@ class transfer_type(gof.utils.object2): ...@@ -278,7 +278,14 @@ class transfer_type(gof.utils.object2):
self.transfer = transfer self.transfer = transfer
def __call__(self, *types): def __call__(self, *types):
upcast = upcast_out(*types) upcast = upcast_out(*types)
return [upcast if i is None else types[i] for i in self.transfer] retval = []
for i in self.transfer:
if i is None:
retval += [upcast]
else:
retval += [types[i]]
return retval
#return [upcast if i is None else types[i] for i in self.transfer]
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.transfer == other.transfer return type(self) == type(other) and self.transfer == other.transfer
def __hash__(self): def __hash__(self):
...@@ -465,8 +472,21 @@ class InRange(LogicalComparison): ...@@ -465,8 +472,21 @@ class InRange(LogicalComparison):
return False return False
return True return True
def c_code(self, node, name, (x, low, hi), (z, ), sub): def c_code(self, node, name, (x, low, hi), (z, ), sub):
cmp1 = '>' if self.openlow else '>=' if self.openlow:
cmp2 = '<' if self.openhi else '<=' cmp1 = '>'
else:
cmp1 = '>='
#backport
#cmp1 = '>' if self.openlow else '>='
if self.openhi:
cmp2 = '<'
else:
cmp2 = '<='
#backport
#cmp2 = '<' if self.openhi else '<='
return "%(z)s = %(x)s %(cmp1)s %(low)s && %(x)s %(cmp2)s %(hi)s;" % locals() return "%(z)s = %(x)s %(cmp1)s %(low)s && %(x)s %(cmp2)s %(hi)s;" % locals()
def grad(self, (x, low, hi), (gz, )): def grad(self, (x, low, hi), (gz, )):
return None, None, None return None, None, None
...@@ -476,13 +496,32 @@ inclosedrange = InRange(False, False) ...@@ -476,13 +496,32 @@ inclosedrange = InRange(False, False)
class Switch(ScalarOp): class Switch(ScalarOp):
nin = 3 nin = 3
def impl(self, cond, ift, iff): def impl(self, cond, ift, iff):
return ift if cond else iff if cond:
return ift
else:
return iff
#backport
#return ift if cond else iff
def c_code(self, node, name, (cond, ift, iff), (z, ), sub): def c_code(self, node, name, (cond, ift, iff), (z, ), sub):
return "%(z)s = %(cond)s ? %(ift)s : %(iff)s;" % locals() return "%(z)s = %(cond)s ? %(ift)s : %(iff)s;" % locals()
def grad(self, (cond, ift, iff), (gz, )): def grad(self, (cond, ift, iff), (gz, )):
return (None, if ift.type in grad_types:
switch(cond, gz, 0) if ift.type in grad_types else None, first_part = switch(cond, gz, 0)
switch(cond, 0, gz) if iff.type in grad_types else None) else:
first_part = None
if iff.type in grad_types:
second_part = switch(cond, 0, gz)
else:
second_part = None
return (None, first_part, second_part)
#return (None,
# switch(cond, gz, 0) if ift.type in grad_types else None,
# switch(cond, 0, gz) if iff.type in grad_types else None)
def output_types(self, (cond_t, ift_t, iff_t)): def output_types(self, (cond_t, ift_t, iff_t)):
return upcast_out(ift_t, iff_t) return upcast_out(ift_t, iff_t)
switch = Switch() switch = Switch()
...@@ -558,7 +597,15 @@ class Add(ScalarOp): ...@@ -558,7 +597,15 @@ class Add(ScalarOp):
else: else:
return z + " = " + " + ".join(inputs) + ";" return z + " = " + " + ".join(inputs) + ";"
def grad(self, inputs, (gz, )): def grad(self, inputs, (gz, )):
return [(gz if i.type in grad_types else None) for i in inputs] retval = []
for i in inputs:
if i.type in grad_types:
retval += [gz]
else:
retval += [None]
return retval
#backport
#return [(gz if i.type in grad_types else None) for i in inputs]
add = Add(upcast_out, name = 'add') add = Add(upcast_out, name = 'add')
class Mul(ScalarOp): class Mul(ScalarOp):
...@@ -573,9 +620,18 @@ class Mul(ScalarOp): ...@@ -573,9 +620,18 @@ class Mul(ScalarOp):
else: else:
return z + " = " + " * ".join(inputs) + ";" return z + " = " + " * ".join(inputs) + ";"
def grad(self, inputs, (gz, )): def grad(self, inputs, (gz, )):
return [(mul(*([gz] + utils.difference(inputs, [input]))) retval = []
if input.type in grad_types else None) for input in inputs:
for input in inputs] if input.type in grad_types:
retval += [mul(*([gz] + utils.difference(inputs, [input])))]
else:
retval += [None]
return retval
#return [(mul(*([gz] + utils.difference(inputs, [input])))
# if input.type in grad_types else None)
# for input in inputs]
mul = Mul(upcast_out, name = 'mul') mul = Mul(upcast_out, name = 'mul')
class Sub(BinaryScalarOp): class Sub(BinaryScalarOp):
...@@ -584,7 +640,19 @@ class Sub(BinaryScalarOp): ...@@ -584,7 +640,19 @@ class Sub(BinaryScalarOp):
def c_code(self, node, name, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = %(x)s - %(y)s;" % locals() return "%(z)s = %(x)s - %(y)s;" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
return gz if x.type in grad_types else None, -gz if y.type in grad_types else None if x.type in grad_types:
first_part = gz
else:
first_part = None
if y.type in grad_types:
second_part = -gz
else:
second_part = None
return first_part, second_part
#return gz if x.type in grad_types else None, -gz if y.type in grad_types else None
sub = Sub(upcast_out, name = 'sub') sub = Sub(upcast_out, name = 'sub')
def div_proxy(x, y): def div_proxy(x, y):
...@@ -613,8 +681,20 @@ class TrueDiv(BinaryScalarOp): ...@@ -613,8 +681,20 @@ class TrueDiv(BinaryScalarOp):
return "%(z)s = ((double)%(x)s) / %(y)s;" % locals() return "%(z)s = ((double)%(x)s) / %(y)s;" % locals()
return "%(z)s = %(x)s / %(y)s;" % locals() return "%(z)s = %(x)s / %(y)s;" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
return (gz / y if x.type in grad_types else None, if x.type in grad_types:
-(gz * x) / (y * y) if y.type in grad_types else None) first_part = gz / y
else:
first_part = None
if y.type in grad_types:
second_part = -(gz * x) / (y * y)
else:
second_part = None
return (first_part, second_part)
#return (gz / y if x.type in grad_types else None,
# -(gz * x) / (y * y) if y.type in grad_types else None)
true_div = TrueDiv(upcast_out, name = 'true_div') true_div = TrueDiv(upcast_out, name = 'true_div')
class IntDiv(BinaryScalarOp): class IntDiv(BinaryScalarOp):
...@@ -642,19 +722,43 @@ class Pow(BinaryScalarOp): ...@@ -642,19 +722,43 @@ class Pow(BinaryScalarOp):
def c_code(self, node, name, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = pow(%(x)s, %(y)s);" % locals() return "%(z)s = pow(%(x)s, %(y)s);" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
return (gz * y * x**(y - 1) if x.type in grad_types else None, if x.type in grad_types:
gz * log(x) * x**y if y.type in grad_types else None) first_part = gz * y * x**(y - 1)
else:
first_part = None
if y.type in grad_types:
second_part = gz * log(x) * x**y
else:
second_part = None
return (first_part, second_part)
#return (gz * y * x**(y - 1) if x.type in grad_types else None,
# gz * log(x) * x**y if y.type in grad_types else None)
pow = Pow(upcast_out, name = 'pow') pow = Pow(upcast_out, name = 'pow')
class Clip(ScalarOp): class Clip(ScalarOp):
nin = 3 nin = 3
def impl(self, x, min, max): def impl(self, x, min, max):
return min if x < min else max if x > max else x if x < min:
return min
elif x > max:
return max
else:
return x
#return min if x < min else max if x > max else x
def c_code(self, node, name, (x, min, max), (z, ), sub): def c_code(self, node, name, (x, min, max), (z, ), sub):
return "%(z)s = %(x)s < %(min)s ? %(min)s : %(x)s > %(max)s ? %(max)s : %(x)s;" % locals() return "%(z)s = %(x)s < %(min)s ? %(min)s : %(x)s > %(max)s ? %(max)s : %(x)s;" % locals()
def grad(self, (x, min, max), (gz, )): def grad(self, (x, min, max), (gz, )):
gx = ((x > min) & (x < max)) * gz gx = ((x > min) & (x < max)) * gz
return gx if x.type in grad_types else None, None, None if x.type in grad_types:
return gx
else:
return None,None,None
#return gx if x.type in grad_types else None, None, None
clip = Clip(transfer_type(0), name = 'clip') clip = Clip(transfer_type(0), name = 'clip')
class First(BinaryScalarOp): class First(BinaryScalarOp):
...@@ -663,7 +767,12 @@ class First(BinaryScalarOp): ...@@ -663,7 +767,12 @@ class First(BinaryScalarOp):
def c_code(self, node, name, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = %(x)s;" % locals() return "%(z)s = %(x)s;" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
return gz if x.type in grad_types else None, None if x.type in grad_types:
return gz
else:
return None,None
#backport
#return gz if x.type in grad_types else None, None
first = First(transfer_type(0), name = 'first') first = First(transfer_type(0), name = 'first')
class Second(BinaryScalarOp): class Second(BinaryScalarOp):
...@@ -672,7 +781,13 @@ class Second(BinaryScalarOp): ...@@ -672,7 +781,13 @@ class Second(BinaryScalarOp):
def c_code(self, node, name, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = %(y)s;" % locals() return "%(z)s = %(y)s;" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
return None, gz if y.type in grad_types else None if y.type in grad_types:
return None, gz
else:
return None
#backport
#return None, gz if y.type in grad_types else None
second = Second(transfer_type(1), name = 'second') second = Second(transfer_type(1), name = 'second')
...@@ -683,7 +798,13 @@ class Identity(UnaryScalarOp): ...@@ -683,7 +798,13 @@ class Identity(UnaryScalarOp):
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = %(x)s;" % locals() return "%(z)s = %(x)s;" % locals()
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz if x.type in grad_types else None, if x.type in grad_types:
return gz,
else:
return None,
#backport
#return gz if x.type in grad_types else None,
identity = Identity(same_out, name = 'identity') identity = Identity(same_out, name = 'identity')
class Abs(UnaryScalarOp): class Abs(UnaryScalarOp):
...@@ -699,7 +820,12 @@ class Abs(UnaryScalarOp): ...@@ -699,7 +820,12 @@ class Abs(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.abs(x) return numpy.abs(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz * sgn(x) if x.type in grad_types else None, if x.type in grad_types:
return gz * sgn(x),
else:
return None,
#backport
#return gz * sgn(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
type = node.inputs[0].type type = node.inputs[0].type
if type in int_types: if type in int_types:
...@@ -735,7 +861,12 @@ class Neg(UnaryScalarOp): ...@@ -735,7 +861,12 @@ class Neg(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return -x return -x
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return -gz if x.type in grad_types else None, if x.type in grad_types:
return -gz,
else:
return None,
#backport
#return -gz if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = -%(x)s;" % locals() return "%(z)s = -%(x)s;" % locals()
neg = Neg(same_out, name = 'neg') neg = Neg(same_out, name = 'neg')
...@@ -744,7 +875,13 @@ class Inv(UnaryScalarOp): ...@@ -744,7 +875,13 @@ class Inv(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return 1.0 / x return 1.0 / x
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return -gz / (x * x) if x.type in grad_types else None, if x.type in grad_types:
return -gz / (x * x),
else:
return None,
#backport
#return -gz / (x * x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = 1.0 / %(x)s;" % locals() return "%(z)s = 1.0 / %(x)s;" % locals()
inv = Inv(upgrade_to_float, name = 'inv') inv = Inv(upgrade_to_float, name = 'inv')
...@@ -753,7 +890,12 @@ class Log(UnaryScalarOp): ...@@ -753,7 +890,12 @@ class Log(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.log(x) return math.log(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz / x if x.type in grad_types else None, if x.type in grad_types:
return gz / x,
else:
return None,
#backport
#return gz / x if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
#todo: the version using log2 seems to be very slightly faster #todo: the version using log2 seems to be very slightly faster
# on some machines for some reason, check if it's worth switching # on some machines for some reason, check if it's worth switching
...@@ -765,7 +907,13 @@ class Log2(UnaryScalarOp): ...@@ -765,7 +907,13 @@ class Log2(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.log2(x) return numpy.log2(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz / (x * math.log(2.0)) if x.type in grad_types else None, if x.type in grad_types:
return gz / (x * math.log(2.0)),
else:
return None,
#backport
#return gz / (x * math.log(2.0)) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = log2(%(x)s);" % locals() return "%(z)s = log2(%(x)s);" % locals()
log2 = Log2(upgrade_to_float, name = 'log2') log2 = Log2(upgrade_to_float, name = 'log2')
...@@ -774,7 +922,13 @@ class Log10(UnaryScalarOp): ...@@ -774,7 +922,13 @@ class Log10(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.log10(x) return numpy.log10(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz / (x * math.log(10.0)) if x.type in grad_types else None, if x.type in grad_types:
return gz / (x * math.log(10.0)),
else:
return None
#backport
#return gz / (x * math.log(10.0)) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = log10(%(x)s);" % locals() return "%(z)s = log10(%(x)s);" % locals()
log10 = Log10(upgrade_to_float, name = 'log10') log10 = Log10(upgrade_to_float, name = 'log10')
...@@ -783,7 +937,13 @@ class Exp(UnaryScalarOp): ...@@ -783,7 +937,13 @@ class Exp(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.exp(x) return math.exp(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz * exp(x) if x.type in grad_types else None, if x.type in grad_types:
return gz * exp(x),
else:
return None,
#backport
#return gz * exp(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = exp(%(x)s);" % locals() return "%(z)s = exp(%(x)s);" % locals()
exp = Exp(upgrade_to_float, name = 'exp') exp = Exp(upgrade_to_float, name = 'exp')
...@@ -792,7 +952,13 @@ class Sqr(UnaryScalarOp): ...@@ -792,7 +952,13 @@ class Sqr(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return x*x return x*x
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz * x * 2 if x.type in grad_types else None, if x.type in grad_types:
return gz * x * 2,
else:
return None,
#backport
# return gz * x * 2 if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = %(x)s * %(x)s;" % locals() return "%(z)s = %(x)s * %(x)s;" % locals()
sqr = Sqr(same_out, name = 'sqr') sqr = Sqr(same_out, name = 'sqr')
...@@ -801,7 +967,12 @@ class Sqrt(UnaryScalarOp): ...@@ -801,7 +967,12 @@ class Sqrt(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.sqrt(x) return math.sqrt(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return (gz * 0.5) / sqrt(x) if x.type in grad_types else None, if x.type in grad_types:
return (gz * 0.5) / sqrt(x),
else:
return None,
#backport
#return (gz * 0.5) / sqrt(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = sqrt(%(x)s);" % locals() return "%(z)s = sqrt(%(x)s);" % locals()
sqrt = Sqrt(upgrade_to_float, name = 'sqrt') sqrt = Sqrt(upgrade_to_float, name = 'sqrt')
...@@ -810,7 +981,12 @@ class Cos(UnaryScalarOp): ...@@ -810,7 +981,12 @@ class Cos(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.cos(x) return math.cos(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return -gz * sin(x) if x.type in grad_types else None, if x.type in grad_types:
return -gz * sin(x),
else:
return None,
#backport
# return -gz * sin(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = cos(%(x)s);" % locals() return "%(z)s = cos(%(x)s);" % locals()
cos = Cos(upgrade_to_float, name = 'cos') cos = Cos(upgrade_to_float, name = 'cos')
...@@ -819,7 +995,12 @@ class Sin(UnaryScalarOp): ...@@ -819,7 +995,12 @@ class Sin(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.sin(x) return math.sin(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz * cos(x) if x.type in grad_types else None, if x.type in grad_types:
return gz * cos(x),
else:
return None,
#backport
# return gz * cos(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = sin(%(x)s);" % locals() return "%(z)s = sin(%(x)s);" % locals()
sin = Sin(upgrade_to_float, name = 'sin') sin = Sin(upgrade_to_float, name = 'sin')
...@@ -828,7 +1009,12 @@ class Tan(UnaryScalarOp): ...@@ -828,7 +1009,12 @@ class Tan(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.tan(x) return math.tan(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz / sqr(cos(x)) if x.type in grad_types else None, if x.type in grad_types:
return gz / sqr(cos(x)),
else:
return None,
#backport
#return gz / sqr(cos(x)) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = tan(%(x)s);" % locals() return "%(z)s = tan(%(x)s);" % locals()
tan = Tan(upgrade_to_float, name = 'tan') tan = Tan(upgrade_to_float, name = 'tan')
...@@ -840,7 +1026,12 @@ class Cosh(UnaryScalarOp): ...@@ -840,7 +1026,12 @@ class Cosh(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.cosh(x) return math.cosh(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz * sinh(x) if x.type in grad_types else None, if x.type in grad_types:
return gz * sinh(x),
else:
return None,
#backport
#return gz * sinh(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = cosh(%(x)s);" % locals() return "%(z)s = cosh(%(x)s);" % locals()
cosh = Cosh(upgrade_to_float, name = 'cosh') cosh = Cosh(upgrade_to_float, name = 'cosh')
...@@ -852,7 +1043,12 @@ class Sinh(UnaryScalarOp): ...@@ -852,7 +1043,12 @@ class Sinh(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.sinh(x) return math.sinh(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz * cosh(x) if x.type in grad_types else None, if x.type in grad_types:
return gz * cosh(x),
else:
return None,
#backport
#return gz * cosh(x) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = sinh(%(x)s);" % locals() return "%(z)s = sinh(%(x)s);" % locals()
sinh = Sinh(upgrade_to_float, name = 'sinh') sinh = Sinh(upgrade_to_float, name = 'sinh')
...@@ -865,7 +1061,12 @@ class Tanh(UnaryScalarOp): ...@@ -865,7 +1061,12 @@ class Tanh(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.tanh(x) return math.tanh(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz * (1 - sqr(tanh(x))) if x.type in grad_types else None, if x.type in grad_types:
return gz * (1 - sqr(tanh(x))),
else:
return None,
#backport
#return gz * (1 - sqr(tanh(x))) if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = tanh(%(x)s);" % locals() return "%(z)s = tanh(%(x)s);" % locals()
tanh = Tanh(upgrade_to_float, name = 'tanh') tanh = Tanh(upgrade_to_float, name = 'tanh')
......
...@@ -12,10 +12,10 @@ from scipy import sparse ...@@ -12,10 +12,10 @@ from scipy import sparse
import scipy.sparse import scipy.sparse
from theano.printing import Print from theano.printing import Print
from .. import gof from theano import gof
from .. import tensor from theano import tensor
from .. import compile from theano import compile
from .. import scalar from theano import scalar
#TODO: move this decorator to the compile submodule #TODO: move this decorator to the compile submodule
def register_specialize(lopt, *tags, **kwargs): def register_specialize(lopt, *tags, **kwargs):
...@@ -273,7 +273,12 @@ class CSMProperties(gof.Op): ...@@ -273,7 +273,12 @@ class CSMProperties(gof.Op):
[data, tensor.ivector(), tensor.ivector(), tensor.ivector()]) [data, tensor.ivector(), tensor.ivector(), tensor.ivector()])
def perform(self, node, (csm,), out): def perform(self, node, (csm,), out):
out[0][0] = csm.data if self.kmap is None else csm.data[self.kmap] if self.kmap is None:
out[0][0] = csm.data
else:
out[0][0] = csm.data[self.kmap]
#backport
#out[0][0] = csm.data if self.kmap is None else csm.data[self.kmap]
out[1][0] = numpy.asarray(csm.indices, dtype='int32') out[1][0] = numpy.asarray(csm.indices, dtype='int32')
out[2][0] = numpy.asarray(csm.indptr, dtype='int32') out[2][0] = numpy.asarray(csm.indptr, dtype='int32')
out[3][0] = numpy.asarray(csm.shape, dtype='int32') out[3][0] = numpy.asarray(csm.shape, dtype='int32')
...@@ -1082,8 +1087,19 @@ register_specialize(local_structured_dot) ...@@ -1082,8 +1087,19 @@ register_specialize(local_structured_dot)
def structured_dot_grad(sparse_A, dense_B, ga): def structured_dot_grad(sparse_A, dense_B, ga):
if sparse_A.type.format in ('csc','csr'): if sparse_A.type.format in ('csc','csr'):
sdgcsx = sdg_csc if sparse_A.type.format == 'csc' else sdg_csr if sparse_A.type.format == 'csc':
CSx = CSC if sparse_A.type.format == 'csc' else CSR sdgcsx = sdg_csc
else:
sdgcsx = sdg_csr
#backport
#sdgcsx = sdg_csc if sparse_A.type.format == 'csc' else sdg_csr
if sparse_A.type.format == 'csc':
CSx = CSC
else:
CSx = CSR
#backport
#CSx = CSC if sparse_A.type.format == 'csc' else CSR
g_A_data = sdgcsx(csm_indices(sparse_A),\ g_A_data = sdgcsx(csm_indices(sparse_A),\
csm_indptr(sparse_A), dense_B, ga) csm_indptr(sparse_A), dense_B, ga)
......
...@@ -5,22 +5,23 @@ __docformat__ = "restructuredtext en" ...@@ -5,22 +5,23 @@ __docformat__ = "restructuredtext en"
import __builtin__ import __builtin__
import sys # for sys.maxint import sys # for sys.maxint
import traceback #for overriding Op.__call__ import traceback #for overriding Op.__call__
import functools if sys.version_info >= (2,5):
import functools
import numpy import numpy
from copy import copy from copy import copy
from .. import gof from theano import gof
from ..gof import Variable, Op, utils, Type, Constant, Apply, Value from theano.gof import Variable, Op, utils, Type, Constant, Apply, Value
from .. import gradient from theano import gradient
import elemwise import elemwise
from .. import scalar as scal from theano import scalar as scal
from ..gof.python25 import partial from theano.gof.python25 import partial, any
from .. import compile, printing from theano import compile, printing
from ..printing import pprint, Print from theano.printing import pprint, Print
### set up the external interface ### set up the external interface
from elemwise import Elemwise, DimShuffle, CAReduce, Sum from elemwise import Elemwise, DimShuffle, CAReduce, Sum
...@@ -352,12 +353,18 @@ class TensorType(Type): ...@@ -352,12 +353,18 @@ class TensorType(Type):
return self.name return self.name
else: else:
b = self.broadcastable b = self.broadcastable
#bcast = str(self.broadcastable) named_broadcastable = {(): 'scalar',
bcast = {(): 'scalar',
(False,): 'vector', (False,): 'vector',
(False, True): 'col', (False, True): 'col',
(True, False): 'row', (True, False): 'row',
(False, False): 'matrix'}.get(b, "%iD" % len(b) if not any(b) else str(b)) (False, False): 'matrix'}
if b in named_broadcastable:
bcast = named_broadcastable[b]
else:
if any(b):
bcast = str(b)
else:
bcast = '%iD' % len(b)
return "TensorType(%s, %s)" % (str(self.dtype), bcast) return "TensorType(%s, %s)" % (str(self.dtype), bcast)
def __repr__(self): def __repr__(self):
...@@ -833,7 +840,11 @@ def _scal_elemwise(symbol): ...@@ -833,7 +840,11 @@ def _scal_elemwise(symbol):
"""Replace a symbol definition with an elementwise version of the corresponding scalar Op""" """Replace a symbol definition with an elementwise version of the corresponding scalar Op"""
symbolname = symbol.__name__ symbolname = symbol.__name__
inplace = symbolname.endswith('_inplace') inplace = symbolname.endswith('_inplace')
n="Elemwise{%s,%s}"%(symbolname,"inplace" if inplace else "no_inplace") if inplace:
msg = "inplace"
else:
msg = "no_inplace"
n="Elemwise{%s,%s}"%(symbolname,msg)
if inplace: if inplace:
scalar_op = getattr(scal, symbolname[:-len('_inplace')]) scalar_op = getattr(scal, symbolname[:-len('_inplace')])
...@@ -987,11 +998,15 @@ class MaxAndArgmax(Op): ...@@ -987,11 +998,15 @@ class MaxAndArgmax(Op):
if not ( axis.data == 0 or axis.data == x.ndim-1): if not ( axis.data == 0 or axis.data == x.ndim-1):
raise NotImplementedError('MaxAndArgmax gradient with axis corresponding to internal dimension') raise NotImplementedError('MaxAndArgmax gradient with axis corresponding to internal dimension')
g_max_pad = shape_padleft(g_max) if axis.data==0 else \ if axis.data==0:
shape_padright(g_max) g_max_pad = shape_padleft(g_max)
else:
g_max_pad = shape_padright(g_max)
xmax = max(x, axis) xmax = max(x, axis)
xmax_pad = shape_padleft(xmax) if axis.data==0 else \ if axis.data==0:
shape_padright(xmax) xmax_pad = shape_padleft(xmax)
else:
xmax_pad = shape_padright(xmax)
g_x = eq(xmax_pad, x) * g_max_pad g_x = eq(xmax_pad, x) * g_max_pad
return g_x, None return g_x, None
...@@ -1261,10 +1276,10 @@ class Filler(gof.Op): ...@@ -1261,10 +1276,10 @@ class Filler(gof.Op):
def __hash__(self): def __hash__(self):
return hash(self.ndim) ^ hash(self.dtype) return hash(self.ndim) ^ hash(self.dtype)
Zeros = functools.partial(Filler, 0) Zeros = partial(Filler, 0)
"""WRITEME""" """WRITEME"""
Ones = functools.partial(Filler, 1) Ones = partial(Filler, 1)
"""WRITEME""" """WRITEME"""
@constructor @constructor
...@@ -1368,8 +1383,18 @@ class Repeat(gof.Op): ...@@ -1368,8 +1383,18 @@ class Repeat(gof.Op):
assert isinstance(input.type, TensorType) assert isinstance(input.type, TensorType)
assert repeats.type == iscalar assert repeats.type == iscalar
assert axis.type == iscalar assert axis.type == iscalar
type = TensorType(dtype = input.type.dtype, broadcastable = []
broadcastable = [False if i==axis else x for i, x in enumerate(input.broadcastable)]) for i,x in enumerate(input.broadcastable):
if i==axis:
broadcastable += [False]
else:
broadcastable += [x]
type = TensorType(dtype = input.type.dtype, broadcastable = \
broadcastable)
#backport
#type = TensorType(dtype = input.type.dtype,
# broadcastable = [False if i==axis else x for i, x in enumerate(input.broadcastable)])
return gof.Apply(self, [inputs, repeats, axis], [type()]) return gof.Apply(self, [inputs, repeats, axis], [type()])
def perform(self, node, (input, repeats, axis), (out, )): def perform(self, node, (input, repeats, axis), (out, )):
...@@ -1480,27 +1505,46 @@ class Subtensor(Op): ...@@ -1480,27 +1505,46 @@ class Subtensor(Op):
@staticmethod @staticmethod
def convert(entry, slice_ok=True): def convert(entry, slice_ok=True):
scal_types = [scal.int64, scal.int32, scal.int16, scal.int8] scal_types = [scal.int64, scal.int32, scal.int16, scal.int8]
tensor_types = [bscalar, iscalar, lscalar] tensor_types = [bscalar, iscalar, lscalar]
if isinstance(entry, gof.Variable) and entry.type in scal_types: if isinstance(entry, gof.Variable) and entry.type in scal_types:
return entry.type return entry.type
elif isinstance(entry, gof.Type) and entry in scal_types: elif isinstance(entry, gof.Type) and entry in scal_types:
return entry return entry
if isinstance(entry, gof.Variable) and entry.type in tensor_types: if isinstance(entry, gof.Variable) and entry.type in tensor_types:
return scal.Scalar(entry.type.dtype) return scal.Scalar(entry.type.dtype)
elif isinstance(entry, gof.Type) and entry in tensor_types: elif isinstance(entry, gof.Type) and entry in tensor_types:
return scal.Scalar(entry.dtype) return scal.Scalar(entry.dtype)
elif slice_ok and isinstance(entry, slice): elif slice_ok and isinstance(entry, slice):
a = entry.start a = entry.start
b = entry.stop b = entry.stop
c = entry.step c = entry.step
return slice(Subtensor.convert(a, False) if a is not None else None,
Subtensor.convert(b, False) if b is not None else None, if a is not None:
Subtensor.convert(c, False) if c is not None else None) slice_a = Subtensor.convert(a, False)
elif isinstance(entry, int):
return entry
else: else:
raise TypeError(Subtensor.e_indextype, entry) slice_a = None
if b is not None:
slice_b = Subtensor.convert(b, False)
else:
slice_b = None
if c is not None:
slice_c = Subtensor.convert(c, False)
else:
slice_c = None
return slice(slice_a,slice_b,slice_c)
#backport
#return slice(Subtensor.convert(a, False) if a is not None else None,
# Subtensor.convert(b, False) if b is not None else None,
# Subtensor.convert(c, False) if c is not None else None)
elif isinstance(entry, int):
return entry
else:
raise TypeError(Subtensor.e_indextype, entry)
def __init__(self, idx_list): def __init__(self, idx_list):
self.idx_list = map(self.convert, idx_list) self.idx_list = map(self.convert, idx_list)
...@@ -1564,17 +1608,34 @@ class Subtensor(Op): ...@@ -1564,17 +1608,34 @@ class Subtensor(Op):
def __hash__(self): def __hash__(self):
#TODO: optimize by cache this hash value #TODO: optimize by cache this hash value
idx_list = tuple((entry.start, entry.stop, entry.step) msg = []
if isinstance(entry, slice) for entry in self.idx_list:
else entry if isinstance(entry, slice):
for entry in self.idx_list) msg += [(entry.start, entry.stop, entry.step)]
else:
msg += [entry]
idx_list = tuple(msg)
#backport
#idx_list = tuple((entry.start, entry.stop, entry.step)
# if isinstance(entry, slice)
# else entry
# for entry in self.idx_list)
return hash(idx_list) return hash(idx_list)
def __str__(self): def __str__(self):
indices = [] indices = []
for entry in self.idx_list: for entry in self.idx_list:
if isinstance(entry, slice): if isinstance(entry, slice):
indices.append(":".join("" if x is None else str(x) for x in [entry.start, entry.stop, entry.step])) msg = []
for x in [entry.start, entry.stop, entry.step]:
if x is None:
msg += ""
else:
msg += [str(x)]
indices.append(":".join(msg))
#backport
#indices.append(":".join("" if x is None else str(x) for x in [entry.start, entry.stop, entry.step]))
else: else:
indices.append(str(entry)) indices.append(str(entry))
return "%s{%s}" % (self.__class__.__name__, ", ".join(indices)) return "%s{%s}" % (self.__class__.__name__, ", ".join(indices))
...@@ -1598,11 +1659,27 @@ class SubtensorPrinter: ...@@ -1598,11 +1659,27 @@ class SubtensorPrinter:
elif isinstance(entry, scal.Scalar): elif isinstance(entry, scal.Scalar):
sidxs.append(inbrack_pstate.pprinter.process(inputs.pop())) sidxs.append(inbrack_pstate.pprinter.process(inputs.pop()))
elif isinstance(entry, slice): elif isinstance(entry, slice):
sidxs.append("%s:%s%s" % ("" if entry.start is None or entry.start == 0 else entry.start, if entry.start is None or entry.start==0:
"" if entry.stop is None or entry.stop == sys.maxint else entry.stop, msg1 = ""
"" if entry.step is None else ":%s" % entry.step)) else:
return "%s[%s]" % (pstate.pprinter.process(input, pstate.clone(precedence = 1000)), msg1 = entry.start
", ".join(sidxs))
if entry.stop is None or entry.stop == sys.maxint:
msg2 = ""
else:
msg2 = entry.stop
if entry.step is None:
msg3 = ""
else:
msg3 = ":%s" % entry.step
sidxs.append("%s:%s%s" % (msg1, msg2, msg3))
#backport
#sidxs.append("%s:%s%s" % ("" if entry.start is None or entry.start == 0 else entry.start,
# "" if entry.stop is None or entry.stop == sys.maxint else entry.stop,
# "" if entry.step is None else ":%s" % entry.step))
return "%s[%s]" % (pstate.pprinter.process(input, pstate.clone(precedence = 1000)), ", ".join(sidxs))
else: else:
raise TypeError("Can only print Subtensor.") raise TypeError("Can only print Subtensor.")
...@@ -1631,20 +1708,44 @@ class SetSubtensor(Op): ...@@ -1631,20 +1708,44 @@ class SetSubtensor(Op):
and self.inplace == other.inplace and self.inplace == other.inplace
def __hash__(self): def __hash__(self):
idx_list = tuple((entry.start, entry.stop, entry.step) msg = []
if isinstance(entry, slice) for entry in self.idx_list:
else entry if isinstance(entry, slice):
for entry in self.idx_list) msg += [(entry.start, entry.stop, entry.step)]
else:
msg += [entry]
idx_list = tuple(msg)
#backport
#idx_list = tuple((entry.start, entry.stop, entry.step)
# if isinstance(entry, slice)
# else entry
# for entry in self.idx_list)
return hashtype(self) ^ hash(idx_list) ^ hash(self.inplace) return hashtype(self) ^ hash(idx_list) ^ hash(self.inplace)
def __str__(self): def __str__(self):
indices = [] indices = []
for entry in self.idx_list: for entry in self.idx_list:
if isinstance(entry, slice): if isinstance(entry, slice):
indices.append(":".join("" if x is None else str(x) for x in [entry.start, entry.stop, entry.step])) msg = []
for x in [entry.start, entry.stop, entry.step]:
if x is None:
msg += ""
else:
msg += [str(x)]
indices.append(":".join(msg))
#backport
#indices.append(":".join("" if x is None else str(x) for x in [entry.start, entry.stop, entry.step]))
else: else:
indices.append(str(entry)) indices.append(str(entry))
return "%s%s{%s}" % ('Inplace' if self.inplace else '', if self.inplace:
msg = 'Inplace'
else:
msg = ''
#backport
#return "%s%s{%s}" % ('Inplace' if self.inplace else '',
return "%s%s{%s}" % (msg,
self.__class__.__name__, ", ".join(indices)) self.__class__.__name__, ", ".join(indices))
def make_node(self, x, y, *inputs): def make_node(self, x, y, *inputs):
...@@ -2119,12 +2220,10 @@ class Reshape(Op): ...@@ -2119,12 +2220,10 @@ class Reshape(Op):
def __eq__(self, other): def __eq__(self, other):
# .name does not participate because it doesn't affect computations # .name does not participate because it doesn't affect computations
return (type(other) == type(self)) and (other.ndim == self.ndim) return (type(other) is Reshape) and (other.ndim == self.ndim)
def __hash__(self): def __hash__(self):
# .name does not participate because it doesn't affect computations # .name does not participate because it doesn't affect computations
return hash(type(self)) ^ hash(self.ndim) return hash(Reshape) ^ hash(self.ndim)
def __str__(self):
return '%s{%i}' % (self.__class__.__name__, self.ndim)
def make_node(self, x, shp): def make_node(self, x, shp):
x = as_tensor_variable(x) x = as_tensor_variable(x)
shp = as_tensor_variable(shp) shp = as_tensor_variable(shp)
...@@ -2218,7 +2317,11 @@ class Tile(Op): ...@@ -2218,7 +2317,11 @@ class Tile(Op):
def tile(x, reps, ndim=None): def tile(x, reps, ndim=None):
if not hasattr(tile, 'op'): if not hasattr(tile, 'op'):
tile.op = {} tile.op = {}
ndim = len(reps) if ndim is None else ndim #not sure if len(shp) is going to work. if ndim is None:
ndim = len(reps)
#backport
#ndim = len(reps) if ndim is None else ndim #not sure if len(shp) is going to work.
if ndim not in tile.op: if ndim not in tile.op:
tile.op[ndim] = Tile(ndim) tile.op[ndim] = Tile(ndim)
return tile.op[ndim](x, reps) return tile.op[ndim](x, reps)
......
...@@ -3,19 +3,19 @@ ...@@ -3,19 +3,19 @@
import os, sys, traceback import os, sys, traceback
import numpy import numpy
from ..gof import (utils, Op, Apply, view_roots, PatternSub, DestroyHandler, from theano.gof import (utils, Op, Apply, view_roots, PatternSub, DestroyHandler,
SeqOptimizer, local_optimizer, Optimizer, LocalOptimizer, OpKeyOptimizer, SeqOptimizer, local_optimizer, Optimizer, LocalOptimizer, OpKeyOptimizer,
InconsistencyError, toolbox) InconsistencyError, toolbox)
from ..printing import pprint, FunctionPrinter from theano.printing import pprint, FunctionPrinter
from .opt import register_specialize, out2in, insert_inplace_optimizer from theano.tensor.opt import register_specialize, out2in, insert_inplace_optimizer
# opt.py # opt.py
import basic as T import basic as T
#NB: this clobbers the builtin 'compile' symbol #NB: this clobbers the builtin 'compile' symbol
from .. import compile #to register the optimizer built by this file from theano import compile #to register the optimizer built by this file
from .blas_headers import cblas_header_text, blas_header_text from theano.tensor.blas_headers import cblas_header_text, blas_header_text
@utils.memoize @utils.memoize
def ldflags(libs=True, flags=False): def ldflags(libs=True, flags=False):
...@@ -363,11 +363,16 @@ class Gemm(GemmRelated): ...@@ -363,11 +363,16 @@ class Gemm(GemmRelated):
gemm = Gemm() gemm = Gemm()
pprint.assign(gemm, FunctionPrinter('gemm')) pprint.assign(gemm, FunctionPrinter('gemm'))
def res_is_a(node, op, maxclients=None): def res_is_a(node, op, maxclients=None):
return node.owner \ if maxclients is not None:
retval = (len(node.clients) <= maxclients)
else:
retval = True
return node.owner \
and node.owner.op == op \ and node.owner.op == op \
and (len(node.clients) <= maxclients if maxclients is not None else True) and retval
def _as_scalar(res): def _as_scalar(res):
"""Return None or a TensorVariable whose type is in T.float_scalar_types""" """Return None or a TensorVariable whose type is in T.float_scalar_types"""
...@@ -504,7 +509,6 @@ def _gemm_from_node(node): ...@@ -504,7 +509,6 @@ def _gemm_from_node(node):
assert len(sM_list) == len(sM_orig) assert len(sM_list) == len(sM_orig)
assert len(sM_list) + len(other_inputs) == len(node.inputs) assert len(sM_list) + len(other_inputs) == len(node.inputs)
if len(sM_list) == 2: if len(sM_list) == 2:
(sL, mL), (sR, mR) = sM_list (sL, mL), (sR, mR) = sM_list
gemm_of_sM_list = _beta_L_plus_alpha_M(sL, mL, sR, mR) gemm_of_sM_list = _beta_L_plus_alpha_M(sL, mL, sR, mR)
......
...@@ -2,13 +2,13 @@ import sys ...@@ -2,13 +2,13 @@ import sys
import elemwise_cgen as cgen import elemwise_cgen as cgen
import numpy import numpy
from .. import gof from theano import gof
from ..gof import Op, Apply from theano.gof import Op, Apply
from .. import scalar from theano import scalar
from ..scalar import Scalar from theano.scalar import Scalar
from .. import printing from theano import printing
from ..printing import pprint from theano.printing import pprint
from ..gof.python25 import all from theano.gof.python25 import all
from copy import copy, deepcopy from copy import copy, deepcopy
...@@ -216,19 +216,31 @@ class DimShuffle(Op): ...@@ -216,19 +216,31 @@ class DimShuffle(Op):
'0, 0, NPY_ALIGNED|NPY_ENSURECOPY, NULL)')] '0, 0, NPY_ALIGNED|NPY_ENSURECOPY, NULL)')]
shape_statements = ['npy_intp dimensions[%i]'%nd_out] shape_statements = ['npy_intp dimensions[%i]'%nd_out]
shape_statements += [('dimensions['+str(i)+'] = %(basename)s->dimensions['+str(o)+']') for i, o in enumerate(self.new_order):
if o != 'x' else if o != 'x':
('dimensions['+str(i)+'] = 1') shape_statements += [('dimensions['+str(i)+'] = %(basename)s->dimensions['+str(o)+']')]
for i, o in enumerate(self.new_order)] else:
shape_statements += [('dimensions['+str(i)+'] = 1')]
#backport
#shape_statements += [('dimensions['+str(i)+'] = %(basename)s->dimensions['+str(o)+']')
# if o != 'x' else
# ('dimensions['+str(i)+'] = 1')
# for i, o in enumerate(self.new_order)]
strides_statements = ['npy_intp strides[%i]'%nd_out] strides_statements = ['npy_intp strides[%i]'%nd_out]
#set the strides of the non-broadcasted dimensions #set the strides of the non-broadcasted dimensions
strides_statements += [('strides['+str(i)+'] = %(basename)s->strides['+str(o)+']') for i, o in enumerate(self.new_order):
if o != 'x' else if o != 'x':
('strides['+str(i)+'] = 0') strides_statements += [('strides['+str(i)+'] = %(basename)s->strides['+str(o)+']')]
for i, o in enumerate(self.new_order)] else:
strides_statements += [('strides['+str(i)+'] = 0')]
#backport
#strides_statements += [('strides['+str(i)+'] = %(basename)s->strides['+str(o)+']')
# if o != 'x' else
# ('strides['+str(i)+'] = 0')
# for i, o in enumerate(self.new_order)]
# set the strides of the broadcasted dimensions # set the strides of the broadcasted dimensions
# this algorithm is from numpy: PyArray_Newshape() in cvs/numpy/numpy/core/src/multiarraymodule.c # this algorithm is from numpy: PyArray_Newshape() in cvs/numpy/numpy/core/src/multiarraymodule.c
...@@ -443,7 +455,16 @@ class Elemwise(Op): ...@@ -443,7 +455,16 @@ class Elemwise(Op):
def _rehash(self): def _rehash(self):
items = self.inplace_pattern.items() items = self.inplace_pattern.items()
items.sort() items.sort()
tuple_items = tuple([k for k,v in items] + [(tuple(v) if isinstance(v, (tuple, list)) else v) for k,v in items]) first_part = [k for k,v in items]
second_part = []
for k,v in items:
if isinstance(v, (tuple, list)):
second_part += [tuple(v)]
else:
second_part += [v]
tuple_items = tuple(first_part + second_part)
#backport
#tuple_items = tuple([k for k,v in items] + [(tuple(v) if isinstance(v, (tuple, list)) else v) for k,v in items])
h = hash('Elemwise') ^ hash(self.scalar_op) ^ hash(tuple_items) h = hash('Elemwise') ^ hash(self.scalar_op) ^ hash(tuple_items)
assert h == getattr(self,'_hashval', h) assert h == getattr(self,'_hashval', h)
self._hashval = h self._hashval = h
...@@ -518,10 +539,21 @@ class Elemwise(Op): ...@@ -518,10 +539,21 @@ class Elemwise(Op):
for dims in zip(*[[(1, True)]*(maxsize - len(input.shape)) + zip(input.shape, sinput.type.broadcastable) for dims in zip(*[[(1, True)]*(maxsize - len(input.shape)) + zip(input.shape, sinput.type.broadcastable)
for input, sinput in zip(inputs, node.inputs)]): for input, sinput in zip(inputs, node.inputs)]):
if max(d for d,b in dims) != 1 and (1, False) in dims: if max(d for d,b in dims) != 1 and (1, False) in dims:
msg = []
for input, sinput in zip(inputs, node.inputs):
for d, b in zip(input.shape, sinput.type.broadcastable):
if b:
msg += ['*']
else:
msg += [str(d)]
raise ValueError('Dimension mismatch; shapes are %s' % raise ValueError('Dimension mismatch; shapes are %s' %
', '.join('(%s)' % ', '.join('*' if b else str(d) ', '.join('(%s)' % ', '.join(msg)))
for d, b in zip(input.shape, sinput.type.broadcastable)) #backport
for input, sinput in zip(inputs, node.inputs))) #raise ValueError('Dimension mismatch; shapes are %s' %
# ', '.join('(%s)' % ', '.join('*' if b else str(d)
# for d, b in zip(input.shape, sinput.type.broadcastable))
# for input, sinput in zip(inputs, node.inputs)))
# Other mismatches will be caught by the ufunc # Other mismatches will be caught by the ufunc
if not self.inplace_pattern: if not self.inplace_pattern:
for output, storage in zip(node.outputs, output_storage): for output, storage in zip(node.outputs, output_storage):
......
from .basic import _scal_elemwise #, _transpose_inplace from basic import _scal_elemwise #, _transpose_inplace
from .. import scalar as scal from theano import scalar as scal
import elemwise import elemwise
from .. import printing from theano import printing
from ..printing import pprint from theano.printing import pprint
from theano.gof.python25 import any
def _scal_inplace(symbol): def _scal_inplace(symbol):
"""Replace a symbol definition with an elementwise version of the corresponding scalar Op""" """Replace a symbol definition with an elementwise version of the corresponding scalar Op"""
......
...@@ -4,20 +4,20 @@ ...@@ -4,20 +4,20 @@
# TODO: 0*x -> 0 # TODO: 0*x -> 0
from .. import gof from theano import gof
from ..gof import opt, InconsistencyError, TopoOptimizer, graph from theano.gof import opt, InconsistencyError, TopoOptimizer, graph
from elemwise import Elemwise, DimShuffle from elemwise import Elemwise, DimShuffle
from .. import scalar from theano import scalar
import basic as T import basic as T
import inplace as I import inplace as I
import numpy as N import numpy as N
import operator import operator
import itertools import itertools
import sys import sys
from .. import compile #to register the optimizer built by this file from theano import compile #to register the optimizer built by this file
from ..compile.debugmode import _debugprint
from theano.compile.debugmode import _debugprint
from theano.gof.python25 import any
# Utilities # Utilities
...@@ -738,11 +738,20 @@ class Canonizer(gof.LocalOptimizer): ...@@ -738,11 +738,20 @@ class Canonizer(gof.LocalOptimizer):
def mul_calculate(num, denum, aslist=False, out_type=None): def mul_calculate(num, denum, aslist=False, out_type=None):
if not num and not denum: if not num and not denum:
# Smallest 1 possible. # Smallest 1 possible.
return [] if aslist else N.int8(1) if aslist:
return []
else:
return N.int8(1)
#return [] if aslist else N.int8(1)
# Make sure we do not accidently upcast data types. # Make sure we do not accidently upcast data types.
if out_type is None: if out_type is None:
# TODO: remove this error-causing heuristic # TODO: remove this error-causing heuristic
first = num[0] if num else denum[0] if num:
first = num[0]
else:
first = denum[0]
#first = num[0] if num else denum[0]
one = N.asarray(first).dtype.type(1) one = N.asarray(first).dtype.type(1)
else: else:
one = N.asarray(1, dtype=out_type.dtype) one = N.asarray(1, dtype=out_type.dtype)
...@@ -850,15 +859,29 @@ def local_mul_specialize(node): ...@@ -850,15 +859,29 @@ def local_mul_specialize(node):
new_inputs.append(input) new_inputs.append(input)
if len(new_inputs) < len(node.inputs): if len(new_inputs) < len(node.inputs):
if len(new_inputs) == 0: if len(new_inputs) == 0:
newval = -y.flatten()[0] if neg else y.flatten()[0] if neg:
newval = -y.flatten()[0]
else:
newval = y.flatten()[0]
#newval = -y.flatten()[0] if neg else y.flatten()[0]
return fill_chain(T.TensorConstant(T.TensorType(dtype=node.outputs[0].type.dtype, return fill_chain(T.TensorConstant(T.TensorType(dtype=node.outputs[0].type.dtype,
broadcastable = [True] * node.outputs[0].ndim), N.asarray(newval))) broadcastable = [True] * node.outputs[0].ndim), N.asarray(newval)))
if len(new_inputs) == 1: if len(new_inputs) == 1:
return fill_chain(-new_inputs[0] if neg else new_inputs[0]) if neg:
msg = -new_inputs[0]
else:
msg = new_inputs[0]
return fill_chain(msg)
# return fill_chain(-new_inputs[0] if neg else new_inputs[0])
else: else:
return fill_chain(-T.mul(*new_inputs) if neg else \ if neg:
T.mul(*new_inputs)) msg = -T.mul(*new_inputs)
else:
msg = T.mul(*new_inputs)
#return fill_chain(-T.mul(*new_inputs) if neg else \
# T.mul(*new_inputs))
else: else:
return False return False
register_specialize(local_mul_specialize) register_specialize(local_mul_specialize)
...@@ -914,7 +937,11 @@ mul_canonizer = in2out(gof.LocalOptGroup(local_mul_canonizer, local_fill_cut, lo ...@@ -914,7 +937,11 @@ mul_canonizer = in2out(gof.LocalOptGroup(local_mul_canonizer, local_fill_cut, lo
def add_calculate(num, denum, aslist = False, out_type=None): def add_calculate(num, denum, aslist = False, out_type=None):
#TODO: make sure that this function and mul_calculate are similar #TODO: make sure that this function and mul_calculate are similar
zero = 0.0 if out_type is None else N.asarray(0, dtype=out_type.dtype) if out_type is None:
zero = 0.0
else:
zero = N.asarray(0, dtype=out_type.dtype)
#zero = 0.0 if out_type is None else N.asarray(0, dtype=out_type.dtype)
v = reduce(N.add, num, zero) - reduce(N.add, denum, zero) v = reduce(N.add, num, zero) - reduce(N.add, denum, zero)
if aslist: if aslist:
if N.all(v == 0): if N.all(v == 0):
...@@ -1061,7 +1088,15 @@ def constant_folding(node): ...@@ -1061,7 +1088,15 @@ def constant_folding(node):
storage = [[None] for output in node.outputs] storage = [[None] for output in node.outputs]
node.op.perform(node, [x.data for x in node.inputs], storage) node.op.perform(node, [x.data for x in node.inputs], storage)
#TODO: think about how to extend to more types #TODO: think about how to extend to more types
return [(T.TensorConstant if isinstance(s[0], (N.ndarray,int,float)) else gof.Constant)(output.type, s[0]) for s, output in zip(storage, node.outputs)] msg = []
for s, output in zip(storage, node.outputs):
if isinstance(s[0], (N.ndarray,int,float)):
msg += [T.TensorConstant(output.type,s[0])]
else:
msg += [gof.Constant(output.type, s[0])]
return msg
#TODO: verify this backport!!
#return [(T.TensorConstant if isinstance(s[0], (N.ndarray,int,float)) else gof.Constant)(output.type, s[0]) for s, output in zip(storage, node.outputs)]
register_canonicalize(constant_folding) register_canonicalize(constant_folding)
register_specialize(constant_folding) register_specialize(constant_folding)
......
...@@ -4,9 +4,9 @@ __docformat__ = "restructuredtext en" ...@@ -4,9 +4,9 @@ __docformat__ = "restructuredtext en"
import sys import sys
import numpy import numpy
from ..compile import module, In, Component from theano.compile import module, In, Component
from ..gof import Container from theano.gof import Container
from ..tensor import raw_random from theano.tensor import raw_random
class RandomStreamsInstance(object): class RandomStreamsInstance(object):
"""RandomStreamsInstance""" """RandomStreamsInstance"""
...@@ -37,7 +37,10 @@ class RandomStreamsInstance(object): ...@@ -37,7 +37,10 @@ class RandomStreamsInstance(object):
:rtype: None :rtype: None
""" """
seed = self.default_seed if seed is None else seed if seed is None:
seed = self.default_seed
#backport
#seed = self.default_seed if seed is None else seed
seedgen = numpy.random.RandomState(seed) seedgen = numpy.random.RandomState(seed)
for old_r, new_r in self.random_streams.random_state_variables: for old_r, new_r in self.random_streams.random_state_variables:
old_r_seed = seedgen.randint(2**30) old_r_seed = seedgen.randint(2**30)
......
...@@ -7,8 +7,8 @@ import numpy ...@@ -7,8 +7,8 @@ import numpy
#local imports #local imports
import basic as tensor import basic as tensor
import opt import opt
from .. import gof from theano import gof
from ..compile import optdb from theano.compile import optdb
class RandomStateType(gof.Type): class RandomStateType(gof.Type):
"""A Type wrapper for numpy.RandomState """A Type wrapper for numpy.RandomState
...@@ -85,7 +85,12 @@ class RandomFunction(gof.Op): ...@@ -85,7 +85,12 @@ class RandomFunction(gof.Op):
def __setstate__(self, state): def __setstate__(self, state):
self.state = state self.state = state
fn, outtype, args, kwargs = state fn, outtype, args, kwargs = state
self.fn = getattr(numpy.random.RandomState, fn) if isinstance(fn, str) else fn if isinstance(fn, str):
self.fn = getattr(numpy.random.RandomState, fn)
else:
self.fn = fn
#backport
#self.fn = getattr(numpy.random.RandomState, fn) if isinstance(fn, str) else fn
self.outtype = outtype self.outtype = outtype
self.args = tuple(tensor.as_tensor_variable(arg) for arg in args) self.args = tuple(tensor.as_tensor_variable(arg) for arg in args)
self.inplace = kwargs.pop('inplace', False) self.inplace = kwargs.pop('inplace', False)
...@@ -139,7 +144,12 @@ class RandomFunction(gof.Op): ...@@ -139,7 +144,12 @@ class RandomFunction(gof.Op):
inputs = [] inputs = []
for arg, default in zip(args, self.args): for arg, default in zip(args, self.args):
assert arg is None or default.type.dtype == arg.type.dtype assert arg is None or default.type.dtype == arg.type.dtype
input = default if arg is None else arg if arg is None:
input = default
else:
input = arg
#backport
#input = default if arg is None else arg
inputs.append(input) inputs.append(input)
return gof.Apply(self, return gof.Apply(self,
......
...@@ -10,7 +10,7 @@ from copy import copy ...@@ -10,7 +10,7 @@ from copy import copy
from theano import compile from theano import compile
from theano import gradient from theano import gradient
from theano import gof from theano import gof
from theano.gof.python25 import any from theano.gof.python25 import any, all
from theano import gof from theano import gof
from theano.tensor.elemwise import DimShuffle from theano.tensor.elemwise import DimShuffle
......
import traceback import traceback
import theano.tensor as T import theano.tensor as T
from ...gof import Env from theano.gof import Env
from ...printing import pp from theano.printing import pp
import numpy import numpy
from theano.tensor.blas import * from theano.tensor.blas import *
from theano.tensor.blas import _dot22, res_is_a from theano.tensor.blas import _dot22, res_is_a
...@@ -9,11 +9,8 @@ from unittest import TestCase ...@@ -9,11 +9,8 @@ from unittest import TestCase
from theano.tests import unittest_tools from theano.tests import unittest_tools
from copy import copy from copy import copy
_as_scalar = GemmLocalOptimizer._as_scalar
_is_real_matrix = GemmLocalOptimizer._is_real_matrix
from theano import In, Out from theano import In, Out
from .test_basic import (_approx_eq, as_tensor_variable, inplace_func, from test_basic import (_approx_eq, as_tensor_variable, inplace_func,
compile, value, constant, inplace, eval_outputs) compile, value, constant, inplace, eval_outputs)
class t_gemm(TestCase): class t_gemm(TestCase):
...@@ -415,4 +412,3 @@ def test_inplace1(): ...@@ -415,4 +412,3 @@ def test_inplace1():
# gemm should operate in-place on (Z+Z) # gemm should operate in-place on (Z+Z)
if (not gemm in [n.op for n in f.maker.env.nodes]): if (not gemm in [n.op for n in f.maker.env.nodes]):
raise Failure('no gemm in graph') raise Failure('no gemm in graph')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论