提交 234f8629 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Merged

...@@ -283,6 +283,11 @@ class Function(object): ...@@ -283,6 +283,11 @@ class Function(object):
#def assign(c, v): #def assign(c, v):
#c.data = v #c.data = v
# Store the list of names of named inputs.
named_inputs = []
# Count the number of un-named inputs.
n_unnamed_inputs = 0
#setters = [] #setters = []
# Initialize the storage # Initialize the storage
# this loop works by modifying the elements (as variable c) of self.input_storage inplace. # this loop works by modifying the elements (as variable c) of self.input_storage inplace.
...@@ -312,6 +317,10 @@ class Function(object): ...@@ -312,6 +317,10 @@ class Function(object):
finder[input.name] = c finder[input.name] = c
else: else:
finder[input.name] = DUPLICATE finder[input.name] = DUPLICATE
if input.name is None:
n_unnamed_inputs += 1
else:
named_inputs.append(input.name)
#backport #backport
#finder[input.name] = c if input.name not in finder else DUPLICATE #finder[input.name] = c if input.name not in finder else DUPLICATE
# inv_finder maps the container to the input (useful for one error message) # inv_finder maps the container to the input (useful for one error message)
...@@ -378,7 +387,9 @@ class Function(object): ...@@ -378,7 +387,9 @@ class Function(object):
try: try:
s = finder[item] s = finder[item]
except KeyError: except KeyError:
raise TypeError("Unknown input or state: %s" % item) # Print informative error message.
msg = get_info_on_inputs(named_inputs, n_unnamed_inputs)
raise TypeError("Unknown input or state: %s. %s" % (item, msg))
if s is DUPLICATE: if s is DUPLICATE:
raise TypeError("Ambiguous name: %s - please check the names of the inputs of your function for duplicates." % item) raise TypeError("Ambiguous name: %s - please check the names of the inputs of your function for duplicates." % item)
if isinstance(s, gof.Container): if isinstance(s, gof.Container):
...@@ -1014,3 +1025,43 @@ def convert_function_input(input): ...@@ -1014,3 +1025,43 @@ def convert_function_input(input):
else: else:
raise TypeError("Unknown input type: %s, expected Variable instance" % type(input), input) raise TypeError("Unknown input type: %s, expected Variable instance" % type(input), input)
def get_info_on_inputs(named_inputs, n_unnamed_inputs):
"""Return a human-readable description of named and un-named inputs."""
n_named_inputs = len(named_inputs)
def get_plural(n):
if n > 1:
return 's'
else:
return ''
if n_named_inputs == 0:
if n_unnamed_inputs == 0:
msg = 'The function is supposed to have no input.'
else:
if n_unnamed_inputs == 1:
msg = ("The function has a single input variable which has no "
"name, and thus cannot be assigned through a keyword"
" argument (use 'name=...' in a Variable's "
"constructor to give it a name).")
else:
# Use plural.
msg = ("The function has %s inputs, but none of them is named,"
" and thus they cannot be assigned through keyword "
"arguments (use 'name=...' in a Variable's "
"constructor to give it a name)." % n_unnamed_inputs)
else:
if n_unnamed_inputs == 0:
msg = ("The function has %s named input%s (%s)." % (
n_named_inputs, get_plural(n_named_inputs),
', '.join(named_inputs)))
else:
msg = ("The function has %s named input%s (%s), and %s unnamed "
"input%s which thus cannot be accessed through keyword "
"argument%s (use 'name=...' in a variable's constructor "
"to give it a name)." % (
n_named_inputs, get_plural(n_named_inputs),
', '.join(named_inputs), n_unnamed_inputs,
get_plural(n_unnamed_inputs),
get_plural(n_unnamed_inputs)))
return msg
...@@ -3452,8 +3452,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False): ...@@ -3452,8 +3452,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False):
:return: symbolic expression of gradient of `cost` with respect to `wrt`. :return: symbolic expression of gradient of `cost` with respect to `wrt`.
If `wrt` is a list, then return a list containing the gradient of `cost` wrt If `wrt` is a list, then return a list containing the gradient of `cost` wrt
each element of the list. If an element of `wrt` is not differentiable each element of the list. If an element of `wrt` is not differentiable
with respect to the output, then a `TensorConstant` with an appropriate with respect to the output, then a zero variable is returned.
kind of zero is returned.
This function is a wrapper around a the more general function This function is a wrapper around a the more general function
`theano.gradient.grad_sources_inputs``. `theano.gradient.grad_sources_inputs``.
...@@ -3473,21 +3472,13 @@ def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False): ...@@ -3473,21 +3472,13 @@ def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False):
gmap = gradient.grad_sources_inputs([(cost, g_cost)], inputs + consider_constant, gmap = gradient.grad_sources_inputs([(cost, g_cost)], inputs + consider_constant,
warn_type=warn_type) warn_type=warn_type)
def zero(p): # Note that it is important to use `zeros_like` when there is no gradient,
return TensorConstant( # instead of returning a scalar constant equal to zero. Otherwise we lose
TensorType(dtype = p.type.dtype, broadcastable = []), # the guarantee that the gradient has same shape as `wrt`.
theano._asarray(0, dtype=p.type.dtype))
#try:
#it = iter(wrt)
#except:
#it = None
#if it: #hasattr(wrt, '__iter__'): # isinstance(wrt, (list, tuple)):
if isinstance(wrt, (list, tuple)): if isinstance(wrt, (list, tuple)):
return [gmap.get(p, zero(p)) for p in wrt] return [gmap.get(p, zeros_like(p)) for p in wrt]
else: else:
return gmap.get(wrt, zero(wrt)) return gmap.get(wrt, zeros_like(wrt))
class numeric_grad: class numeric_grad:
"""WRITEME""" """WRITEME"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论