提交 234f8629 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Merged

......@@ -283,6 +283,11 @@ class Function(object):
#def assign(c, v):
#c.data = v
# Store the list of names of named inputs.
named_inputs = []
# Count the number of un-named inputs.
n_unnamed_inputs = 0
#setters = []
# Initialize the storage
# this loop works by modifying the elements (as variable c) of self.input_storage inplace.
......@@ -312,6 +317,10 @@ class Function(object):
finder[input.name] = c
else:
finder[input.name] = DUPLICATE
if input.name is None:
n_unnamed_inputs += 1
else:
named_inputs.append(input.name)
#backport
#finder[input.name] = c if input.name not in finder else DUPLICATE
# inv_finder maps the container to the input (useful for one error message)
......@@ -378,7 +387,9 @@ class Function(object):
try:
s = finder[item]
except KeyError:
raise TypeError("Unknown input or state: %s" % item)
# Print informative error message.
msg = get_info_on_inputs(named_inputs, n_unnamed_inputs)
raise TypeError("Unknown input or state: %s. %s" % (item, msg))
if s is DUPLICATE:
raise TypeError("Ambiguous name: %s - please check the names of the inputs of your function for duplicates." % item)
if isinstance(s, gof.Container):
......@@ -1014,3 +1025,43 @@ def convert_function_input(input):
else:
raise TypeError("Unknown input type: %s, expected Variable instance" % type(input), input)
def get_info_on_inputs(named_inputs, n_unnamed_inputs):
"""Return a human-readable description of named and un-named inputs."""
n_named_inputs = len(named_inputs)
def get_plural(n):
if n > 1:
return 's'
else:
return ''
if n_named_inputs == 0:
if n_unnamed_inputs == 0:
msg = 'The function is supposed to have no input.'
else:
if n_unnamed_inputs == 1:
msg = ("The function has a single input variable which has no "
"name, and thus cannot be assigned through a keyword"
" argument (use 'name=...' in a Variable's "
"constructor to give it a name).")
else:
# Use plural.
msg = ("The function has %s inputs, but none of them is named,"
" and thus they cannot be assigned through keyword "
"arguments (use 'name=...' in a Variable's "
"constructor to give it a name)." % n_unnamed_inputs)
else:
if n_unnamed_inputs == 0:
msg = ("The function has %s named input%s (%s)." % (
n_named_inputs, get_plural(n_named_inputs),
', '.join(named_inputs)))
else:
msg = ("The function has %s named input%s (%s), and %s unnamed "
"input%s which thus cannot be accessed through keyword "
"argument%s (use 'name=...' in a variable's constructor "
"to give it a name)." % (
n_named_inputs, get_plural(n_named_inputs),
', '.join(named_inputs), n_unnamed_inputs,
get_plural(n_unnamed_inputs),
get_plural(n_unnamed_inputs)))
return msg
......@@ -3452,8 +3452,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False):
:return: symbolic expression of gradient of `cost` with respect to `wrt`.
If `wrt` is a list, then return a list containing the gradient of `cost` wrt
each element of the list. If an element of `wrt` is not differentiable
with respect to the output, then a `TensorConstant` with an appropriate
kind of zero is returned.
with respect to the output, then a zero variable is returned.
This function is a wrapper around a the more general function
`theano.gradient.grad_sources_inputs``.
......@@ -3473,21 +3472,13 @@ def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False):
gmap = gradient.grad_sources_inputs([(cost, g_cost)], inputs + consider_constant,
warn_type=warn_type)
def zero(p):
return TensorConstant(
TensorType(dtype = p.type.dtype, broadcastable = []),
theano._asarray(0, dtype=p.type.dtype))
#try:
#it = iter(wrt)
#except:
#it = None
#if it: #hasattr(wrt, '__iter__'): # isinstance(wrt, (list, tuple)):
# Note that it is important to use `zeros_like` when there is no gradient,
# instead of returning a scalar constant equal to zero. Otherwise we lose
# the guarantee that the gradient has same shape as `wrt`.
if isinstance(wrt, (list, tuple)):
return [gmap.get(p, zero(p)) for p in wrt]
return [gmap.get(p, zeros_like(p)) for p in wrt]
else:
return gmap.get(wrt, zero(wrt))
return gmap.get(wrt, zeros_like(wrt))
class numeric_grad:
"""WRITEME"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论