提交 475e8ad9 authored 作者: Frederic's avatar Frederic

pep8

上级 c3d4ad81
...@@ -78,8 +78,7 @@ def grad_not_implemented(op, x_pos, x, comment=""): ...@@ -78,8 +78,7 @@ def grad_not_implemented(op, x_pos, x, comment=""):
gradient is not implemented. gradient is not implemented.
""" """
return (NullType( return (NullType((
(
"This variable is Null because the grad method for " "This variable is Null because the grad method for "
"input %s (%s) of the %s op is not implemented. %s" "input %s (%s) of the %s op is not implemented. %s"
) % (x_pos, x, op, comment)))() ) % (x_pos, x, op, comment)))()
...@@ -406,17 +405,16 @@ def grad(cost, wrt, consider_constant=None, ...@@ -406,17 +405,16 @@ def grad(cost, wrt, consider_constant=None,
if cost is not None and isinstance(cost.type, NullType): if cost is not None and isinstance(cost.type, NullType):
raise ValueError("Can't differentiate a NaN cost." raise ValueError("Can't differentiate a NaN cost."
"cost is NaN because " + \ "cost is NaN because " +
cost.type.why_null) cost.type.why_null)
if cost is not None and cost.ndim != 0: if cost is not None and cost.ndim != 0:
raise TypeError("cost must be a scalar.") raise TypeError("cost must be a scalar.")
if isinstance(wrt, set): if isinstance(wrt, set):
raise TypeError("wrt must not be a set. sets have no defined " raise TypeError("wrt must not be a set. sets have no defined "
"iteration order, so we can't return gradients in a matching" "iteration order, so we can't return gradients in a"
" order.") " matching order.")
using_list = isinstance(wrt, list) using_list = isinstance(wrt, list)
using_tuple = isinstance(wrt, tuple) using_tuple = isinstance(wrt, tuple)
...@@ -426,7 +424,7 @@ def grad(cost, wrt, consider_constant=None, ...@@ -426,7 +424,7 @@ def grad(cost, wrt, consider_constant=None,
for elem in wrt: for elem in wrt:
if not isinstance(elem, Variable): if not isinstance(elem, Variable):
raise TypeError("Expected Variable, got " + str(elem) + raise TypeError("Expected Variable, got " + str(elem) +
" of type "+str(type(elem))) " of type " + str(type(elem)))
outputs = [] outputs = []
if cost is not None: if cost is not None:
...@@ -452,7 +450,8 @@ def grad(cost, wrt, consider_constant=None, ...@@ -452,7 +450,8 @@ def grad(cost, wrt, consider_constant=None,
# g_cost may be Disconnected or NullType. A creative use of the function, # g_cost may be Disconnected or NullType. A creative use of the function,
# sure, but nonetheless one we can and should support. So before we try # sure, but nonetheless one we can and should support. So before we try
# to cast it make sure it even has a dtype # to cast it make sure it even has a dtype
if hasattr(g_cost.type, 'dtype') and cost.type.dtype not in tensor.discrete_dtypes: if (hasattr(g_cost.type, 'dtype') and
cost.type.dtype not in tensor.discrete_dtypes):
# Here we enforce the constraint that floating point variables have # Here we enforce the constraint that floating point variables have
# the same dtype as their gradient. # the same dtype as their gradient.
g_cost = g_cost.astype(cost.type.dtype) g_cost = g_cost.astype(cost.type.dtype)
...@@ -471,8 +470,8 @@ def grad(cost, wrt, consider_constant=None, ...@@ -471,8 +470,8 @@ def grad(cost, wrt, consider_constant=None,
'Ambiguous whether %s should be made into tensor' 'Ambiguous whether %s should be made into tensor'
' or sparse theano variable' % str(type(g_var))) ' or sparse theano variable' % str(type(g_var)))
if not isinstance(g_var.type, (NullType, DisconnectedType)) and 'float' \ if (not isinstance(g_var.type, (NullType, DisconnectedType)) and
not in str(g_var.type.dtype): 'float' not in str(g_var.type.dtype)):
raise TypeError("Gradients must always be NullType, " raise TypeError("Gradients must always be NullType, "
"DisconnectedType, or continuous, but grad was " "DisconnectedType, or continuous, but grad was "
"given a known_grad of type "+str(g_var.type)) "given a known_grad of type "+str(g_var.type))
...@@ -728,11 +727,13 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant): ...@@ -728,11 +727,13 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
return var_to_app_to_idx return var_to_app_to_idx
class NullTypeGradError(TypeError): class NullTypeGradError(TypeError):
""" """
Raised when grad encounters a NullType. Raised when grad encounters a NullType.
""" """
class DisconnectedInputError(ValueError): class DisconnectedInputError(ValueError):
""" """
Raised when grad is asked to compute the gradient Raised when grad is asked to compute the gradient
...@@ -740,6 +741,7 @@ class DisconnectedInputError(ValueError): ...@@ -740,6 +741,7 @@ class DisconnectedInputError(ValueError):
disconnected_inputs='raise'. disconnected_inputs='raise'.
""" """
def _populate_grad_dict(var_to_app_to_idx, def _populate_grad_dict(var_to_app_to_idx,
grad_dict, wrt, cost_name=None): grad_dict, wrt, cost_name=None):
""" """
...@@ -902,7 +904,7 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -902,7 +904,7 @@ def _populate_grad_dict(var_to_app_to_idx,
"expected iterable." % str(node.op)) "expected iterable." % str(node.op))
if len(input_grads) != len(inputs): if len(input_grads) != len(inputs):
raise ValueError(("%s returned the wrong number of" +\ raise ValueError(("%s returned the wrong number of" +
" gradient terms.") % str(node.op)) " gradient terms.") % str(node.op))
# must convert to list in case the op returns a tuple # must convert to list in case the op returns a tuple
...@@ -926,7 +928,7 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -926,7 +928,7 @@ def _populate_grad_dict(var_to_app_to_idx,
# used to mean undefined, zero, or disconnected. # used to mean undefined, zero, or disconnected.
# We therefore don't allow it because its usage has become # We therefore don't allow it because its usage has become
# so muddied. # so muddied.
raise TypeError(('%s.grad returned None for' +\ raise TypeError(('%s.grad returned None for' +
' a gradient term, ' ' a gradient term, '
'this is prohibited. Instead of None,' 'this is prohibited. Instead of None,'
'return zeros_like(input), DisconnectedType()(),' 'return zeros_like(input), DisconnectedType()(),'
...@@ -980,8 +982,8 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -980,8 +982,8 @@ def _populate_grad_dict(var_to_app_to_idx,
#Check that op.connection_pattern matches the connectivity #Check that op.connection_pattern matches the connectivity
#logic driving the op.grad method #logic driving the op.grad method
for i, packed in \ for i, packed in enumerate(zip(inputs, input_grads,
enumerate(zip(inputs, input_grads, inputs_connected)): inputs_connected)):
ipt, ig, connected = packed ipt, ig, connected = packed
actually_connected = \ actually_connected = \
not isinstance(ig.type, DisconnectedType) not isinstance(ig.type, DisconnectedType)
...@@ -1031,7 +1033,7 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1031,7 +1033,7 @@ def _populate_grad_dict(var_to_app_to_idx,
if isinstance(term.type, NullType): if isinstance(term.type, NullType):
raise NullTypeGradError("tensor.grad " raise NullTypeGradError("tensor.grad "
"encountered a NaN. " +\ "encountered a NaN. " +
term.type.why_null) term.type.why_null)
#Don't try to sum up DisconnectedType placeholders #Don't try to sum up DisconnectedType placeholders
...@@ -1243,14 +1245,12 @@ class numeric_grad(object): ...@@ -1243,14 +1245,12 @@ class numeric_grad(object):
""" """
if len(g_pt) != len(self.gf): if len(g_pt) != len(self.gf):
raise ValueError( raise ValueError('argument has wrong number of elements',
'argument has wrong number of elements',
len(g_pt)) len(g_pt))
errs = [] errs = []
for i, (a, b) in enumerate(zip(g_pt, self.gf)): for i, (a, b) in enumerate(zip(g_pt, self.gf)):
if a.shape != b.shape: if a.shape != b.shape:
raise ValueError( raise ValueError('argument element %i has wrong shape %s' % (
'argument element %i has wrong shape %s' % (
i, str((a.shape, b.shape)))) i, str((a.shape, b.shape))))
errs.append(numeric_grad.abs_rel_err(a, b)) errs.append(numeric_grad.abs_rel_err(a, b))
return errs return errs
...@@ -1368,7 +1368,8 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, ...@@ -1368,7 +1368,8 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None,
def function(inputs, output): def function(inputs, output):
if mode is None: if mode is None:
f = compile.function(inputs, output, accept_inplace=True, f = compile.function(inputs, output, accept_inplace=True,
allow_input_downcast=True, on_unused_input='ignore') allow_input_downcast=True,
on_unused_input='ignore')
else: else:
f = compile.function(inputs, output, accept_inplace=True, f = compile.function(inputs, output, accept_inplace=True,
allow_input_downcast=True, mode=mode, allow_input_downcast=True, mode=mode,
......
...@@ -561,6 +561,7 @@ class test_Eigh(test_Eig): ...@@ -561,6 +561,7 @@ class test_Eigh(test_Eig):
class test_Eigh_float32(test_Eigh): class test_Eigh_float32(test_Eigh):
dtype = 'float32' dtype = 'float32'
def test_matrix_inverse_solve(): def test_matrix_inverse_solve():
if not imported_scipy: if not imported_scipy:
raise SkipTest("Scipy needed for the Solve op.") raise SkipTest("Scipy needed for the Solve op.")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论