提交 4cf7afb4 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2952 from abergeron/flake8

Flake8 work
...@@ -15,7 +15,9 @@ class OpFromGraph(gof.Op): ...@@ -15,7 +15,9 @@ class OpFromGraph(gof.Op):
TODO: TODO:
- examples for a multi-layer mlp. where? - examples for a multi-layer mlp. where?
- __hash__, __eq__ otherwise won't merge, try gof.opt.is_same_graph_with_merge(op1.new_outputs, op2, new_outputs) - __hash__, __eq__ otherwise won't merge, try
gof.opt.is_same_graph_with_merge(op1.new_outputs, op2,
new_outputs)
- c_code() to remove the double overhead? - c_code() to remove the double overhead?
- opt to unfold it, work inplace on inputs - opt to unfold it, work inplace on inputs
- grad() make it support DisconnectedType and the new interface - grad() make it support DisconnectedType and the new interface
...@@ -76,8 +78,6 @@ class OpFromGraph(gof.Op): ...@@ -76,8 +78,6 @@ class OpFromGraph(gof.Op):
# not see them. Otherwise their is problem with the gradient. # not see them. Otherwise their is problem with the gradient.
self.shared_inputs = [var for var in gof.graph.inputs(outputs) self.shared_inputs = [var for var in gof.graph.inputs(outputs)
if isinstance(var, SharedVariable)] if isinstance(var, SharedVariable)]
used_inputs = [var for var in gof.graph.inputs(outputs)
if not isinstance(var, gof.Constant)]
shared_vars = [var.type() for var in self.shared_inputs] shared_vars = [var.type() for var in self.shared_inputs]
new = rebuild_collect_shared(outputs, inputs=inputs + shared_vars, new = rebuild_collect_shared(outputs, inputs=inputs + shared_vars,
replace=dict(zip(self.shared_inputs, replace=dict(zip(self.shared_inputs,
...@@ -110,8 +110,8 @@ class OpFromGraph(gof.Op): ...@@ -110,8 +110,8 @@ class OpFromGraph(gof.Op):
def make_node(self, *inputs): def make_node(self, *inputs):
for input, type in zip(inputs, self.input_types): for input, type in zip(inputs, self.input_types):
if not type == input.type: if not type == input.type:
raise TypeError("Wrong type, expected %s but got %s" raise TypeError("Wrong type, expected %s but got %s" %
% (type, input.type)) (type, input.type))
return gof.Apply(self, return gof.Apply(self,
list(inputs) + self.shared_inputs, list(inputs) + self.shared_inputs,
[type() for type in self.output_types]) [type() for type in self.output_types])
...@@ -143,7 +143,8 @@ class OpFromGraph(gof.Op): ...@@ -143,7 +143,8 @@ class OpFromGraph(gof.Op):
grad_ops = self.grad_ops grad_ops = self.grad_ops
else: else:
gs = theano.gradient.grad(cost=None, gs = theano.gradient.grad(cost=None,
known_grads=dict(zip(self.new_outputs, output_grads)), known_grads=dict(zip(self.new_outputs,
output_grads)),
wrt=self.new_inputs, wrt=self.new_inputs,
disconnected_inputs='ignore') disconnected_inputs='ignore')
......
"""Define the `function` function """Define the `function` function
""" """
__docformat__ = "restructuredtext en"
import cPickle import cPickle
import logging import logging
_logger = logging.getLogger('theano.compile.function')
import traceback as tb import traceback as tb
import re import re
...@@ -14,9 +11,11 @@ from theano.compile.function_module import orig_function ...@@ -14,9 +11,11 @@ from theano.compile.function_module import orig_function
from theano.compile.pfunc import pfunc from theano.compile.pfunc import pfunc
from numpy import any from numpy import any
import warnings import warnings
from theano import gof
from theano import compat from theano import compat
__docformat__ = "restructuredtext en"
_logger = logging.getLogger('theano.compile.function')
def function_dump(filename, inputs, outputs=None, mode=None, updates=None, def function_dump(filename, inputs, outputs=None, mode=None, updates=None,
givens=None, givens=None,
...@@ -70,54 +69,67 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -70,54 +69,67 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
:type mode: string or `Mode` instance. :type mode: string or `Mode` instance.
:param mode: compilation mode :param mode: compilation mode
:type updates: iterable over pairs (shared_variable, new_expression). List, tuple or OrderedDict. :type updates: iterable over pairs (shared_variable, new_expression).
:param updates: update the values for SharedVariable inputs according to these expressions List, tuple or OrderedDict.
:param updates: update the values for SharedVariable inputs
according to these expressions
:type givens: iterable over pairs (Var1, Var2) of Variables. List, tuple or dict. The Var1 :type givens: iterable over pairs (Var1, Var2) of Variables. List,
and Var2 in each pair must have the same Type. tuple or dict. The Var1 and Var2 in each pair must
have the same Type.
:param givens: specific substitutions to make in the computation graph (Var2 replaces :param givens: specific substitutions to make in the computation
Var1). graph (Var2 replaces Var1).
:type no_default_updates: either bool or list of Variables :type no_default_updates: either bool or list of Variables
:param no_default_updates: if True, do not perform any automatic update on Variables. :param no_default_updates: if True, do not perform any automatic
If False (default), perform them all. Else, perform automatic updates on all Variables update on Variables. If False (default), perform them
that are neither in "updates" nor in "no_default_updates". all. Else, perform automatic updates on all Variables that are
neither in "updates" nor in "no_default_updates".
:param name: an optional name for this function. The profile mode will print the time spent in this function.
:param name: an optional name for this function. The profile mode
:param rebuild_strict: True (Default) is the safer and better tested setting, in which case will print the time spent in this function.
`givens` must substitute new variables with the same Type as the variables they replace.
False is a you-better-know-what-you-are-doing setting, that permits `givens` to replace :param rebuild_strict: True (Default) is the safer and better
variables with new variables of any Type. The consequence of changing a Type is that all tested setting, in which case `givens` must substitute new
results depending on that variable may have a different Type too (the graph is rebuilt from variables with the same Type as the variables they replace.
inputs to outputs). If one of the new types does not make sense for one of the Ops in the False is a you-better-know-what-you-are-doing setting, that
permits `givens` to replace variables with new variables of
any Type. The consequence of changing a Type is that all
results depending on that variable may have a different Type
too (the graph is rebuilt from inputs to outputs). If one of
the new types does not make sense for one of the Ops in the
graph, an Exception will be raised. graph, an Exception will be raised.
:type allow_input_downcast: Boolean or None :type allow_input_downcast: Boolean or None
:param allow_input_downcast: True means that the values passed as :param allow_input_downcast: True means that the values passed as
inputs when calling the function can be silently downcasted to fit inputs when calling the function can be silently downcasted to
the dtype of the corresponding Variable, which may lose precision. fit the dtype of the corresponding Variable, which may lose
False means that it will only be cast to a more general, or precision. False means that it will only be cast to a more
precise, type. None (default) is almost like False, but allows general, or precise, type. None (default) is almost like
downcasting of Python float scalars to floatX. False, but allows downcasting of Python float scalars to
floatX.
:type profile: None, True, or ProfileStats instance :type profile: None, True, or ProfileStats instance
:param profile: accumulate profiling information into a given ProfileStats :param profile: accumulate profiling information into a given
instance. If argument is `True` then a new ProfileStats instance will be ProfileStats instance. If argument is `True` then a new
used. This profiling object will be available via self.profile. ProfileStats instance will be used. This profiling object
will be available via self.profile.
:param on_unused_input: What to do if a variable in the 'inputs' list is :param on_unused_input: What to do if a variable in the 'inputs'
not used in the graph. Possible values are 'raise', 'warn', 'ignore' and None. list is not used in the graph. Possible values are 'raise',
'warn', 'ignore' and None.
:rtype: Function instance :rtype: Function instance
:returns: a callable object that will compute the outputs (given the inputs) :returns: a callable object that will compute the outputs (given
and update the implicit function arguments according to the `updates`. the inputs) and update the implicit function arguments
according to the `updates`.
:note: Regarding givens: Be careful to make sure that these substitutions are :note: Regarding givens: Be careful to make sure that these
independent--behaviour when Var1 of one pair appears in the graph leading to Var2 in substitutions are independent--behaviour when Var1 of one pair
another expression is undefined. Replacements specified with givens are different from appears in the graph leading to Var2 in another expression is
optimizations in that Var2 is not expected to be equivalent to Var1. undefined. Replacements specified with givens are different
from optimizations in that Var2 is not expected to be
equivalent to Var1.
Internal documentation: Internal documentation:
...@@ -195,10 +207,6 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -195,10 +207,6 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
was easier to develop the VM in Python then translate it to C instead was easier to develop the VM in Python then translate it to C instead
of just writing it in C from scratch. of just writing it in C from scratch.
CVM stands for C Virtual Machine. CVM stands for C Virtual Machine.
""" """
if isinstance(outputs, dict): if isinstance(outputs, dict):
output_items = outputs.items() output_items = outputs.items()
...@@ -214,7 +222,6 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -214,7 +222,6 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
output_keys.append(pair[0]) output_keys.append(pair[0])
outputs.append(pair[1]) outputs.append(pair[1])
else: else:
output_keys = None output_keys = None
...@@ -256,12 +263,13 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -256,12 +263,13 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
if givens is None: if givens is None:
givens = [] givens = []
if not isinstance(inputs, (list, tuple)): if not isinstance(inputs, (list, tuple)):
raise Exception("Input variables of a Theano function should be" raise Exception("Input variables of a Theano function should be "
" contained in a list, even when there is a single input.") "contained in a list, even when there is a single "
"input.")
# compute some features of the arguments: # compute some features of the arguments:
uses_In = any([isinstance(i, In) for i in inputs]) # N.B. the square brackets are ncessary uses_In = any([isinstance(i, In) for i in inputs])
uses_tuple = any([isinstance(i, (list, tuple)) for i in inputs]) # N.B. the square brackets are ncessary uses_tuple = any([isinstance(i, (list, tuple)) for i in inputs])
uses_updates = bool(updates) uses_updates = bool(updates)
uses_givens = bool(givens) uses_givens = bool(givens)
...@@ -275,7 +283,8 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -275,7 +283,8 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
if uses_In or uses_tuple: if uses_In or uses_tuple:
# we must use old semantics in this case. # we must use old semantics in this case.
if profile: if profile:
raise NotImplementedError('profiling not supported in old-style function') raise NotImplementedError("profiling not supported in old-style "
"function")
if uses_updates or uses_givens: if uses_updates or uses_givens:
raise NotImplementedError( raise NotImplementedError(
"In() instances and tuple inputs trigger the old " "In() instances and tuple inputs trigger the old "
...@@ -284,8 +293,8 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -284,8 +293,8 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
mode=mode, mode=mode,
accept_inplace=accept_inplace, name=name) accept_inplace=accept_inplace, name=name)
else: else:
# note: pfunc will also call orig_function-- orig_function is a choke point # note: pfunc will also call orig_function-- orig_function is
# that all compilation must pass through # a choke point that all compilation must pass through
fn = pfunc(params=inputs, fn = pfunc(params=inputs,
outputs=outputs, outputs=outputs,
mode=mode, mode=mode,
......
"""Define `SymbolicInput`, `SymbolicOutput`, `In`, `Out` """ """Define `SymbolicInput`, `SymbolicOutput`, `In`, `Out` """
__docformat__ = 'restructuredtext en'
from theano import gof from theano import gof
from sharedvalue import SharedVariable from sharedvalue import SharedVariable
...@@ -7,6 +6,8 @@ from sharedvalue import SharedVariable ...@@ -7,6 +6,8 @@ from sharedvalue import SharedVariable
import logging import logging
_logger = logging.getLogger("theano.compile.io") _logger = logging.getLogger("theano.compile.io")
__docformat__ = 'restructuredtext en'
class SymbolicInput(object): class SymbolicInput(object):
""" """
...@@ -17,34 +18,47 @@ class SymbolicInput(object): ...@@ -17,34 +18,47 @@ class SymbolicInput(object):
not computed from its owner. not computed from its owner.
name: Any type. (If autoname=True, defaults to variable.name). name: Any type. (If autoname=True, defaults to variable.name).
If name is a valid Python identifier, this input can be set by kwarg, and its value If name is a valid Python identifier, this input can be set by
can be accessed by self.<name>. kwarg, and its value can be accessed by self.<name>.
update: Variable instance (default: None) update: Variable instance (default: None)
value (see previous) will be replaced with this expression variable after each function call. value (see previous) will be replaced with this expression
If update is None, the update will be the default value of the input. variable after each function call. If update is None, the
update will be the default value of the input.
mutable: Bool (default: False if update is None, True if update is
not None)
True: permit the compiled function to modify the python object
being passed as the input
mutable: Bool (default: False if update is None, True if update is not None) False: do not permit the compiled function to modify the
True: permit the compiled function to modify the python object being passed as the input python object being passed as the input.
False: do not permit the compiled function to modify the python object being passed as the input.
strict: Bool (default: False) strict: Bool (default: False)
True: means that the value you pass for this input must have exactly the right type
False: the value you pass for this input may be cast automatically to the proper type True: means that the value you pass for this input must have
exactly the right type
False: the value you pass for this input may be cast
automatically to the proper type
allow_downcast: Bool or None (default: None) allow_downcast: Bool or None (default: None)
Only applies when `strict` is False. Only applies when `strict` is False.
True: the value you pass for this input can be silently True: the value you pass for this input can be silently
downcasted to fit the right type, which may lose precision. downcasted to fit the right type, which may lose precision.
False: the value will only be cast to a more general, or precise, type.
None: Almost like False, but allows downcast of Python floats to floatX. False: the value will only be cast to a more general, or
precise, type. None: Almost like False, but allows downcast
of Python floats to floatX.
autoname: Bool (default: True) autoname: Bool (default: True)
See the name option. See the name option.
implicit: Bool (default: False) implicit: Bool (default: False)
See help(In). Note that 'None' is not allowed here, since we are in the See help(In). Note that 'None' is not allowed here, since we
symbolic case. are in the symbolic case.
""" """
def __init__(self, variable, name=None, update=None, mutable=None, def __init__(self, variable, name=None, update=None, mutable=None,
...@@ -146,36 +160,54 @@ class In(SymbolicInput): ...@@ -146,36 +160,54 @@ class In(SymbolicInput):
not computed from its owner. not computed from its owner.
name: Any type. (If autoname=True, defaults to variable.name). name: Any type. (If autoname=True, defaults to variable.name).
If name is a valid Python identifier, this input can be set by kwarg, and its value If name is a valid Python identifier, this input can be set by
can be accessed by self.<name>. kwarg, and its value can be accessed by self.<name>.
value: Any type. value: Any type.
The initial/default value for this input. If update is None, this input acts just like The initial/default value for this input. If update is None,
an argument with a default value in Python. If update is not None, changes to this this input acts just like an argument with a default value in
value will "stick around", whether due to an update or a user's explicit action. Python. If update is not None, changes to this value will
"stick around", whether due to an update or a user's explicit
action.
update: Variable instance (default: None) update: Variable instance (default: None)
value (see previous) will be replaced with this expression variable after each function call. value (see previous) will be replaced with this expression
If update is None, the update will be the default value of the input. variable after each function call. If update is None, the
update will be the default value of the input.
mutable: Bool (default: False if update is None, True if update is not None) mutable: Bool (default: False if update is None, True if update is
True: permit the compiled function to modify the python object being passed as the input not None)
False: do not permit the compiled function to modify the python object being passed as the input.
True: permit the compiled function to modify the python object
being passed as the input
False: do not permit the compiled function to modify the
python object being passed as the input.
borrow: Bool (default: take the same value as mutable) borrow: Bool (default: take the same value as mutable)
True: permit the output of the compiled function to be aliased to the input
True: permit the output of the compiled function to be aliased
to the input
False: do not permit any output to be aliased to the input False: do not permit any output to be aliased to the input
strict: Bool (default: False) strict: Bool (default: False)
True: means that the value you pass for this input must have exactly the right type
False: the value you pass for this input may be cast automatically to the proper type True: means that the value you pass for this input must have
exactly the right type
False: the value you pass for this input may be cast
automatically to the proper type
allow_downcast: Bool or None (default: None) allow_downcast: Bool or None (default: None)
Only applies when `strict` is False. Only applies when `strict` is False.
True: the value you pass for this input can be silently True: the value you pass for this input can be silently
downcasted to fit the right type, which may lose precision. downcasted to fit the right type, which may lose precision.
False: the value will only be cast to a more general, or precise, type.
None: Almost like False, but allows downcast of Python floats to floatX. False: the value will only be cast to a more general, or
precise, type. None: Almost like False, but allows downcast
of Python floats to floatX.
autoname: Bool (default: True) autoname: Bool (default: True)
See the name option. See the name option.
...@@ -194,11 +226,11 @@ class In(SymbolicInput): ...@@ -194,11 +226,11 @@ class In(SymbolicInput):
# Note: the documentation above is duplicated in doc/topics/function.txt, # Note: the documentation above is duplicated in doc/topics/function.txt,
# try to keep it synchronized. # try to keep it synchronized.
def __init__(self, variable, name=None, value=None, update=None, def __init__(self, variable, name=None, value=None, update=None,
mutable=None, strict=False, allow_downcast=None, autoname=True, mutable=None, strict=False, allow_downcast=None,
implicit=None, borrow=None, shared=False): autoname=True, implicit=None, borrow=None, shared=False):
# if shared, an input's value comes from its persistent
# if shared, an input's value comes from its persistent storage, not from a default stored # storage, not from a default stored in the function or from
# in the function or from the caller # the caller
self.shared = shared self.shared = shared
if borrow is None: if borrow is None:
......
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
""" """
from __future__ import print_function from __future__ import print_function
import logging import logging
import warnings
from textwrap import dedent
import numpy import numpy
...@@ -11,24 +9,24 @@ import theano ...@@ -11,24 +9,24 @@ import theano
from theano import gof from theano import gof
import theano.gof.vm import theano.gof.vm
from theano.configparser import config, AddConfigVar, StrParam from theano.configparser import config, AddConfigVar, StrParam
from theano.compile.ops import register_view_op_c_code, _output_guard from theano.compile.ops import _output_guard
_logger = logging.getLogger('theano.compile.mode') _logger = logging.getLogger('theano.compile.mode')
AddConfigVar('optimizer_excluding', AddConfigVar('optimizer_excluding',
("When using the default mode, we will remove optimizer with these " ("When using the default mode, we will remove optimizer with "
"tags. Separate tags with ':'."), "these tags. Separate tags with ':'."),
StrParam("", allow_override=False), StrParam("", allow_override=False),
in_c_key=False) in_c_key=False)
AddConfigVar('optimizer_including', AddConfigVar('optimizer_including',
("When using the default mode, we will add optimizer with these tags. " ("When using the default mode, we will add optimizer with "
"Separate tags with ':'."), "these tags. Separate tags with ':'."),
StrParam("", allow_override=False), StrParam("", allow_override=False),
in_c_key=False) in_c_key=False)
AddConfigVar('optimizer_requiring', AddConfigVar('optimizer_requiring',
("When using the default mode, we will require optimizer with these " ("When using the default mode, we will require optimizer with "
"tags. Separate tags with ':'."), "these tags. Separate tags with ':'."),
StrParam("", allow_override=False), StrParam("", allow_override=False),
in_c_key=False) in_c_key=False)
...@@ -50,9 +48,9 @@ def check_equal(x, y): ...@@ -50,9 +48,9 @@ def check_equal(x, y):
y = y.todense() y = y.todense()
if isinstance(x, numpy.ndarray) and isinstance(y, numpy.ndarray): if isinstance(x, numpy.ndarray) and isinstance(y, numpy.ndarray):
if (x.dtype != y.dtype if (x.dtype != y.dtype or
or x.shape != y.shape x.shape != y.shape or
or numpy.any(abs(x - y) > 1e-10)): numpy.any(abs(x - y) > 1e-10)):
raise Exception("Output mismatch.", raise Exception("Output mismatch.",
{'performlinker': x, 'clinker': y}) {'performlinker': x, 'clinker': y})
else: else:
...@@ -287,7 +285,8 @@ class Mode(object): ...@@ -287,7 +285,8 @@ class Mode(object):
def __str__(self): def __str__(self):
return "%s(linker = %s, optimizer = %s)" % (self.__class__.__name__, return "%s(linker = %s, optimizer = %s)" % (self.__class__.__name__,
self.provided_linker, self.provided_optimizer) self.provided_linker,
self.provided_optimizer)
def __get_optimizer(self): def __get_optimizer(self):
if isinstance(self._optimizer, gof.Query): if isinstance(self._optimizer, gof.Query):
...@@ -364,10 +363,11 @@ def get_mode(orig_string): ...@@ -364,10 +363,11 @@ def get_mode(orig_string):
# DebugMode use its own linker. # DebugMode use its own linker.
ret = DebugMode(optimizer=config.optimizer) ret = DebugMode(optimizer=config.optimizer)
else: else:
# The import is needed in case string is ProfileMode # This might be required if the string is 'ProfileMode'
from profilemode import ProfileMode, prof_mode_instance_to_print from profilemode import ProfileMode # noqa
ret = eval(string from profilemode import prof_mode_instance_to_print
+ '(linker=config.linker, optimizer=config.optimizer)') ret = eval(string +
'(linker=config.linker, optimizer=config.optimizer)')
elif string in predefined_modes: elif string in predefined_modes:
ret = predefined_modes[string] ret = predefined_modes[string]
else: else:
......
from __future__ import print_function from __future__ import print_function
# Note: this code was initially copied from the 'pyutools' package by its # Note: this code was initially copied from the 'pyutools' package by its
# original author, and re-licensed under Theano's license. # original author, and re-licensed under Theano's license.
import numpy
import theano import theano
from theano.compile.mode import Mode from theano.compile.mode import Mode
......
...@@ -71,11 +71,12 @@ class ViewOp(gof.Op): ...@@ -71,11 +71,12 @@ class ViewOp(gof.Op):
version = [] version = []
# If any of the c code is unversionned, we have to return () # If any of the c code is unversionned, we have to return ()
# Else, we will return a list of (type name, version) pairs. # Else, we will return a list of (type name, version) pairs.
for t, (c, v) in sorted(self.c_code_and_version.items(), key=lambda pair: str(pair[0])): for t, (c, v) in sorted(self.c_code_and_version.items(),
key=lambda pair: str(pair[0])):
if not v: if not v:
warnings.warn("Type %s has C code for ViewOp, but it has " warnings.warn("Type %s has C code for ViewOp, but it has no "
"no version. You should add a 'version' keyword arg " "version. You should add a 'version' keyword "
"when calling register_view_op_c_code." % t, "arg when calling register_view_op_c_code." % t,
stacklevel=2) stacklevel=2)
return () return ()
version.append((str(t), v)) version.append((str(t), v))
...@@ -165,11 +166,13 @@ class DeepCopyOp(gof.Op): ...@@ -165,11 +166,13 @@ class DeepCopyOp(gof.Op):
version = [] version = []
# If any of the c code is unversionned, we have to return () # If any of the c code is unversionned, we have to return ()
# Else, we will return a list of (type name, version) pairs. # Else, we will return a list of (type name, version) pairs.
for t, (c, v) in sorted(self.c_code_and_version.items(), key=lambda pair: str(pair[0])): for t, (c, v) in sorted(self.c_code_and_version.items(),
key=lambda pair: str(pair[0])):
if not v: if not v:
warnings.warn("Type %s has C code for DeepCopyOp, but it has " warnings.warn("Type %s has C code for DeepCopyOp, but it has "
"no version. You should add a 'version' keyword arg " "no version. You should add a 'version' keyword"
"when calling register_deep_copy_op_c_code." % t, " arg when calling "
"register_deep_copy_op_c_code." % t,
stacklevel=2) stacklevel=2)
return () return ()
version.append((str(t), v)) version.append((str(t), v))
...@@ -284,11 +287,12 @@ class Shape(gof.Op): ...@@ -284,11 +287,12 @@ class Shape(gof.Op):
version = [] version = []
# If any of the c code is unversionned, we have to return () # If any of the c code is unversionned, we have to return ()
# Else, we will return a list of (type name, version) pairs. # Else, we will return a list of (type name, version) pairs.
for t, (c, v) in sorted(self.c_code_and_version.items(), key=lambda pair: str(pair[0])): for t, (c, v) in sorted(self.c_code_and_version.items(),
key=lambda pair: str(pair[0])):
if not v: if not v:
warnings.warn("Type %s has C code for Shape, but it has " warnings.warn("Type %s has C code for Shape, but it has no "
"no version. You should add a 'version' keyword arg " "version. You should add a 'version' keyword "
"when calling register_shape_c_code." % t, "arg when calling register_shape_c_code." % t,
stacklevel=2) stacklevel=2)
return () return ()
version.append((str(t), v)) version.append((str(t), v))
...@@ -301,7 +305,6 @@ class Shape(gof.Op): ...@@ -301,7 +305,6 @@ class Shape(gof.Op):
shape = Shape() shape = Shape()
_shape = shape # was used in the past, now use shape directly. _shape = shape # was used in the past, now use shape directly.
#pprint.assign(_shape, printing.MemberPrinter('shape'))
class Shape_i(gof.Op): class Shape_i(gof.Op):
...@@ -389,8 +392,11 @@ class Shape_i(gof.Op): ...@@ -389,8 +392,11 @@ class Shape_i(gof.Op):
return [()] return [()]
def grad(self, inp, grads): def grad(self, inp, grads):
return [theano.gradient.grad_not_implemented(op=self, x_pos=0, x=inp[0], return [theano.gradient.grad_not_implemented(
comment="No gradient for the shape of a matrix is implemented.")] op=self, x_pos=0, x=inp[0],
comment=("No gradient for the shape of a matrix "
"is implemented."))]
def shape_i(var, i, fgraph=None): def shape_i(var, i, fgraph=None):
"""Equivalent of var.shape[i], but apply if possible the shape """Equivalent of var.shape[i], but apply if possible the shape
...@@ -435,9 +441,10 @@ def shape_i(var, i, fgraph=None): ...@@ -435,9 +441,10 @@ def shape_i(var, i, fgraph=None):
def register_shape_i_c_code(typ, code, check_input, version=()): def register_shape_i_c_code(typ, code, check_input, version=()):
""" Tell Shape_i how to generate C code for a Theano Type """ Tell Shape_i how to generate C code for a Theano Type
:param typ: A Theano type. It must be the Theano class itself and not an :param typ: A Theano type. It must be the Theano class itself and not
instance of the class. an instance of the class.
:param code: C code that gets the shape of dimensions %(i)s for the Theano type 'typ'. :param code: C code that gets the shape of dimensions %(i)s for the
Theano type 'typ'.
Use %(iname)s and %(oname)s for the input and output C Use %(iname)s and %(oname)s for the input and output C
variable names respectively. variable names respectively.
:param version: A number indicating the version of the code, for cache. :param version: A number indicating the version of the code, for cache.
...@@ -620,7 +627,8 @@ class Rebroadcast(gof.Op): ...@@ -620,7 +627,8 @@ class Rebroadcast(gof.Op):
return type(self) == type(other) and self.axis == other.axis return type(self) == type(other) and self.axis == other.axis
def __hash__(self): def __hash__(self):
items = sorted(self.axis.iteritems()) # no ambiguity because each item key is unique # no ambiguity because each item key is unique
items = sorted(self.axis.iteritems())
return hash((type(self), tuple(items))) return hash((type(self), tuple(items)))
def __str__(self): def __str__(self):
...@@ -637,9 +645,9 @@ class Rebroadcast(gof.Op): ...@@ -637,9 +645,9 @@ class Rebroadcast(gof.Op):
def make_node(self, x): def make_node(self, x):
if self.axis.keys() and (x.ndim <= numpy.max(self.axis.keys())): if self.axis.keys() and (x.ndim <= numpy.max(self.axis.keys())):
raise ValueError('Trying to rebroadcast non-existent dimension') raise ValueError('Trying to rebroadcast non-existent dimension')
t = x.type.clone(broadcastable=[self.axis.get(i, b) t = x.type.clone(
for i, b in enumerate( broadcastable=[self.axis.get(i, b)
x.type.broadcastable)]) for i, b in enumerate(x.type.broadcastable)])
return gof.Apply(self, [x], [t()]) return gof.Apply(self, [x], [t()])
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
...@@ -702,9 +710,10 @@ class Rebroadcast(gof.Op): ...@@ -702,9 +710,10 @@ class Rebroadcast(gof.Op):
for t, (c, v) in sorted(self.c_code_and_version.items(), for t, (c, v) in sorted(self.c_code_and_version.items(),
key=lambda pair: str(pair[0])): key=lambda pair: str(pair[0])):
if not v: if not v:
warnings.warn("Type %s has C code for Rebroadcast, but it has " warnings.warn("Type %s has C code for Rebroadcast, but it "
"no version. You should add a 'version' keyword arg " "has no version. You should add a 'version' "
"when calling register_rebroadcast_c_code." % t, "keyword arg when calling "
"register_rebroadcast_c_code." % t,
stacklevel=2) stacklevel=2)
return () return ()
version.append((str(t), v)) version.append((str(t), v))
...@@ -718,17 +727,18 @@ def register_specify_shape_c_code(typ, code, version=(), ...@@ -718,17 +727,18 @@ def register_specify_shape_c_code(typ, code, version=(),
c_support_code_apply=None): c_support_code_apply=None):
""" Tell SpecifyShape how to generate C code for a Theano Type """ Tell SpecifyShape how to generate C code for a Theano Type
:param typ: A Theano type. It must be the Theano class itself and not an :param typ: A Theano type. It must be the Theano class itself and
instance of the class. not an instance of the class.
:param code: C code that checks the shape and returns a view for the Theano type 'typ'. :param code: C code that checks the shape and returns a view for
Use %(iname)s and %(oname)s for the input and output C the Theano type 'typ'. Use %(iname)s and %(oname)s
variable names respectively. for the input and output C variable names
%(shape)s is the vector of shape of %(iname)s. respectively. %(shape)s is the vector of shape of
Check that its length is good. %(iname)s. Check that its length is good.
:param version: A number indicating the version of the code, for cache. :param version: A number indicating the version of the code, for cache.
:param c_support_code_apply: extra code. :param c_support_code_apply: extra code.
""" """
SpecifyShape.c_code_and_version[typ] = (code, version, c_support_code_apply) SpecifyShape.c_code_and_version[typ] = (code, version,
c_support_code_apply)
class SpecifyShape(gof.Op): class SpecifyShape(gof.Op):
...@@ -784,7 +794,8 @@ class SpecifyShape(gof.Op): ...@@ -784,7 +794,8 @@ class SpecifyShape(gof.Op):
new_shape = [] new_shape = []
for dim in xrange(node.inputs[0].ndim): for dim in xrange(node.inputs[0].ndim):
try: try:
s = theano.tensor.get_scalar_constant_value(node.inputs[1][dim]) s = theano.tensor.get_scalar_constant_value(
node.inputs[1][dim])
s = theano.tensor.as_tensor_variable(s) s = theano.tensor.as_tensor_variable(s)
new_shape.append(s) new_shape.append(s)
except theano.tensor.NotScalarConstantError: except theano.tensor.NotScalarConstantError:
...@@ -832,7 +843,8 @@ class SpecifyShape(gof.Op): ...@@ -832,7 +843,8 @@ class SpecifyShape(gof.Op):
code, version, _ = self.c_code_and_version[itype] code, version, _ = self.c_code_and_version[itype]
return code % locals() return code % locals()
return super(SpecifyShape, self).c_code(node, node, inames, onames, sub) return super(SpecifyShape, self).c_code(node, node, inames,
onames, sub)
def c_code_cache_version(self): def c_code_cache_version(self):
version = [] version = []
...@@ -841,9 +853,10 @@ class SpecifyShape(gof.Op): ...@@ -841,9 +853,10 @@ class SpecifyShape(gof.Op):
for t, (c, v, _) in sorted(self.c_code_and_version.items(), for t, (c, v, _) in sorted(self.c_code_and_version.items(),
key=lambda pair: str(pair[0])): key=lambda pair: str(pair[0])):
if not v: if not v:
warnings.warn("Type %s has C code for SpecifyShape, but it has " warnings.warn("Type %s has C code for SpecifyShape, but it "
"no version. You should add a 'version' keyword arg " "has no version. You should add a 'version' "
"when calling register_specify_shape_c_code." % t, "keyword arg when calling "
"register_specify_shape_c_code." % t,
stacklevel=2) stacklevel=2)
return () return ()
version.append((str(t), v)) version.append((str(t), v))
......
差异被折叠。
"""Provide a simple user friendly API to Theano-managed memory""" """Provide a simple user friendly API to Theano-managed memory"""
__docformat__ = 'restructuredtext en'
# Standard imports # Standard imports
import copy import copy
import logging import logging
...@@ -12,6 +10,7 @@ import numpy ...@@ -12,6 +10,7 @@ import numpy
from theano.gof import Container, Variable, generic, utils from theano.gof import Container, Variable, generic, utils
_logger = logging.getLogger('theano.compile.sharedvalue') _logger = logging.getLogger('theano.compile.sharedvalue')
__docformat__ = 'restructuredtext en'
class SharedVariable(Variable): class SharedVariable(Variable):
...@@ -49,7 +48,8 @@ class SharedVariable(Variable): ...@@ -49,7 +48,8 @@ class SharedVariable(Variable):
or copied, so they must have the correct type. or copied, so they must have the correct type.
:param allow_downcast: Only applies if `strict` is False. :param allow_downcast: Only applies if `strict` is False.
True -> allow assigned value to lose precision when cast during assignment. True -> allow assigned value to lose precision when cast
during assignment.
False -> never allow precision loss. False -> never allow precision loss.
None -> only allow downcasting of a Python float to a scalar floatX. None -> only allow downcasting of a Python float to a scalar floatX.
...@@ -65,12 +65,13 @@ class SharedVariable(Variable): ...@@ -65,12 +65,13 @@ class SharedVariable(Variable):
if container is not None: if container is not None:
self.container = container self.container = container
if (value is not None) or (strict is not None): if (value is not None) or (strict is not None):
raise TypeError( raise TypeError('value and strict are ignored if you pass '
'value and strict are ignored if you pass a container here') 'a container here')
else: else:
if container is not None: if container is not None:
raise TypeError('Error to specify both value and container') raise TypeError('Error to specify both value and container')
self.container = Container(self, self.container = Container(
self,
storage=[type.filter(value, strict=strict, storage=[type.filter(value, strict=strict,
allow_downcast=allow_downcast)], allow_downcast=allow_downcast)],
readonly=False, readonly=False,
...@@ -183,7 +184,8 @@ def shared(value, name=None, strict=False, allow_downcast=None, **kwargs): ...@@ -183,7 +184,8 @@ def shared(value, name=None, strict=False, allow_downcast=None, **kwargs):
potential constructors to those that can accept those kwargs. potential constructors to those that can accept those kwargs.
:note: Some shared variable have ``borrow`` as extra kwargs. :note: Some shared variable have ``borrow`` as extra kwargs.
`See <http://deeplearning.net/software/theano/tutorial/aliasing.html#borrowing-when-creating-shared-variables>`_ for detail. `See <http://deeplearning.net/software/theano/tutorial/aliasing.\
html#borrowing-when-creating-shared-variables>`_ for detail.
:note: Some shared variable have ``broadcastable`` as extra kwargs. :note: Some shared variable have ``broadcastable`` as extra kwargs.
As shared variable shapes can change, all dimensions default As shared variable shapes can change, all dimensions default
...@@ -200,7 +202,8 @@ def shared(value, name=None, strict=False, allow_downcast=None, **kwargs): ...@@ -200,7 +202,8 @@ def shared(value, name=None, strict=False, allow_downcast=None, **kwargs):
try: try:
if isinstance(value, Variable): if isinstance(value, Variable):
raise TypeError(" Shared variable constructor needs numeric values and not symbolic variables.") raise TypeError("Shared variable constructor needs numeric "
"values and not symbolic variables.")
for ctor in reversed(shared.constructors): for ctor in reversed(shared.constructors):
try: try:
......
import cPickle, logging import cPickle
import logging
_logger = logging.getLogger("theano.gof.callcache") _logger = logging.getLogger("theano.gof.callcache")
...@@ -18,9 +19,6 @@ class CallCache(object): ...@@ -18,9 +19,6 @@ class CallCache(object):
def persist(self, filename=None): def persist(self, filename=None):
if filename is None: if filename is None:
filename = self.filename filename = self.filename
# backport
#filename = self.filename if filename is None else filename
f = open(filename, 'w') f = open(filename, 'w')
cPickle.dump(self.cache, f) cPickle.dump(self.cache, f)
f.close() f.close()
...@@ -28,9 +26,6 @@ class CallCache(object): ...@@ -28,9 +26,6 @@ class CallCache(object):
def call(self, fn, args=(), key=None): def call(self, fn, args=(), key=None):
if key is None: if key is None:
key = (fn, tuple(args)) key = (fn, tuple(args))
# backport
#key = (fn, tuple(args)) if key is None else key
if key not in self.cache: if key not in self.cache:
_logger.debug('cache miss %i', len(self.cache)) _logger.debug('cache miss %i', len(self.cache))
self.cache[key] = fn(*args) self.cache[key] = fn(*args)
......
...@@ -8,7 +8,6 @@ import re ...@@ -8,7 +8,6 @@ import re
import shutil import shutil
import struct import struct
import socket import socket
import subprocess
import sys import sys
import textwrap import textwrap
...@@ -295,7 +294,8 @@ def cleanup(): ...@@ -295,7 +294,8 @@ def cleanup():
have_npy_abi_version = True have_npy_abi_version = True
elif obj.startswith('c_compiler_str='): elif obj.startswith('c_compiler_str='):
have_c_compiler = True have_c_compiler = True
elif (isinstance(obj, (theano.gof.Op, theano.gof.Type)) and elif (isinstance(obj, (theano.gof.Op,
theano.gof.Type)) and
hasattr(obj, 'c_code_cache_version')): hasattr(obj, 'c_code_cache_version')):
v = obj.c_code_cache_version() v = obj.c_code_cache_version()
if v not in [(), None] and v not in key[0]: if v not in [(), None] and v not in key[0]:
...@@ -310,7 +310,7 @@ def cleanup(): ...@@ -310,7 +310,7 @@ def cleanup():
if keydata.key_pkl != filename: if keydata.key_pkl != filename:
keydata.key_pkl = filename keydata.key_pkl = filename
keydata.remove_key(key) keydata.remove_key(key)
except IOError as e: except IOError:
_logger.error( _logger.error(
"Could not remove file '%s'. To complete " "Could not remove file '%s'. To complete "
"the clean-up, please remove manually " "the clean-up, please remove manually "
...@@ -395,7 +395,7 @@ def print_compiledir_content(): ...@@ -395,7 +395,7 @@ def print_compiledir_content():
if big_key_files: if big_key_files:
big_key_files = sorted(big_key_files, key=lambda t: str(t[1])) big_key_files = sorted(big_key_files, key=lambda t: str(t[1]))
big_total_size = sum([size for dir, size, ops in big_key_files]) big_total_size = sum([sz for _, sz, _ in big_key_files])
print(("There are directories with key files bigger than %d bytes " print(("There are directories with key files bigger than %d bytes "
"(they probably contain big tensor constants)" % "(they probably contain big tensor constants)" %
max_key_file_size)) max_key_file_size))
......
...@@ -102,8 +102,8 @@ def get_lock(lock_dir=None, **kw): ...@@ -102,8 +102,8 @@ def get_lock(lock_dir=None, **kw):
# the lock state and raise an error. # the lock state and raise an error.
while get_lock.n_lock > 0: while get_lock.n_lock > 0:
release_lock() release_lock()
raise Exception("For some unknow reason, the lock was already taken," raise Exception("For some unknow reason, the lock was already "
" but no start time was registered.") "taken, but no start time was registered.")
now = time.time() now = time.time()
if now - get_lock.start_time > config.compile.timeout/2: if now - get_lock.start_time > config.compile.timeout/2:
lockpath = os.path.join(get_lock.lock_dir, 'lock') lockpath = os.path.join(get_lock.lock_dir, 'lock')
......
...@@ -14,18 +14,21 @@ if os.path.exists(os.path.join(config.compiledir, 'cutils_ext.so')): ...@@ -14,18 +14,21 @@ if os.path.exists(os.path.join(config.compiledir, 'cutils_ext.so')):
def compile_cutils_code(): def compile_cutils_code():
types = ['npy_' + t for t in ['int8', 'int16', 'int32', 'int64', 'int128', types = ['npy_' + t for t in ['int8', 'int16', 'int32', 'int64', 'int128',
'int256', 'uint8', 'uint16', 'uint32', 'uint64', 'uint128', 'uint256', 'int256', 'uint8', 'uint16', 'uint32',
'float16', 'float32', 'float64', 'float80', 'float96', 'float128', 'uint64', 'uint128', 'uint256',
'float16', 'float32', 'float64',
'float80', 'float96', 'float128',
'float256']] 'float256']]
complex_types = ['npy_' + t for t in ['complex32', 'complex64', complex_types = ['npy_' + t for t in ['complex32', 'complex64',
'complex128', 'complex160', 'complex192', 'complex512']] 'complex128', 'complex160',
'complex192', 'complex512']]
inplace_map_template = """ inplace_map_template = """
#if defined(%(typen)s) #if defined(%(typen)s)
static void %(type)s_inplace_add(PyArrayMapIterObject *mit, PyArrayIterObject *it, int inc_or_set) static void %(type)s_inplace_add(PyArrayMapIterObject *mit,
PyArrayIterObject *it, int inc_or_set)
{ {
int index = mit->size; int index = mit->size;
while (index--) { while (index--) {
...@@ -38,10 +41,13 @@ def compile_cutils_code(): ...@@ -38,10 +41,13 @@ def compile_cutils_code():
#endif #endif
""" """
floatadd = "((%(type)s*)mit->dataptr)[0] = inc_or_set * ((%(type)s*)mit->dataptr)[0] + ((%(type)s*)it->dataptr)[0];" floatadd = ("((%(type)s*)mit->dataptr)[0] = inc_or_set * "
"((%(type)s*)mit->dataptr)[0] + ((%(type)s*)it->dataptr)[0];")
complexadd = """ complexadd = """
((%(type)s*)mit->dataptr)[0].real = inc_or_set * ((%(type)s*)mit->dataptr)[0].real + ((%(type)s*)it->dataptr)[0].real; ((%(type)s*)mit->dataptr)[0].real = inc_or_set *
((%(type)s*)mit->dataptr)[0].imag = inc_or_set * ((%(type)s*)mit->dataptr)[0].imag + ((%(type)s*)it->dataptr)[0].imag; ((%(type)s*)mit->dataptr)[0].real + ((%(type)s*)it->dataptr)[0].real;
((%(type)s*)mit->dataptr)[0].imag = inc_or_set *
((%(type)s*)mit->dataptr)[0].imag + ((%(type)s*)it->dataptr)[0].imag;
""" """
fns = ''.join([inplace_map_template % {'type': t, 'typen': t.upper(), fns = ''.join([inplace_map_template % {'type': t, 'typen': t.upper(),
...@@ -51,33 +57,36 @@ def compile_cutils_code(): ...@@ -51,33 +57,36 @@ def compile_cutils_code():
'op': complexadd % {'type': t}} 'op': complexadd % {'type': t}}
for t in complex_types]) for t in complex_types])
def gen_binop(type, typen):
return """
#if defined(%(typen)s)
%(type)s_inplace_add,
#endif
""" % dict(type=type, typen=typen)
fn_array = ("static inplace_map_binop addition_funcs[] = {" + fn_array = ("static inplace_map_binop addition_funcs[] = {" +
''.join([""" ''.join([gen_binop(type=t, typen=t.upper())
#if defined(%(typen)s) for t in types + complex_types]) + "NULL};\n")
%(type)s_inplace_add,
#endif def gen_num(typen):
""" % {'type': t, 'typen': t.upper()} return """
for t in types + complex_types]) + #if defined(%(typen)s)
"""NULL}; %(typen)s,
""") #endif
""" % dict(type=type, typen=typen)
type_number_array = ("static int type_numbers[] = {" + type_number_array = ("static int type_numbers[] = {" +
''.join([""" ''.join([gen_num(typen=t.upper())
#if defined(%(typen)s) for t in types + complex_types]) + "-1000};")
%(typen)s,
#endif
""" % {'type': t, 'typen': t.upper()}
for t in types + complex_types]) +
"-1000};")
code = (""" code = ("""
#if NPY_API_VERSION >= 0x00000008 #if NPY_API_VERSION >= 0x00000008
typedef void (*inplace_map_binop)(PyArrayMapIterObject *, PyArrayIterObject *, int inc_or_set); typedef void (*inplace_map_binop)(PyArrayMapIterObject *,
""" + fns + fn_array + type_number_array + PyArrayIterObject *, int inc_or_set);
""" + fns + fn_array + type_number_array + """
"""
static int static int
map_increment(PyArrayMapIterObject *mit, PyObject *op, inplace_map_binop add_inplace, int inc_or_set) map_increment(PyArrayMapIterObject *mit, PyObject *op,
inplace_map_binop add_inplace, int inc_or_set)
{ {
PyArrayObject *arr = NULL; PyArrayObject *arr = NULL;
PyArrayIterObject *it; PyArrayIterObject *it;
...@@ -129,7 +138,8 @@ inplace_increment(PyObject *dummy, PyObject *args) ...@@ -129,7 +138,8 @@ inplace_increment(PyObject *dummy, PyObject *args)
return NULL; return NULL;
} }
if (!PyArray_Check(arg_a)) { if (!PyArray_Check(arg_a)) {
PyErr_SetString(PyExc_ValueError, "needs an ndarray as first argument"); PyErr_SetString(PyExc_ValueError,
"needs an ndarray as first argument");
return NULL; return NULL;
} }
...@@ -285,7 +295,7 @@ try: ...@@ -285,7 +295,7 @@ try:
open(os.path.join(location, '__init__.py'), 'w').close() open(os.path.join(location, '__init__.py'), 'w').close()
try: try:
from cutils_ext.cutils_ext import * from cutils_ext.cutils_ext import * # noqa
except ImportError: except ImportError:
get_lock() get_lock()
# Ensure no-one else is currently modifying the content of the compilation # Ensure no-one else is currently modifying the content of the compilation
...@@ -296,11 +306,11 @@ try: ...@@ -296,11 +306,11 @@ try:
# We must retry to import it as some other process could # We must retry to import it as some other process could
# have been compiling it between the first failed import # have been compiling it between the first failed import
# and when we receive the lock # and when we receive the lock
from cutils_ext.cutils_ext import * from cutils_ext.cutils_ext import * # noqa
except ImportError: except ImportError:
compile_cutils() compile_cutils()
from cutils_ext.cutils_ext import * from cutils_ext.cutils_ext import * # noqa
finally: finally:
# Release lock on compilation directory. # Release lock on compilation directory.
......
...@@ -15,12 +15,13 @@ _logger = logging.getLogger('theano.gof.lazylinker_c') ...@@ -15,12 +15,13 @@ _logger = logging.getLogger('theano.gof.lazylinker_c')
force_compile = False force_compile = False
version = 0.21 # must match constant returned in function get_version() version = 0.21 # must match constant returned in function get_version()
lazylinker_ext = None
def try_import(): def try_import():
global lazylinker_ext global lazylinker_ext
sys.path[0:0] = [config.compiledir] sys.path[0:0] = [config.compiledir]
import lazylinker_ext import lazylinker_ext # noqa
del sys.path[0] del sys.path[0]
...@@ -43,11 +44,11 @@ try: ...@@ -43,11 +44,11 @@ try:
# Try to make the location # Try to make the location
os.mkdir(location) os.mkdir(location)
except OSError as e: except OSError as e:
# If we get an error, verify that the error was # 17, the path already exists, # If we get an error, verify that the error was # 17, the
# and that it is a directory # path already exists, and that it is a directory Note: we
# Note: we can't check if it exists before making it, because we are not holding # can't check if it exists before making it, because we
# the lock right now, so we could race another process and get error 17 if we lose # are not holding the lock right now, so we could race
# the race # another process and get error 17 if we lose the race
assert e.errno == errno.EEXIST assert e.errno == errno.EEXIST
assert os.path.isdir(location) assert os.path.isdir(location)
...@@ -142,5 +143,5 @@ except ImportError: ...@@ -142,5 +143,5 @@ except ImportError:
# Release lock on compilation directory. # Release lock on compilation directory.
release_lock() release_lock()
from lazylinker_ext.lazylinker_ext import * from lazylinker_ext.lazylinker_ext import * # noqa
assert force_compile or (version == get_version()) assert force_compile or (version == get_version())
...@@ -32,7 +32,7 @@ class DB(object): ...@@ -32,7 +32,7 @@ class DB(object):
self.__db__ = DefaultOrderedDict(OrderedSet) self.__db__ = DefaultOrderedDict(OrderedSet)
self._names = set() self._names = set()
self.name = None # will be reset by register self.name = None # will be reset by register
#(via obj.name by the thing doing the registering) # (via obj.name by the thing doing the registering)
def register(self, name, obj, *tags, **kwargs): def register(self, name, obj, *tags, **kwargs):
""" """
...@@ -175,8 +175,10 @@ class Query(object): ...@@ -175,8 +175,10 @@ class Query(object):
self.exclude = OrderedSet(self.exclude) self.exclude = OrderedSet(self.exclude)
def __str__(self): def __str__(self):
return "Query{inc=%s,ex=%s,require=%s,subquery=%s,position_cutoff=%d}" % ( return ("Query{inc=%s,ex=%s,require=%s,subquery=%s,"
self.include, self.exclude, self.require, self.subquery, self.position_cutoff) "position_cutoff=%d}" %
(self.include, self.exclude, self.require, self.subquery,
self.position_cutoff))
# add all opt with this tag # add all opt with this tag
def including(self, *tags): def including(self, *tags):
...@@ -268,7 +270,7 @@ class SequenceDB(DB): ...@@ -268,7 +270,7 @@ class SequenceDB(DB):
position_cutoff = kwtags.pop('position_cutoff', position_cutoff = kwtags.pop('position_cutoff',
config.optdb.position_cutoff) config.optdb.position_cutoff)
if len(tags) >= 1 and isinstance(tags[0], Query): if len(tags) >= 1 and isinstance(tags[0], Query):
# the call to super should have raise an error with a good message # the call to super should have raise an error with a good message
assert len(tags) == 1 assert len(tags) == 1
if getattr(tags[0], 'position_cutoff', None): if getattr(tags[0], 'position_cutoff', None):
position_cutoff = tags[0].position_cutoff position_cutoff = tags[0].position_cutoff
......
...@@ -39,8 +39,8 @@ def make_depends(): ...@@ -39,8 +39,8 @@ def make_depends():
def depends(pair): def depends(pair):
""" Returns True if a depends on b """ """ Returns True if a depends on b """
a, b = pair a, b = pair
return (any(bout in a.inputs for bout in b.outputs) return (any(bout in a.inputs for bout in b.outputs) or
or any(depends((ainp.owner, b)) for ainp in a.inputs any(depends((ainp.owner, b)) for ainp in a.inputs
if ainp.owner)) if ainp.owner))
return depends return depends
...@@ -160,12 +160,12 @@ def posort(l, *cmps): ...@@ -160,12 +160,12 @@ def posort(l, *cmps):
for b in l: for b in l:
assert not(b in comes_after[a] and a in comes_after[b]) assert not(b in comes_after[a] and a in comes_after[b])
for cmp in cmps: for cmp_fn in cmps:
for a in l: for a in l:
for b in l: for b in l:
if cmp(a, b) < 0: # a wants to come before b if cmp_fn(a, b) < 0: # a wants to come before b
# if this wouldn't cause a cycle and isn't already known # if this wouldn't cause a cycle and isn't already known
if not b in comes_before[a] and not b in comes_after[a]: if b not in comes_before[a] and b not in comes_after[a]:
add_links(a, b) add_links(a, b)
# check() # debug code # check() # debug code
......
...@@ -36,8 +36,11 @@ def test_give_variables_names_small(): ...@@ -36,8 +36,11 @@ def test_give_variables_names_small():
def test_remove(): def test_remove():
even = lambda x: x % 2 == 0 def even(x):
odd = lambda x: x % 2 == 1 return x % 2 == 0
# The list are neede as with python 3, remove and filter return generators
def odd(x):
return x % 2 == 1
# The list are needed as with python 3, remove and filter return generators
# and we can't compare generators. # and we can't compare generators.
assert list(remove(even, range(5))) == list(filter(odd, range(5))) assert list(remove(even, range(5))) == list(filter(odd, range(5)))
...@@ -214,9 +214,9 @@ class Validator(Feature): ...@@ -214,9 +214,9 @@ class Validator(Feature):
class ReplaceValidate(History, Validator): class ReplaceValidate(History, Validator):
pickle_rm_attr = ["replace_validate", "replace_all_validate", pickle_rm_attr = (["replace_validate", "replace_all_validate",
"replace_all_validate_remove"] + \ "replace_all_validate_remove"] +
History.pickle_rm_attr + Validator.pickle_rm_attr History.pickle_rm_attr + Validator.pickle_rm_attr)
def on_attach(self, fgraph): def on_attach(self, fgraph):
for attr in ('replace_validate', 'replace_all_validate', for attr in ('replace_validate', 'replace_all_validate',
...@@ -256,11 +256,13 @@ class ReplaceValidate(History, Validator): ...@@ -256,11 +256,13 @@ class ReplaceValidate(History, Validator):
try: try:
fgraph.replace(r, new_r, reason=reason, verbose=False) fgraph.replace(r, new_r, reason=reason, verbose=False)
except Exception as e: except Exception as e:
if ('The type of the replacement must be the same' not in msg = str(e)
str(e) and 'does not belong to this FunctionGraph' not in str(e)): s1 = 'The type of the replacement must be the same'
s2 = 'does not belong to this FunctionGraph'
if (s1 not in msg and s2 not in msg):
out = sys.stderr out = sys.stderr
print("<<!! BUG IN FGRAPH.REPLACE OR A LISTENER !!>>", end=' ', file=out) print("<<!! BUG IN FGRAPH.REPLACE OR A LISTENER !!>>",
print(type(e), e, reason, file=out) type(e), e, reason, file=out)
# this might fail if the error is in a listener: # this might fail if the error is in a listener:
# (fgraph.replace kinda needs better internal error handling) # (fgraph.replace kinda needs better internal error handling)
fgraph.revert(chk) fgraph.revert(chk)
...@@ -286,13 +288,14 @@ class ReplaceValidate(History, Validator): ...@@ -286,13 +288,14 @@ class ReplaceValidate(History, Validator):
fgraph.revert(chk) fgraph.revert(chk)
if warn: if warn:
out = sys.stderr out = sys.stderr
print(( print(
"WARNING: An optimization wanted to replace a Variable" "WARNING: An optimization wanted to replace a Variable"
" in the graph, but the replacement for it doesn't" " in the graph, but the replacement for it doesn't"
" remove it. We disabled the optimization." " remove it. We disabled the optimization."
" Your function runs correctly, but it would be" " Your function runs correctly, but it would be"
" appreciated if you submit this problem to the" " appreciated if you submit this problem to the"
" mailing list theano-users so that we can fix it."), file=out) " mailing list theano-users so that we can fix it.",
file=out)
print(reason, replacements, file=out) print(reason, replacements, file=out)
raise ReplacementDidntRemovedError() raise ReplacementDidntRemovedError()
...@@ -311,7 +314,8 @@ class NodeFinder(Bookkeeper): ...@@ -311,7 +314,8 @@ class NodeFinder(Bookkeeper):
def on_attach(self, fgraph): def on_attach(self, fgraph):
if self.fgraph is not None: if self.fgraph is not None:
raise Exception("A NodeFinder instance can only serve one FunctionGraph.") raise Exception("A NodeFinder instance can only serve one "
"FunctionGraph.")
if hasattr(fgraph, 'get_nodes'): if hasattr(fgraph, 'get_nodes'):
raise AlreadyThere("NodeFinder is already present or in conflict" raise AlreadyThere("NodeFinder is already present or in conflict"
" with another plugin.") " with another plugin.")
......
"""WRITEME Defines the `Type` class.""" """WRITEME Defines the `Type` class."""
__docformat__ = "restructuredtext en"
from theano.compat import PY3 from theano.compat import PY3
from theano.gof import utils from theano.gof import utils
...@@ -13,6 +10,8 @@ from theano.gof import graph ...@@ -13,6 +10,8 @@ from theano.gof import graph
######## ########
from theano.gof.op import CLinkerObject from theano.gof.op import CLinkerObject
__docformat__ = "restructuredtext en"
class CLinkerType(CLinkerObject): class CLinkerType(CLinkerObject):
"""Interface specification for Types that can be arguments to a `CLinkerOp`. """Interface specification for Types that can be arguments to a `CLinkerOp`.
...@@ -45,7 +44,8 @@ class CLinkerType(CLinkerObject): ...@@ -45,7 +44,8 @@ class CLinkerType(CLinkerObject):
- `MethodNotDefined`: Subclass does not implement this method - `MethodNotDefined`: Subclass does not implement this method
""" """
raise MethodNotDefined("c_literal", type(self), self.__class__.__name__) raise MethodNotDefined("c_literal", type(self),
self.__class__.__name__)
def c_declare(self, name, sub, check_input=True): def c_declare(self, name, sub, check_input=True):
"""Required: Return c code to declare variables that will be """Required: Return c code to declare variables that will be
...@@ -56,7 +56,8 @@ class CLinkerType(CLinkerObject): ...@@ -56,7 +56,8 @@ class CLinkerType(CLinkerObject):
return "PyObject ** addr_of_%(name)s;" return "PyObject ** addr_of_%(name)s;"
:param name: the name of the ``PyObject *`` pointer that will the value for this Type :param name: the name of the ``PyObject *`` pointer that will
the value for this Type
:type name: string :type name: string
...@@ -138,7 +139,8 @@ class CLinkerType(CLinkerObject): ...@@ -138,7 +139,8 @@ class CLinkerType(CLinkerObject):
- `MethodNotDefined`: Subclass does not implement this method - `MethodNotDefined`: Subclass does not implement this method
""" """
raise MethodNotDefined("c_extract", type(self), self.__class__.__name__) raise MethodNotDefined("c_extract", type(self),
self.__class__.__name__)
def c_extract_out(self, name, sub, check_input=True): def c_extract_out(self, name, sub, check_input=True):
"""Optional: C code to extract a PyObject * instance. """Optional: C code to extract a PyObject * instance.
...@@ -184,11 +186,12 @@ class CLinkerType(CLinkerObject): ...@@ -184,11 +186,12 @@ class CLinkerType(CLinkerObject):
def c_sync(self, name, sub): def c_sync(self, name, sub):
"""Required: Return c code to pack C types back into a PyObject. """Required: Return c code to pack C types back into a PyObject.
The code returned from this function must be templated using "%(name)s", The code returned from this function must be templated using
representing the name that the caller wants to call this Variable. The "%(name)s", representing the name that the caller wants to
returned code may set "py_%(name)s" to a PyObject* and that PyObject* call this Variable. The returned code may set "py_%(name)s"
will be accessible from Python via variable.data. Do not forget to adjust to a PyObject* and that PyObject* will be accessible from
reference counts if "py_%(name)s" is changed from its original value. Python via variable.data. Do not forget to adjust reference
counts if "py_%(name)s" is changed from its original value.
:Parameters: :Parameters:
- `name`: WRITEME - `name`: WRITEME
...@@ -205,10 +208,11 @@ class CLinkerType(CLinkerObject): ...@@ -205,10 +208,11 @@ class CLinkerType(CLinkerObject):
def c_code_cache_version(self): def c_code_cache_version(self):
"""Return a tuple of integers indicating the version of this Type. """Return a tuple of integers indicating the version of this Type.
An empty tuple indicates an 'unversioned' Type that will not be cached between processes. An empty tuple indicates an 'unversioned' Type that will not
be cached between processes.
The cache mechanism may erase cached modules that have been superceded by newer The cache mechanism may erase cached modules that have been
versions. See `ModuleCache` for details. superceded by newer versions. See `ModuleCache` for details.
""" """
return () return ()
...@@ -221,19 +225,21 @@ class PureType(object): ...@@ -221,19 +225,21 @@ class PureType(object):
- creating `Variable` instances (conventionally, `__call__` does this), and - creating `Variable` instances (conventionally, `__call__` does this), and
- filtering a value assigned to a `Variable` so that the value conforms to restrictions - filtering a value assigned to a `Variable` so that the value
imposed by the type (also known as casting, this is done by `filter`), conforms to restrictions imposed by the type (also known as
casting, this is done by `filter`),
""" """
# the type that will be created by call to make_variable.
Variable = graph.Variable
Variable = graph.Variable # the type that will be created by call to make_variable. # the type that will be created by call to make_constant
Constant = graph.Constant # the type that will be created by call to make_constant Constant = graph.Constant
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None):
"""Required: Return data or an appropriately wrapped/converted data. """Required: Return data or an appropriately wrapped/converted data.
Subclass implementation should raise a TypeError exception if the data is not of an Subclass implementation should raise a TypeError exception if
acceptable type. the data is not of an acceptable type.
If strict is True, the data returned must be the same as the If strict is True, the data returned must be the same as the
data passed as an argument. If it is False, and allow_downcast data passed as an argument. If it is False, and allow_downcast
...@@ -283,7 +289,8 @@ class PureType(object): ...@@ -283,7 +289,8 @@ class PureType(object):
return other return other
def is_valid_value(self, a): def is_valid_value(self, a):
"""Required: Return True for any python object `a` that would be a legal value for a Variable of this Type""" """Required: Return True for any python object `a` that would be a
legal value for a Variable of this Type"""
try: try:
self.filter(a, strict=True) self.filter(a, strict=True)
return True return True
...@@ -291,7 +298,8 @@ class PureType(object): ...@@ -291,7 +298,8 @@ class PureType(object):
return False return False
def value_validity_msg(self, a): def value_validity_msg(self, a):
"""Optional: return a message explaining the output of is_valid_value""" """Optional: return a message explaining the output of
is_valid_value"""
return "none" return "none"
def make_variable(self, name=None): def make_variable(self, name=None):
...@@ -371,7 +379,8 @@ class Type(object2, PureType, CLinkerType): ...@@ -371,7 +379,8 @@ class Type(object2, PureType, CLinkerType):
But you are encouraged to write your own, as described in WRITEME. But you are encouraged to write your own, as described in WRITEME.
The following following code illustrates the use of a Type instance, here tensor.fvector: The following following code illustrates the use of a Type
instance, here tensor.fvector:
.. code-block:: python .. code-block:: python
...@@ -381,17 +390,21 @@ class Type(object2, PureType, CLinkerType): ...@@ -381,17 +390,21 @@ class Type(object2, PureType, CLinkerType):
# Create a second Variable with the same Type instance # Create a second Variable with the same Type instance
c = tensor.fvector() c = tensor.fvector()
Whenever you create a symbolic variable in theano (technically, `Variable`) it will contain a Whenever you create a symbolic variable in theano (technically,
reference to a Type instance. That reference is typically constant during the lifetime of `Variable`) it will contain a reference to a Type instance. That
the Variable. Many variables can refer to a single Type instance, as do b and c above. The reference is typically constant during the lifetime of the
Type instance defines the kind of value which might end up in that variable when executing Variable. Many variables can refer to a single Type instance, as
a `Function`. In this sense, theano is like a strongly-typed language because the types do b and c above. The Type instance defines the kind of value
are included in the graph before the values. In our example above, b is a Variable which is which might end up in that variable when executing a `Function`.
guaranteed to correspond to a numpy.ndarray of rank 1 when we try to do some computations In this sense, theano is like a strongly-typed language because
the types are included in the graph before the values. In our
example above, b is a Variable which is guaranteed to correspond
to a numpy.ndarray of rank 1 when we try to do some computations
with it. with it.
Many `Op` instances will raise an exception if they are applied to inputs with incorrect Many `Op` instances will raise an exception if they are applied to
types. Type references are also useful to do type-checking in pattern-based optimizations. inputs with incorrect types. Type references are also useful to
do type-checking in pattern-based optimizations.
""" """
def convert_variable(self, var): def convert_variable(self, var):
...@@ -451,8 +464,8 @@ class Generic(SingletonType): ...@@ -451,8 +464,8 @@ class Generic(SingletonType):
""" """
Represents a generic Python object. Represents a generic Python object.
This class implements the `PureType` and `CLinkerType` interfaces for generic PyObject This class implements the `PureType` and `CLinkerType` interfaces
instances. for generic PyObject instances.
EXAMPLE of what this means, or when you would use this type. EXAMPLE of what this means, or when you would use this type.
......
from __future__ import print_function from __future__ import print_function
import linecache import linecache
import traceback import traceback
import re
import sys import sys
from theano import config from theano import config
...@@ -15,7 +14,6 @@ def simple_extract_stack(f=None, limit=None): ...@@ -15,7 +14,6 @@ def simple_extract_stack(f=None, limit=None):
This is because this update cause an call to os.stat to get the This is because this update cause an call to os.stat to get the
line content. This cause too much long on cluster. line content. This cause too much long on cluster.
""" """
if f is None: if f is None:
try: try:
...@@ -48,7 +46,7 @@ if sys.version_info[:2] > (3, 4): ...@@ -48,7 +46,7 @@ if sys.version_info[:2] > (3, 4):
# I enable my implementation only for some python version just to # I enable my implementation only for some python version just to
# be sure the Python internal do not change. If this work with # be sure the Python internal do not change. If this work with
# other python version, you can enable it. # other python version, you can enable it.
simple_extract_stack = traceback.extract_stack simple_extract_stack = traceback.extract_stack # noqa
def add_tag_trace(thing, user_line=1): def add_tag_trace(thing, user_line=1):
...@@ -190,8 +188,8 @@ def deprecated(filename, msg=''): ...@@ -190,8 +188,8 @@ def deprecated(filename, msg=''):
def g(*args, **kwargs): def g(*args, **kwargs):
if printme[0]: if printme[0]:
print('WARNING: %s.%s deprecated. %s'\ print('WARNING: %s.%s deprecated. %s' %
% (filename, f.__name__, msg)) (filename, f.__name__, msg))
printme[0] = False printme[0] = False
return f(*args, **kwargs) return f(*args, **kwargs)
return g return g
...@@ -220,7 +218,7 @@ def difference(seq1, seq2): ...@@ -220,7 +218,7 @@ def difference(seq1, seq2):
raise Exception('not worth it') raise Exception('not worth it')
set2 = set(seq2) set2 = set(seq2)
return [x for x in seq1 if x not in set2] return [x for x in seq1 if x not in set2]
except Exception as e: except Exception:
# maybe a seq2 element is not hashable # maybe a seq2 element is not hashable
# maybe seq2 is too short # maybe seq2 is too short
# -> use O(len(seq1) * len(seq2)) algo # -> use O(len(seq1) * len(seq2)) algo
...@@ -311,11 +309,11 @@ def comm_guard(type1, type2): ...@@ -311,11 +309,11 @@ def comm_guard(type1, type2):
old_f = f.func_globals[f.__name__] old_f = f.func_globals[f.__name__]
def new_f(arg1, arg2, *rest): def new_f(arg1, arg2, *rest):
if (type1 is ANY_TYPE or isinstance(arg1, type1)) \ if ((type1 is ANY_TYPE or isinstance(arg1, type1)) and
and (type2 is ANY_TYPE or isinstance(arg2, type2)): (type2 is ANY_TYPE or isinstance(arg2, type2))):
pass pass
elif (type1 is ANY_TYPE or isinstance(arg2, type1)) \ elif ((type1 is ANY_TYPE or isinstance(arg2, type1)) and
and (type2 is ANY_TYPE or isinstance(arg1, type2)): (type2 is ANY_TYPE or isinstance(arg1, type2))):
arg1, arg2 = arg2, arg1 arg1, arg2 = arg2, arg1
else: else:
return old_f(arg1, arg2, *rest) return old_f(arg1, arg2, *rest)
...@@ -337,7 +335,8 @@ def comm_guard(type1, type2): ...@@ -337,7 +335,8 @@ def comm_guard(type1, type2):
return type.__name__ return type.__name__
new_f.__doc__ = (str(old_f.__doc__) + "\n" + new_f.__doc__ = (str(old_f.__doc__) + "\n" +
", ".join([typename(type) for type in (type1, type2)]) + ", ".join([typename(type)
for type in (type1, type2)]) +
"\n" + str(f.__doc__ or "")) "\n" + str(f.__doc__ or ""))
return new_f return new_f
...@@ -406,15 +405,16 @@ def give_variables_names(variables): ...@@ -406,15 +405,16 @@ def give_variables_names(variables):
This function is idempotent.""" This function is idempotent."""
names = map(lambda var: var.name, variables) names = map(lambda var: var.name, variables)
h = hist(names) h = hist(names)
bad_var = lambda var: not var.name or h[var.name] > 1
def bad_var(var):
return not var.name or h[var.name] > 1
for i, var in enumerate(filter(bad_var, variables)): for i, var in enumerate(filter(bad_var, variables)):
var.name = (var.name or "") + "_%d" % i var.name = (var.name or "") + "_%d" % i
if not unique(map(str, variables)): if not unique(map(str, variables)):
raise ValueError("Not all variables have unique names." raise ValueError("Not all variables have unique names. Maybe you've "
"Maybe you've named some of the variables identically") "named some of the variables identically")
return variables return variables
......
...@@ -53,7 +53,8 @@ AddConfigVar('vm.lazy', ...@@ -53,7 +53,8 @@ AddConfigVar('vm.lazy',
in_c_key=False) in_c_key=False)
def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, dependencies): def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re,
dependencies):
reallocated_info = {} reallocated_info = {}
viewed_by = {} viewed_by = {}
for var in fgraph.variables: for var in fgraph.variables:
...@@ -74,14 +75,14 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend ...@@ -74,14 +75,14 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
ins = None ins = None
if dmap and idx_o in dmap: if dmap and idx_o in dmap:
idx_v = dmap[idx_o] idx_v = dmap[idx_o]
assert len( assert len(idx_v) == 1, ("Here we only support the possibility"
idx_v) == 1, "Here we only support the possibility to destroy one input" " to destroy one input")
ins = node.inputs[idx_v[0]] ins = node.inputs[idx_v[0]]
if vmap and idx_o in vmap: if vmap and idx_o in vmap:
assert ins is None assert ins is None
idx_v = vmap[idx_o] idx_v = vmap[idx_o]
assert len( assert len(idx_v) == 1, ("Here we only support the possibility"
idx_v) == 1, "Here we only support the possibility to view one input" " to view one input")
ins = node.inputs[idx_v[0]] ins = node.inputs[idx_v[0]]
if ins is not None: if ins is not None:
assert isinstance(ins, theano.Variable) assert isinstance(ins, theano.Variable)
...@@ -92,10 +93,11 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend ...@@ -92,10 +93,11 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
for ins in node.inputs: for ins in node.inputs:
assert not (ins in view_of and viewed_by[ins]) assert not (ins in view_of and viewed_by[ins])
if (getattr(ins, 'ndim', None) == 0 and not storage_map[ins][0] if (getattr(ins, 'ndim', None) == 0 and not storage_map[ins][0] and
and ins not in fgraph.outputs and ins.owner ins not in fgraph.outputs and ins.owner and
and all([compute_map_re[v][0] for v in dependencies.get(ins, [])]) all([compute_map_re[v][0]
and ins not in allocated): for v in dependencies.get(ins, [])]) and
ins not in allocated):
# Constant Memory cannot be changed # Constant Memory cannot be changed
# Constant and shared variables' storage_map value is not empty # Constant and shared variables' storage_map value is not empty
reuse_out = None reuse_out = None
...@@ -105,8 +107,9 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend ...@@ -105,8 +107,9 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
if reuse_out: if reuse_out:
break break
for out in order[i].outputs: for out in order[i].outputs:
if (getattr(out, 'ndim', None) == 0 and out not in pre_allocated if (getattr(out, 'ndim', None) == 0 and
and ins.type == out.type): out not in pre_allocated and
ins.type == out.type):
reuse_out = out reuse_out = out
pre_allocated.add(out) pre_allocated.add(out)
allocated.add(ins) allocated.add(ins)
...@@ -122,8 +125,9 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend ...@@ -122,8 +125,9 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
if reuse_out: if reuse_out:
break break
for out in order[i].outputs: for out in order[i].outputs:
if (getattr(out, 'ndim', None) == 0 and out not in pre_allocated if (getattr(out, 'ndim', None) == 0 and
and ins.type == out.type): out not in pre_allocated and
ins.type == out.type):
reuse_out = out reuse_out = out
pre_allocated.add(out) pre_allocated.add(out)
allocated.add(ins) allocated.add(ins)
...@@ -508,7 +512,8 @@ class Stack(VM): ...@@ -508,7 +512,8 @@ class Stack(VM):
st = "c" st = "c"
self.variable_strides[var] = st self.variable_strides[var] = st
except Exception: except Exception:
link.raise_with_op(current_apply, link.raise_with_op(
current_apply,
self.thunks[self.node_idx[current_apply]], self.thunks[self.node_idx[current_apply]],
storage_map=storage_map) storage_map=storage_map)
for o in current_apply.outputs: for o in current_apply.outputs:
...@@ -521,9 +526,9 @@ class Stack(VM): ...@@ -521,9 +526,9 @@ class Stack(VM):
for i in current_apply.inputs: for i in current_apply.inputs:
# Garbage Collection -> check if anybody else uses # Garbage Collection -> check if anybody else uses
# this input # this input
if (dependencies[i] if (dependencies[i] and
and i.owner i.owner and
and i not in self.outputs): i not in self.outputs):
if all(compute_map[v][0] if all(compute_map[v][0]
for v in dependencies[i]): for v in dependencies[i]):
storage_map[i][0] = None storage_map[i][0] = None
...@@ -544,10 +549,13 @@ class Stack(VM): ...@@ -544,10 +549,13 @@ class Stack(VM):
'destroy_map', 'destroy_map',
False)): False)):
warnings.warn( warnings.warn(
"There was a bug that existed in the default Theano configuration," "There was a bug that existed in "
" only in the development version between July 5th 2012" "the default Theano configuration,"
" and July 30th 2012. This was not in a released version." " only in the development version "
" The bug was affecting this script.", "between July 5th 2012 and "
"July 30th 2012. This was not in "
"a released version. The bug was "
"affecting this script.",
# The stack level is not good when # The stack level is not good when
# inside a Scan. # inside a Scan.
stacklevel=3 stacklevel=3
...@@ -578,7 +586,8 @@ class Stack(VM): ...@@ -578,7 +586,8 @@ class Stack(VM):
self.call_times[current_idx] += dt self.call_times[current_idx] += dt
except Exception: except Exception:
link.raise_with_op(current_apply, link.raise_with_op(
current_apply,
self.thunks[self.node_idx[current_apply]], self.thunks[self.node_idx[current_apply]],
storage_map=storage_map) storage_map=storage_map)
...@@ -639,7 +648,7 @@ class Stack(VM): ...@@ -639,7 +648,7 @@ class Stack(VM):
if self.allow_gc: if self.allow_gc:
for v in storage_map: for v in storage_map:
if v.owner and not v in self.outputs: if v.owner and v not in self.outputs:
if compute_map[v][0] == 2: if compute_map[v][0] == 2:
continue continue
else: else:
...@@ -840,7 +849,6 @@ class VM_Linker(link.LocalLinker): ...@@ -840,7 +849,6 @@ class VM_Linker(link.LocalLinker):
vars_idx_inv[i] = var vars_idx_inv[i] = var
# put storage_map and compute_map into a int-based scheme # put storage_map and compute_map into a int-based scheme
n_applies = len(nodes)
storage_map_list = [storage_map[vars_idx_inv[i]] storage_map_list = [storage_map[vars_idx_inv[i]]
for i in xrange(len(vars_idx_inv))] for i in xrange(len(vars_idx_inv))]
compute_map_list = [compute_map[vars_idx_inv[i]] compute_map_list = [compute_map[vars_idx_inv[i]]
...@@ -988,7 +996,8 @@ class VM_Linker(link.LocalLinker): ...@@ -988,7 +996,8 @@ class VM_Linker(link.LocalLinker):
else: else:
dependencies = self.compute_gc_dependencies(storage_map) dependencies = self.compute_gc_dependencies(storage_map)
reallocated_info = calculate_reallocate_info(order, fgraph, storage_map, compute_map_re,dependencies) reallocated_info = calculate_reallocate_info(
order, fgraph, storage_map, compute_map_re, dependencies)
for node in order: for node in order:
try: try:
...@@ -1014,7 +1023,8 @@ class VM_Linker(link.LocalLinker): ...@@ -1014,7 +1023,8 @@ class VM_Linker(link.LocalLinker):
lazy = config.vm.lazy lazy = config.vm.lazy
if lazy is None: if lazy is None:
lazy = not all([(not th.lazy) for th in thunks]) lazy = not all([(not th.lazy) for th in thunks])
if not (lazy or (config.profile and config.profile_memory) or self.use_cloop or self.callback): if not (lazy or (config.profile and config.profile_memory) or
self.use_cloop or self.callback):
for pair in reallocated_info.values(): for pair in reallocated_info.values():
storage_map[pair[1]] = storage_map[pair[0]] storage_map[pair[1]] = storage_map[pair[0]]
...@@ -1024,10 +1034,10 @@ class VM_Linker(link.LocalLinker): ...@@ -1024,10 +1034,10 @@ class VM_Linker(link.LocalLinker):
for node in order: for node in order:
clear_after_this_thunk = [] clear_after_this_thunk = []
for input in node.inputs: for input in node.inputs:
if ((input in computed) if (input in computed and
and (input not in fgraph.outputs) input not in fgraph.outputs and
and (node == last_user[input]) node == last_user[input] and
and input not in reallocated_info.keys()): input not in reallocated_info.keys()):
clear_after_this_thunk.append(storage_map[input]) clear_after_this_thunk.append(storage_map[input])
post_thunk_clear.append(clear_after_this_thunk) post_thunk_clear.append(clear_after_this_thunk)
else: else:
......
...@@ -2,7 +2,6 @@ from functools import wraps ...@@ -2,7 +2,6 @@ from functools import wraps
import numpy import numpy
import theano
from theano import scalar as scal, Constant from theano import scalar as scal, Constant
from theano.gof import local_optimizer from theano.gof import local_optimizer
from theano.tensor import (DimShuffle, get_scalar_constant_value, from theano.tensor import (DimShuffle, get_scalar_constant_value,
......
...@@ -7,13 +7,13 @@ from theano.tests import unittest_tools as utt ...@@ -7,13 +7,13 @@ from theano.tests import unittest_tools as utt
# Skip tests if cuda_ndarray is not available. # Skip tests if cuda_ndarray is not available.
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
import theano.sandbox.cuda as cuda_ndarray import theano.sandbox.cuda as cuda_ndarray
if not cuda_ndarray.cuda_available: if not cuda_ndarray.cuda_available: # noqa
raise SkipTest('Optional package cuda not available') raise SkipTest('Optional package cuda not available')
from theano.misc.pycuda_init import pycuda_available from theano.misc.pycuda_init import pycuda_available
if not pycuda_available: if not pycuda_available: # noqa
raise SkipTest('Optional package pycuda not available') raise SkipTest('Optional package pycuda not available')
from theano.sandbox.cuda.fftconv import scikits_cuda_available from theano.sandbox.cuda.fftconv import scikits_cuda_available
if not scikits_cuda_available: if not scikits_cuda_available: # noqa
raise SkipTest('Optional package scikits.cuda not available') raise SkipTest('Optional package scikits.cuda not available')
from theano.sandbox.cuda import float32_shared_constructor as shared from theano.sandbox.cuda import float32_shared_constructor as shared
......
...@@ -2,13 +2,14 @@ ...@@ -2,13 +2,14 @@
# mpiexec -np 2 python _test_mpi_roundtrip.py # mpiexec -np 2 python _test_mpi_roundtrip.py
from mpi4py import MPI from mpi4py import MPI
comm = MPI.COMM_WORLD
import theano import theano
from theano.tensor.io import send, recv, mpi_cmps from theano.tensor.io import send, recv, mpi_cmps
from theano.gof.sched import sort_schedule_fn from theano.gof.sched import sort_schedule_fn
import numpy as np import numpy as np
from sys import stdout, stderr, exit from sys import stdout, stderr, exit
comm = MPI.COMM_WORLD
rank = comm.Get_rank() rank = comm.Get_rank()
size = comm.Get_size() size = comm.Get_size()
......
from datetime import datetime
__authors__ = "Ian Goodfellow" __authors__ = "Ian Goodfellow"
__credits__ = ["Ian Goodfellow"] __credits__ = ["Ian Goodfellow"]
__license__ = "3-clause BSD" __license__ = "3-clause BSD"
__maintainer__ = "Ian Goodfellow" __maintainer__ = "Ian Goodfellow"
__email__ = "goodfeli@iro" __email__ = "goodfeli@iro"
from datetime import datetime
def disturb_mem(): def disturb_mem():
# Allocate a time-dependent amount of objects to increase # Allocate a time-dependent amount of objects to increase
......
from __future__ import print_function from __future__ import print_function
import os, unittest, sys import os
import nose.plugins.builtin import unittest
import sys
from nose.config import Config from nose.config import Config
from nose.plugins.manager import PluginManager from nose.plugins.manager import PluginManager
from numpy.testing.nosetester import import_nose, NoseTester import nose.plugins.builtin
from numpy.testing.nosetester import NoseTester
from numpy.testing.noseclasses import KnownFailure, NumpyTestProgram from numpy.testing.noseclasses import KnownFailure, NumpyTestProgram
...@@ -31,7 +34,7 @@ class TheanoNoseTester(NoseTester): ...@@ -31,7 +34,7 @@ class TheanoNoseTester(NoseTester):
:type extra_argv: list :type extra_argv: list
:param extra_argv: List with any extra arguments to pass to nosetests. :param extra_argv: List with any extra arguments to pass to nosetests.
""" """
#self.package_path = os.path.abspath(self.package_path) # self.package_path = os.path.abspath(self.package_path)
argv = [__file__, self.package_path] argv = [__file__, self.package_path]
argv += ['--verbosity', str(verbose)] argv += ['--verbosity', str(verbose)]
if extra_argv: if extra_argv:
...@@ -39,8 +42,6 @@ class TheanoNoseTester(NoseTester): ...@@ -39,8 +42,6 @@ class TheanoNoseTester(NoseTester):
return argv return argv
def _show_system_info(self): def _show_system_info(self):
nose = import_nose()
import theano import theano
print("Theano version %s" % theano.__version__) print("Theano version %s" % theano.__version__)
theano_dir = os.path.dirname(theano.__file__) theano_dir = os.path.dirname(theano.__file__)
...@@ -55,16 +56,14 @@ class TheanoNoseTester(NoseTester): ...@@ -55,16 +56,14 @@ class TheanoNoseTester(NoseTester):
Takes the same arguments as `test`. Takes the same arguments as `test`.
""" """
# fail with nice error message if nose is not present
nose = import_nose()
# compile argv # compile argv
argv = self._test_argv(verbose, extra_argv) argv = self._test_argv(verbose, extra_argv)
# numpy way of doing coverage # numpy way of doing coverage
if coverage: if coverage:
argv += ['--cover-package=%s' % self.package_name, '--with-coverage', argv += ['--cover-package=%s' % self.package_name,
'--cover-tests', '--cover-inclusive', '--cover-erase'] '--with-coverage', '--cover-tests',
'--cover-inclusive', '--cover-erase']
# Capture output only if needed # Capture output only if needed
if not capture: if not capture:
...@@ -91,7 +90,8 @@ class TheanoNoseTester(NoseTester): ...@@ -91,7 +90,8 @@ class TheanoNoseTester(NoseTester):
:param extra_argv: List with any extra arguments to pass to nosetests. :param extra_argv: List with any extra arguments to pass to nosetests.
:type coverage: bool :type coverage: bool
:param coverage: If True, report coverage of Theano code. Default is False. :param coverage: If True, report coverage of Theano
code. Default is False.
:type capture: bool :type capture: bool
:param capture: If True, capture the standard output of the tests, like :param capture: If True, capture the standard output of the tests, like
...@@ -134,8 +134,6 @@ class TheanoNoseTester(NoseTester): ...@@ -134,8 +134,6 @@ class TheanoNoseTester(NoseTester):
def main(modulename): def main(modulename):
debug = False
if 0: if 0:
unittest.main() unittest.main()
elif len(sys.argv) == 2 and sys.argv[1] == "--debug": elif len(sys.argv) == 2 and sys.argv[1] == "--debug":
......
...@@ -20,7 +20,6 @@ __contact__ = "Saizheng Zhang <saizhenglisa..at..gmail.com>" ...@@ -20,7 +20,6 @@ __contact__ = "Saizheng Zhang <saizhenglisa..at..gmail.com>"
whitelist_flake8 = [ whitelist_flake8 = [
"__init__.py", "__init__.py",
"version.py",
"tests/test_gradient.py", "tests/test_gradient.py",
"tests/test_config.py", "tests/test_config.py",
"tests/diverse_tests.py", "tests/diverse_tests.py",
...@@ -31,37 +30,20 @@ whitelist_flake8 = [ ...@@ -31,37 +30,20 @@ whitelist_flake8 = [
"tests/test_record.py", "tests/test_record.py",
"tests/__init__.py", "tests/__init__.py",
"tests/test_updates.py", "tests/test_updates.py",
"tests/main.py",
"tests/test_pickle_unpickle_theano_fn.py", "tests/test_pickle_unpickle_theano_fn.py",
"tests/test_determinism.py", "tests/test_determinism.py",
"tests/record.py", "tests/record.py",
"tests/test_printing.py",
"tests/test_tutorial.py", "tests/test_tutorial.py",
"tests/disturb_mem.py",
"tests/unittest_tools.py", "tests/unittest_tools.py",
"compile/ops.py",
"compile/debugmode.py",
"compile/function.py",
"compile/pfunc.py",
"compile/mode.py",
"compile/profilemode.py",
"compile/builders.py",
"compile/__init__.py", "compile/__init__.py",
"compile/profiling.py", "compile/profiling.py",
"compile/function_module.py",
"compile/sharedvalue.py",
"compile/monitormode.py",
"compile/io.py",
"compile/module.py",
"compile/tests/test_builders.py", "compile/tests/test_builders.py",
"compile/tests/test_misc.py", "compile/tests/test_misc.py",
"compile/tests/test_monitormode.py", "compile/tests/test_monitormode.py",
"compile/tests/test_function_module.py", "compile/tests/test_function_module.py",
"compile/tests/test_inplace_opt_for_value.py",
"compile/tests/test_shared.py", "compile/tests/test_shared.py",
"compile/tests/test_ops.py", "compile/tests/test_ops.py",
"compile/tests/test_pfunc.py", "compile/tests/test_pfunc.py",
"compile/tests/test_module.py",
"compile/tests/test_debugmode.py", "compile/tests/test_debugmode.py",
"compile/tests/test_profiling.py", "compile/tests/test_profiling.py",
"typed_list/type.py", "typed_list/type.py",
...@@ -94,16 +76,13 @@ whitelist_flake8 = [ ...@@ -94,16 +76,13 @@ whitelist_flake8 = [
"tensor/io.py", "tensor/io.py",
"tensor/elemwise_cgen.py", "tensor/elemwise_cgen.py",
"tensor/raw_random.py", "tensor/raw_random.py",
"tensor/randomstreams.py",
"tensor/blas_scipy.py", "tensor/blas_scipy.py",
"tensor/basic.py", "tensor/basic.py",
"tensor/tests/test_subtensor.py", "tensor/tests/test_subtensor.py",
"tensor/tests/test_utils.py", "tensor/tests/test_utils.py",
"tensor/tests/test_nlinalg.py", "tensor/tests/test_nlinalg.py",
"tensor/tests/test_randomstreams.py",
"tensor/tests/test_shared_randomstreams.py", "tensor/tests/test_shared_randomstreams.py",
"tensor/tests/test_misc.py", "tensor/tests/test_misc.py",
"tensor/tests/test_naacl09.py",
"tensor/tests/mlp_test.py", "tensor/tests/mlp_test.py",
"tensor/tests/test_opt_uncanonicalize.py", "tensor/tests/test_opt_uncanonicalize.py",
"tensor/tests/test_opt.py", "tensor/tests/test_opt.py",
...@@ -155,7 +134,6 @@ whitelist_flake8 = [ ...@@ -155,7 +134,6 @@ whitelist_flake8 = [
"sandbox/test_theano_object.py", "sandbox/test_theano_object.py",
"sandbox/test_scan.py", "sandbox/test_scan.py",
"sandbox/rng_mrg.py", "sandbox/rng_mrg.py",
"sandbox/downsample.py",
"sandbox/solve.py", "sandbox/solve.py",
"sandbox/theano_object.py", "sandbox/theano_object.py",
"sandbox/scan.py", "sandbox/scan.py",
...@@ -190,7 +168,6 @@ whitelist_flake8 = [ ...@@ -190,7 +168,6 @@ whitelist_flake8 = [
"sandbox/cuda/nvcc_compiler.py", "sandbox/cuda/nvcc_compiler.py",
"sandbox/cuda/neighbours.py", "sandbox/cuda/neighbours.py",
"sandbox/cuda/tests/walltime.py", "sandbox/cuda/tests/walltime.py",
"sandbox/cuda/tests/test_fftconv.py",
"sandbox/cuda/tests/test_gradient.py", "sandbox/cuda/tests/test_gradient.py",
"sandbox/cuda/tests/test_neighbours.py", "sandbox/cuda/tests/test_neighbours.py",
"sandbox/cuda/tests/test_conv_cuda_ndarray.py", "sandbox/cuda/tests/test_conv_cuda_ndarray.py",
...@@ -218,7 +195,6 @@ whitelist_flake8 = [ ...@@ -218,7 +195,6 @@ whitelist_flake8 = [
"sandbox/scan_module/tests/test_utils.py", "sandbox/scan_module/tests/test_utils.py",
"sandbox/scan_module/tests/test_scan.py", "sandbox/scan_module/tests/test_scan.py",
"sandbox/linalg/ops.py", "sandbox/linalg/ops.py",
"sandbox/linalg/kron.py",
"sandbox/linalg/__init__.py", "sandbox/linalg/__init__.py",
"sandbox/linalg/tests/test_linalg.py", "sandbox/linalg/tests/test_linalg.py",
"sandbox/gpuarray/comp.py", "sandbox/gpuarray/comp.py",
...@@ -288,24 +264,12 @@ whitelist_flake8 = [ ...@@ -288,24 +264,12 @@ whitelist_flake8 = [
"sparse/sandbox/truedot.py", "sparse/sandbox/truedot.py",
"sparse/sandbox/sp.py", "sparse/sandbox/sp.py",
"gof/destroyhandler.py", "gof/destroyhandler.py",
"gof/vm.py",
"gof/cutils.py",
"gof/compiledir.py",
"gof/unify.py", "gof/unify.py",
"gof/lazylinker_c.py",
"gof/optdb.py",
"gof/utils.py",
"gof/graph.py", "gof/graph.py",
"gof/callcache.py",
"gof/python25.py",
"gof/type.py",
"gof/__init__.py", "gof/__init__.py",
"gof/cc.py", "gof/cc.py",
"gof/opt.py", "gof/opt.py",
"gof/compilelock.py",
"gof/link.py", "gof/link.py",
"gof/sched.py",
"gof/toolbox.py",
"gof/fg.py", "gof/fg.py",
"gof/op.py", "gof/op.py",
"gof/cmodule.py", "gof/cmodule.py",
...@@ -322,9 +286,6 @@ whitelist_flake8 = [ ...@@ -322,9 +286,6 @@ whitelist_flake8 = [
"gof/tests/test_cc.py", "gof/tests/test_cc.py",
"gof/tests/test_compute_test_value.py", "gof/tests/test_compute_test_value.py",
"gof/sandbox/equilibrium.py", "gof/sandbox/equilibrium.py",
"sandbox/cuda/opt_util.py",
"gof/tests/test_utils.py",
"tensor/tests/_test_mpi_roundtrip.py",
] ]
......
try: try:
from theano.generated_version import * from theano.generated_version import * # noqa
except ImportError: except ImportError:
short_version = 'unknown' short_version = 'unknown'
version = 'unknown' version = 'unknown'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论