提交 d0c3feb7 authored 作者: David Warde-Farley's avatar David Warde-Farley 提交者: Arnaud Bergeron

Use Callable abc where appropriate.

上级 a5f511a2
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
from theano.compat.six import PY3, b, BytesIO, next from theano.compat.six import PY3, b, BytesIO, next
from theano.compat.six.moves import configparser from theano.compat.six.moves import configparser
from theano.compat.six.moves import reload_module as reload from theano.compat.six.moves import reload_module as reload
import collections
__all__ = ['PY3', 'b', 'BytesIO', 'next', 'configparser', 'reload'] __all__ = ['PY3', 'b', 'BytesIO', 'next', 'configparser', 'reload']
...@@ -77,7 +78,7 @@ __all__ += ['cmp', 'operator_div', 'partial', 'defaultdict', 'deque', ...@@ -77,7 +78,7 @@ __all__ += ['cmp', 'operator_div', 'partial', 'defaultdict', 'deque',
class DefaultOrderedDict(OrderedDict): class DefaultOrderedDict(OrderedDict):
def __init__(self, default_factory=None, *a, **kw): def __init__(self, default_factory=None, *a, **kw):
if (default_factory is not None and if (default_factory is not None and
not callable(default_factory)): not isinstance(default_factory, collections.Callable)):
raise TypeError('first argument must be callable') raise TypeError('first argument must be callable')
OrderedDict.__init__(self, *a, **kw) OrderedDict.__init__(self, *a, **kw)
self.default_factory = default_factory self.default_factory = default_factory
......
...@@ -13,6 +13,7 @@ from theano.compat.six.moves import xrange ...@@ -13,6 +13,7 @@ from theano.compat.six.moves import xrange
import numpy import numpy
import collections
def register_view_op_c_code(type, code, version=()): def register_view_op_c_code(type, code, version=()):
...@@ -569,7 +570,7 @@ def as_op(itypes, otypes, infer_shape=None): ...@@ -569,7 +570,7 @@ def as_op(itypes, otypes, infer_shape=None):
itypes = list(itypes) itypes = list(itypes)
otypes = list(otypes) otypes = list(otypes)
if infer_shape is not None and not callable(infer_shape): if infer_shape is not None and not isinstance(infer_shape, collections.Callable):
raise TypeError("infer_shape needs to be a callable") raise TypeError("infer_shape needs to be a callable")
def make_op(fn): def make_op(fn):
......
...@@ -14,6 +14,7 @@ from theano.compat.six import StringIO ...@@ -14,6 +14,7 @@ from theano.compat.six import StringIO
import theano import theano
from theano.compat import configparser as ConfigParser from theano.compat import configparser as ConfigParser
import collections
_logger = logging.getLogger('theano.configparser') _logger = logging.getLogger('theano.configparser')
...@@ -265,7 +266,7 @@ def AddConfigVar(name, doc, configparam, root=config, in_c_key=True): ...@@ -265,7 +266,7 @@ def AddConfigVar(name, doc, configparam, root=config, in_c_key=True):
configparam.in_c_key = in_c_key configparam.in_c_key = in_c_key
# Trigger a read of the value from config files and env vars # Trigger a read of the value from config files and env vars
# This allow to filter wrong value from the user. # This allow to filter wrong value from the user.
if not callable(configparam.default): if not isinstance(configparam.default, collections.Callable):
configparam.__get__() configparam.__get__()
else: else:
# We do not want to evaluate now the default value # We do not want to evaluate now the default value
...@@ -309,7 +310,7 @@ class ConfigParam(object): ...@@ -309,7 +310,7 @@ class ConfigParam(object):
for v in self.default(): for v in self.default():
val_str = v val_str = v
self.__set__(None, val_str) self.__set__(None, val_str)
elif callable(self.default): elif isinstance(self.default, collections.Callable):
val_str = self.default() val_str = self.default()
else: else:
val_str = self.default val_str = self.default
...@@ -365,7 +366,7 @@ class TypedParam(ConfigParam): ...@@ -365,7 +366,7 @@ class TypedParam(ConfigParam):
def filter(val): def filter(val):
cast_val = mytype(val) cast_val = mytype(val)
if callable(is_valid): if isinstance(is_valid, collections.Callable):
if is_valid(cast_val): if is_valid(cast_val):
return cast_val return cast_val
else: else:
......
...@@ -12,6 +12,7 @@ import warnings ...@@ -12,6 +12,7 @@ import warnings
import hashlib import hashlib
import numpy as np import numpy as np
import collections
try: try:
import pydot as pd import pydot as pd
...@@ -210,7 +211,7 @@ N.B.: ...@@ -210,7 +211,7 @@ N.B.:
def _print_fn(op, xin): def _print_fn(op, xin):
for attr in op.attrs: for attr in op.attrs:
temp = getattr(xin, attr) temp = getattr(xin, attr)
if callable(temp): if isinstance(temp, collections.Callable):
pmsg = temp() pmsg = temp()
else: else:
pmsg = temp pmsg = temp
......
...@@ -2,6 +2,7 @@ from __future__ import print_function ...@@ -2,6 +2,7 @@ from __future__ import print_function
import copy, inspect import copy, inspect
import theano import theano
import theano.tensor as T import theano.tensor as T
import collections
#import klass #import klass
...@@ -36,7 +37,7 @@ class InitGraph(type): ...@@ -36,7 +37,7 @@ class InitGraph(type):
# print ' adding class attribute', key # print ' adding class attribute', key
if isinstance(val, theano.Variable) and val.name is None: if isinstance(val, theano.Variable) and val.name is None:
val.name = key val.name = key
if callable(val): if isinstance(val, collections.Callable):
setattr(cls, key, staticmethod(val)) setattr(cls, key, staticmethod(val))
else: else:
setattr(cls, key, val) setattr(cls, key, val)
...@@ -317,7 +318,7 @@ if 0: ...@@ -317,7 +318,7 @@ if 0:
except Exception: except Exception:
kres = klass.KlassVariable(val) kres = klass.KlassVariable(val)
setattr(SymMod, key, kres) setattr(SymMod, key, kres)
elif callable(val) and getattr(val, '__is_symbolic'): elif isinstance(val, collections.Callable) and getattr(val, '__is_symbolic'):
setattr(SymMod, key, val) setattr(SymMod, key, val)
return SymMod() return SymMod()
......
...@@ -33,6 +33,7 @@ from theano.gradient import DisconnectedType ...@@ -33,6 +33,7 @@ from theano.gradient import DisconnectedType
from theano.gradient import grad_undefined from theano.gradient import grad_undefined
from theano.printing import pprint from theano.printing import pprint
import collections
builtin_complex = complex builtin_complex = complex
builtin_int = int builtin_int = int
...@@ -837,7 +838,7 @@ class ScalarOp(Op): ...@@ -837,7 +838,7 @@ class ScalarOp(Op):
def __init__(self, output_types_preference=None, name=None): def __init__(self, output_types_preference=None, name=None):
self.name = name self.name = name
if output_types_preference is not None: if output_types_preference is not None:
if not callable(output_types_preference): if not isinstance(output_types_preference, collections.Callable):
raise TypeError( raise TypeError(
"Expected a callable for the 'output_types_preference' argument to %s. (got: %s)" % "Expected a callable for the 'output_types_preference' argument to %s. (got: %s)" %
(self.__class__, output_types_preference)) (self.__class__, output_types_preference))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论