提交 4cf7afb4 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2952 from abergeron/flake8

Flake8 work
...@@ -15,7 +15,9 @@ class OpFromGraph(gof.Op): ...@@ -15,7 +15,9 @@ class OpFromGraph(gof.Op):
TODO: TODO:
- examples for a multi-layer mlp. where? - examples for a multi-layer mlp. where?
- __hash__, __eq__ otherwise won't merge, try gof.opt.is_same_graph_with_merge(op1.new_outputs, op2, new_outputs) - __hash__, __eq__ otherwise won't merge, try
gof.opt.is_same_graph_with_merge(op1.new_outputs, op2,
new_outputs)
- c_code() to remove the double overhead? - c_code() to remove the double overhead?
- opt to unfold it, work inplace on inputs - opt to unfold it, work inplace on inputs
- grad() make it support DisconnectedType and the new interface - grad() make it support DisconnectedType and the new interface
...@@ -68,7 +70,7 @@ class OpFromGraph(gof.Op): ...@@ -68,7 +70,7 @@ class OpFromGraph(gof.Op):
for i in inputs + outputs: for i in inputs + outputs:
if not isinstance(i, gof.Variable): if not isinstance(i, gof.Variable):
raise TypeError( raise TypeError(
'inputs and outputs must be Variable instances', i) 'inputs and outputs must be Variable instances', i)
if 'updates' in kwargs: if 'updates' in kwargs:
raise TypeError('updates are not allowed in kwargs') raise TypeError('updates are not allowed in kwargs')
...@@ -76,8 +78,6 @@ class OpFromGraph(gof.Op): ...@@ -76,8 +78,6 @@ class OpFromGraph(gof.Op):
# not see them. Otherwise their is problem with the gradient. # not see them. Otherwise their is problem with the gradient.
self.shared_inputs = [var for var in gof.graph.inputs(outputs) self.shared_inputs = [var for var in gof.graph.inputs(outputs)
if isinstance(var, SharedVariable)] if isinstance(var, SharedVariable)]
used_inputs = [var for var in gof.graph.inputs(outputs)
if not isinstance(var, gof.Constant)]
shared_vars = [var.type() for var in self.shared_inputs] shared_vars = [var.type() for var in self.shared_inputs]
new = rebuild_collect_shared(outputs, inputs=inputs + shared_vars, new = rebuild_collect_shared(outputs, inputs=inputs + shared_vars,
replace=dict(zip(self.shared_inputs, replace=dict(zip(self.shared_inputs,
...@@ -110,8 +110,8 @@ class OpFromGraph(gof.Op): ...@@ -110,8 +110,8 @@ class OpFromGraph(gof.Op):
def make_node(self, *inputs): def make_node(self, *inputs):
for input, type in zip(inputs, self.input_types): for input, type in zip(inputs, self.input_types):
if not type == input.type: if not type == input.type:
raise TypeError("Wrong type, expected %s but got %s" raise TypeError("Wrong type, expected %s but got %s" %
% (type, input.type)) (type, input.type))
return gof.Apply(self, return gof.Apply(self,
list(inputs) + self.shared_inputs, list(inputs) + self.shared_inputs,
[type() for type in self.output_types]) [type() for type in self.output_types])
...@@ -143,9 +143,10 @@ class OpFromGraph(gof.Op): ...@@ -143,9 +143,10 @@ class OpFromGraph(gof.Op):
grad_ops = self.grad_ops grad_ops = self.grad_ops
else: else:
gs = theano.gradient.grad(cost=None, gs = theano.gradient.grad(cost=None,
known_grads=dict(zip(self.new_outputs, output_grads)), known_grads=dict(zip(self.new_outputs,
wrt=self.new_inputs, output_grads)),
disconnected_inputs='ignore') wrt=self.new_inputs,
disconnected_inputs='ignore')
grad_ops = [] grad_ops = []
for g in gs: for g in gs:
......
差异被折叠。
"""Define `SymbolicInput`, `SymbolicOutput`, `In`, `Out` """ """Define `SymbolicInput`, `SymbolicOutput`, `In`, `Out` """
__docformat__ = 'restructuredtext en'
from theano import gof from theano import gof
from sharedvalue import SharedVariable from sharedvalue import SharedVariable
...@@ -7,6 +6,8 @@ from sharedvalue import SharedVariable ...@@ -7,6 +6,8 @@ from sharedvalue import SharedVariable
import logging import logging
_logger = logging.getLogger("theano.compile.io") _logger = logging.getLogger("theano.compile.io")
__docformat__ = 'restructuredtext en'
class SymbolicInput(object): class SymbolicInput(object):
""" """
...@@ -17,42 +18,55 @@ class SymbolicInput(object): ...@@ -17,42 +18,55 @@ class SymbolicInput(object):
not computed from its owner. not computed from its owner.
name: Any type. (If autoname=True, defaults to variable.name). name: Any type. (If autoname=True, defaults to variable.name).
If name is a valid Python identifier, this input can be set by kwarg, and its value If name is a valid Python identifier, this input can be set by
can be accessed by self.<name>. kwarg, and its value can be accessed by self.<name>.
update: Variable instance (default: None) update: Variable instance (default: None)
value (see previous) will be replaced with this expression variable after each function call. value (see previous) will be replaced with this expression
If update is None, the update will be the default value of the input. variable after each function call. If update is None, the
update will be the default value of the input.
mutable: Bool (default: False if update is None, True if update is
not None)
True: permit the compiled function to modify the python object
being passed as the input
mutable: Bool (default: False if update is None, True if update is not None) False: do not permit the compiled function to modify the
True: permit the compiled function to modify the python object being passed as the input python object being passed as the input.
False: do not permit the compiled function to modify the python object being passed as the input.
strict: Bool (default: False) strict: Bool (default: False)
True: means that the value you pass for this input must have exactly the right type
False: the value you pass for this input may be cast automatically to the proper type True: means that the value you pass for this input must have
exactly the right type
False: the value you pass for this input may be cast
automatically to the proper type
allow_downcast: Bool or None (default: None) allow_downcast: Bool or None (default: None)
Only applies when `strict` is False. Only applies when `strict` is False.
True: the value you pass for this input can be silently True: the value you pass for this input can be silently
downcasted to fit the right type, which may lose precision. downcasted to fit the right type, which may lose precision.
False: the value will only be cast to a more general, or precise, type.
None: Almost like False, but allows downcast of Python floats to floatX. False: the value will only be cast to a more general, or
precise, type. None: Almost like False, but allows downcast
of Python floats to floatX.
autoname: Bool (default: True) autoname: Bool (default: True)
See the name option. See the name option.
implicit: Bool (default: False) implicit: Bool (default: False)
See help(In). Note that 'None' is not allowed here, since we are in the See help(In). Note that 'None' is not allowed here, since we
symbolic case. are in the symbolic case.
""" """
def __init__(self, variable, name=None, update=None, mutable=None, def __init__(self, variable, name=None, update=None, mutable=None,
strict=False, allow_downcast=None, autoname=True, strict=False, allow_downcast=None, autoname=True,
implicit=False): implicit=False):
assert implicit is not None # Safety check. assert implicit is not None # Safety check.
self.variable = variable self.variable = variable
if (autoname and name is None): if (autoname and name is None):
self.name = variable.name self.name = variable.name
else: else:
self.name = name self.name = name
...@@ -146,36 +160,54 @@ class In(SymbolicInput): ...@@ -146,36 +160,54 @@ class In(SymbolicInput):
not computed from its owner. not computed from its owner.
name: Any type. (If autoname=True, defaults to variable.name). name: Any type. (If autoname=True, defaults to variable.name).
If name is a valid Python identifier, this input can be set by kwarg, and its value If name is a valid Python identifier, this input can be set by
can be accessed by self.<name>. kwarg, and its value can be accessed by self.<name>.
value: Any type. value: Any type.
The initial/default value for this input. If update is None, this input acts just like The initial/default value for this input. If update is None,
an argument with a default value in Python. If update is not None, changes to this this input acts just like an argument with a default value in
value will "stick around", whether due to an update or a user's explicit action. Python. If update is not None, changes to this value will
"stick around", whether due to an update or a user's explicit
action.
update: Variable instance (default: None) update: Variable instance (default: None)
value (see previous) will be replaced with this expression variable after each function call. value (see previous) will be replaced with this expression
If update is None, the update will be the default value of the input. variable after each function call. If update is None, the
update will be the default value of the input.
mutable: Bool (default: False if update is None, True if update is not None) mutable: Bool (default: False if update is None, True if update is
True: permit the compiled function to modify the python object being passed as the input not None)
False: do not permit the compiled function to modify the python object being passed as the input.
True: permit the compiled function to modify the python object
being passed as the input
False: do not permit the compiled function to modify the
python object being passed as the input.
borrow: Bool (default: take the same value as mutable) borrow: Bool (default: take the same value as mutable)
True: permit the output of the compiled function to be aliased to the input
True: permit the output of the compiled function to be aliased
to the input
False: do not permit any output to be aliased to the input False: do not permit any output to be aliased to the input
strict: Bool (default: False) strict: Bool (default: False)
True: means that the value you pass for this input must have exactly the right type
False: the value you pass for this input may be cast automatically to the proper type True: means that the value you pass for this input must have
exactly the right type
False: the value you pass for this input may be cast
automatically to the proper type
allow_downcast: Bool or None (default: None) allow_downcast: Bool or None (default: None)
Only applies when `strict` is False. Only applies when `strict` is False.
True: the value you pass for this input can be silently True: the value you pass for this input can be silently
downcasted to fit the right type, which may lose precision. downcasted to fit the right type, which may lose precision.
False: the value will only be cast to a more general, or precise, type.
None: Almost like False, but allows downcast of Python floats to floatX. False: the value will only be cast to a more general, or
precise, type. None: Almost like False, but allows downcast
of Python floats to floatX.
autoname: Bool (default: True) autoname: Bool (default: True)
See the name option. See the name option.
...@@ -194,11 +226,11 @@ class In(SymbolicInput): ...@@ -194,11 +226,11 @@ class In(SymbolicInput):
# Note: the documentation above is duplicated in doc/topics/function.txt, # Note: the documentation above is duplicated in doc/topics/function.txt,
# try to keep it synchronized. # try to keep it synchronized.
def __init__(self, variable, name=None, value=None, update=None, def __init__(self, variable, name=None, value=None, update=None,
mutable=None, strict=False, allow_downcast=None, autoname=True, mutable=None, strict=False, allow_downcast=None,
implicit=None, borrow=None, shared=False): autoname=True, implicit=None, borrow=None, shared=False):
# if shared, an input's value comes from its persistent
# if shared, an input's value comes from its persistent storage, not from a default stored # storage, not from a default stored in the function or from
# in the function or from the caller # the caller
self.shared = shared self.shared = shared
if borrow is None: if borrow is None:
...@@ -211,25 +243,25 @@ class In(SymbolicInput): ...@@ -211,25 +243,25 @@ class In(SymbolicInput):
# aliased to the input. Thus mutable=True should require borrow=True. # aliased to the input. Thus mutable=True should require borrow=True.
if mutable and not self.borrow: if mutable and not self.borrow:
raise AssertionError( raise AssertionError(
"Symbolic input for variable %s (name=%s) has " "Symbolic input for variable %s (name=%s) has "
"flags mutable=True, borrow=False. This combination is " "flags mutable=True, borrow=False. This combination is "
"incompatible since mutable=True implies that the " "incompatible since mutable=True implies that the "
"input variable may be both aliased (borrow=True) and " "input variable may be both aliased (borrow=True) and "
"overwritten.", "overwritten.",
variable, name) variable, name)
if implicit is None: if implicit is None:
implicit = (isinstance(value, gof.Container) or implicit = (isinstance(value, gof.Container) or
isinstance(value, SharedVariable)) isinstance(value, SharedVariable))
super(In, self).__init__( super(In, self).__init__(
variable=variable, variable=variable,
name=name, name=name,
update=update, update=update,
mutable=mutable, mutable=mutable,
strict=strict, strict=strict,
allow_downcast=allow_downcast, allow_downcast=allow_downcast,
autoname=autoname, autoname=autoname,
implicit=implicit) implicit=implicit)
self.value = value self.value = value
if self.implicit and value is None: if self.implicit and value is None:
raise TypeError('An implicit input must be given a default value') raise TypeError('An implicit input must be given a default value')
......
...@@ -2,35 +2,33 @@ ...@@ -2,35 +2,33 @@
""" """
from __future__ import print_function from __future__ import print_function
import logging import logging
import warnings
from textwrap import dedent
import numpy import numpy
import theano import theano
from theano import gof from theano import gof
import theano.gof.vm import theano.gof.vm
from theano.configparser import config, AddConfigVar, StrParam from theano.configparser import config, AddConfigVar, StrParam
from theano.compile.ops import register_view_op_c_code, _output_guard from theano.compile.ops import _output_guard
_logger = logging.getLogger('theano.compile.mode') _logger = logging.getLogger('theano.compile.mode')
AddConfigVar('optimizer_excluding', AddConfigVar('optimizer_excluding',
("When using the default mode, we will remove optimizer with these " ("When using the default mode, we will remove optimizer with "
"tags. Separate tags with ':'."), "these tags. Separate tags with ':'."),
StrParam("", allow_override=False), StrParam("", allow_override=False),
in_c_key=False) in_c_key=False)
AddConfigVar('optimizer_including', AddConfigVar('optimizer_including',
("When using the default mode, we will add optimizer with these tags. " ("When using the default mode, we will add optimizer with "
"Separate tags with ':'."), "these tags. Separate tags with ':'."),
StrParam("", allow_override=False), StrParam("", allow_override=False),
in_c_key=False) in_c_key=False)
AddConfigVar('optimizer_requiring', AddConfigVar('optimizer_requiring',
("When using the default mode, we will require optimizer with these " ("When using the default mode, we will require optimizer with "
"tags. Separate tags with ':'."), "these tags. Separate tags with ':'."),
StrParam("", allow_override=False), StrParam("", allow_override=False),
in_c_key=False) in_c_key=False)
def check_equal(x, y): def check_equal(x, y):
...@@ -50,15 +48,15 @@ def check_equal(x, y): ...@@ -50,15 +48,15 @@ def check_equal(x, y):
y = y.todense() y = y.todense()
if isinstance(x, numpy.ndarray) and isinstance(y, numpy.ndarray): if isinstance(x, numpy.ndarray) and isinstance(y, numpy.ndarray):
if (x.dtype != y.dtype if (x.dtype != y.dtype or
or x.shape != y.shape x.shape != y.shape or
or numpy.any(abs(x - y) > 1e-10)): numpy.any(abs(x - y) > 1e-10)):
raise Exception("Output mismatch.", raise Exception("Output mismatch.",
{'performlinker': x, 'clinker': y}) {'performlinker': x, 'clinker': y})
else: else:
if x != y: if x != y:
raise Exception("Output mismatch.", raise Exception("Output mismatch.",
{'performlinker': x, 'clinker': y}) {'performlinker': x, 'clinker': y})
# If a string is passed as the linker argument in the constructor for # If a string is passed as the linker argument in the constructor for
...@@ -144,11 +142,11 @@ class AddDestroyHandler(gof.Optimizer): ...@@ -144,11 +142,11 @@ class AddDestroyHandler(gof.Optimizer):
for o in fgraph.outputs: for o in fgraph.outputs:
try: try:
fgraph.replace_validate(o, _output_guard(o), fgraph.replace_validate(o, _output_guard(o),
reason='output_guard') reason='output_guard')
_logger.info("Output variable %s required output_guard, " _logger.info("Output variable %s required output_guard, "
"how was this output left unprotected against " "how was this output left unprotected against "
"destructive operations?" "destructive operations?"
% o) % o)
except gof.InconsistencyError: except gof.InconsistencyError:
# This output is already impossible to destroy. # This output is already impossible to destroy.
# No guard necessary # No guard necessary
...@@ -188,50 +186,50 @@ class PrintCurrentFunctionGraph(gof.Optimizer): ...@@ -188,50 +186,50 @@ class PrintCurrentFunctionGraph(gof.Optimizer):
optdb = gof.SequenceDB() optdb = gof.SequenceDB()
optdb.register('merge1', gof.MergeOptimizer(), optdb.register('merge1', gof.MergeOptimizer(),
0, 'fast_run', 'fast_compile', 'merge') 0, 'fast_run', 'fast_compile', 'merge')
# rearranges elemwise expressions # rearranges elemwise expressions
optdb.register('canonicalize', gof.EquilibriumDB(), optdb.register('canonicalize', gof.EquilibriumDB(),
1, 'fast_run', 'fast_compile') 1, 'fast_run', 'fast_compile')
optdb.register('merge1.2', gof.MergeOptimizer(), optdb.register('merge1.2', gof.MergeOptimizer(),
1.2, 'fast_run', 'fast_compile', 'merge') 1.2, 'fast_run', 'fast_compile', 'merge')
optdb.register('Print1.21', PrintCurrentFunctionGraph('Post-canonicalize'), optdb.register('Print1.21', PrintCurrentFunctionGraph('Post-canonicalize'),
1.21,) # 'fast_run', 'fast_compile') 1.21,) # 'fast_run', 'fast_compile')
# replace unstable subgraphs # replace unstable subgraphs
optdb.register('stabilize', gof.EquilibriumDB(), optdb.register('stabilize', gof.EquilibriumDB(),
1.5, 'fast_run') 1.5, 'fast_run')
optdb.register('Print1.51', PrintCurrentFunctionGraph('Post-stabilize'), optdb.register('Print1.51', PrintCurrentFunctionGraph('Post-stabilize'),
1.51,) # 'fast_run', 'fast_compile') 1.51,) # 'fast_run', 'fast_compile')
# misc special cases for speed # misc special cases for speed
optdb.register('specialize', gof.EquilibriumDB(), optdb.register('specialize', gof.EquilibriumDB(),
2, 'fast_run', 'fast_compile_gpu') 2, 'fast_run', 'fast_compile_gpu')
# misc special cases for speed that break canonicalization # misc special cases for speed that break canonicalization
optdb.register('uncanonicalize', gof.EquilibriumDB(), optdb.register('uncanonicalize', gof.EquilibriumDB(),
3, 'fast_run') 3, 'fast_run')
# misc special cases for speed that are dependent on the device. # misc special cases for speed that are dependent on the device.
optdb.register('specialize_device', gof.EquilibriumDB(), optdb.register('specialize_device', gof.EquilibriumDB(),
48.6, 'fast_run') # must be after gpu stuff at 48.5 48.6, 'fast_run') # must be after gpu stuff at 48.5
# especially constant merge # especially constant merge
optdb.register('merge2', gof.MergeOptimizer(), optdb.register('merge2', gof.MergeOptimizer(),
49, 'fast_run', 'merge') 49, 'fast_run', 'merge')
optdb.register('add_no_output_from_inplace', AddNoOutputFromInplace(), optdb.register('add_no_output_from_inplace', AddNoOutputFromInplace(),
49.4) 49.4)
optdb.register('add_destroy_handler', AddDestroyHandler(), optdb.register('add_destroy_handler', AddDestroyHandler(),
49.5, 'fast_run', 'inplace') 49.5, 'fast_run', 'inplace')
# final pass just to make sure # final pass just to make sure
optdb.register('merge3', gof.MergeOptimizer(), optdb.register('merge3', gof.MergeOptimizer(),
100, 'fast_run', 'merge') 100, 'fast_run', 'merge')
class Mode(object): class Mode(object):
...@@ -287,7 +285,8 @@ class Mode(object): ...@@ -287,7 +285,8 @@ class Mode(object):
def __str__(self): def __str__(self):
return "%s(linker = %s, optimizer = %s)" % (self.__class__.__name__, return "%s(linker = %s, optimizer = %s)" % (self.__class__.__name__,
self.provided_linker, self.provided_optimizer) self.provided_linker,
self.provided_optimizer)
def __get_optimizer(self): def __get_optimizer(self):
if isinstance(self._optimizer, gof.Query): if isinstance(self._optimizer, gof.Query):
...@@ -306,19 +305,19 @@ class Mode(object): ...@@ -306,19 +305,19 @@ class Mode(object):
def including(self, *tags): def including(self, *tags):
link, opt = self.get_linker_optimizer(self.provided_linker, link, opt = self.get_linker_optimizer(self.provided_linker,
self.provided_optimizer) self.provided_optimizer)
# N.B. opt might be a Query instance, not sure what else it might be... # N.B. opt might be a Query instance, not sure what else it might be...
# string? Optimizer? OptDB? who knows??? # string? Optimizer? OptDB? who knows???
return self.__class__(linker=link, optimizer=opt.including(*tags)) return self.__class__(linker=link, optimizer=opt.including(*tags))
def excluding(self, *tags): def excluding(self, *tags):
link, opt = self.get_linker_optimizer(self.provided_linker, link, opt = self.get_linker_optimizer(self.provided_linker,
self.provided_optimizer) self.provided_optimizer)
return self.__class__(linker=link, optimizer=opt.excluding(*tags)) return self.__class__(linker=link, optimizer=opt.excluding(*tags))
def requiring(self, *tags): def requiring(self, *tags):
link, opt = self.get_linker_optimizer(self.provided_linker, link, opt = self.get_linker_optimizer(self.provided_linker,
self.provided_optimizer) self.provided_optimizer)
return self.__class__(linker=link, optimizer=opt.requiring(*tags)) return self.__class__(linker=link, optimizer=opt.requiring(*tags))
# If a string is passed as the mode argument in function or # If a string is passed as the mode argument in function or
...@@ -364,10 +363,11 @@ def get_mode(orig_string): ...@@ -364,10 +363,11 @@ def get_mode(orig_string):
# DebugMode use its own linker. # DebugMode use its own linker.
ret = DebugMode(optimizer=config.optimizer) ret = DebugMode(optimizer=config.optimizer)
else: else:
# The import is needed in case string is ProfileMode # This might be required if the string is 'ProfileMode'
from profilemode import ProfileMode, prof_mode_instance_to_print from profilemode import ProfileMode # noqa
ret = eval(string from profilemode import prof_mode_instance_to_print
+ '(linker=config.linker, optimizer=config.optimizer)') ret = eval(string +
'(linker=config.linker, optimizer=config.optimizer)')
elif string in predefined_modes: elif string in predefined_modes:
ret = predefined_modes[string] ret = predefined_modes[string]
else: else:
......
from __future__ import print_function from __future__ import print_function
# Note: this code was initially copied from the 'pyutools' package by its # Note: this code was initially copied from the 'pyutools' package by its
# original author, and re-licensed under Theano's license. # original author, and re-licensed under Theano's license.
import numpy
import theano import theano
from theano.compile.mode import Mode from theano.compile.mode import Mode
...@@ -48,7 +48,7 @@ class MonitorMode(Mode): ...@@ -48,7 +48,7 @@ class MonitorMode(Mode):
if optimizer == 'default': if optimizer == 'default':
optimizer = theano.config.optimizer optimizer = theano.config.optimizer
if (linker is not None and if (linker is not None and
not isinstance(linker.mode, MonitorMode)): not isinstance(linker.mode, MonitorMode)):
raise Exception("MonitorMode can only use its own linker! You " raise Exception("MonitorMode can only use its own linker! You "
"should not provide one.", linker) "should not provide one.", linker)
...@@ -86,7 +86,7 @@ class MonitorMode(Mode): ...@@ -86,7 +86,7 @@ class MonitorMode(Mode):
def detect_nan(i, node, fn): def detect_nan(i, node, fn):
for output in fn.outputs: for output in fn.outputs:
if (not isinstance(numpy.random.RandomState, output[0]) and if (not isinstance(numpy.random.RandomState, output[0]) and
numpy.isnan(output[0]).any()): numpy.isnan(output[0]).any()):
print('*** NaN detected ***') print('*** NaN detected ***')
theano.printing.debugprint(node) theano.printing.debugprint(node)
print('Inputs : %s' % [input[0] for input in fn.inputs]) print('Inputs : %s' % [input[0] for input in fn.inputs])
......
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论