提交 16add309 authored 作者: nouiz's avatar nouiz

Merge pull request #946 from goodfeli/type_check

add some type checking
......@@ -109,7 +109,7 @@ following methods:
elements of inputs[input_idx] have an effect on the elements of
outputs[output_idx].
The ``node'' parameter is needed to determine the number of
The ``node`` parameter is needed to determine the number of
inputs. Some ops such as Subtensor take a variable number of
inputs.
......@@ -159,7 +159,11 @@ following methods:
If the output is not differentiable with respect to an input
then this method should be defined to return a variable of type
NullType for that input.
NullType for that input. Likewise, if you have not implemented the
grad computation for some input, you may return a variable of type
NullType for that input. theano.gradient contains convenience methods
that can construct the variable for you: :func:`theano.gradient.grad_undefined` and
:func:`theano.gradient.grad_not_implemented`, respectively.
If an element of output_gradient is of type theano.gradient.DisconnectedType,
it means that the cost is not a function of this output. If any of the
......
......@@ -446,11 +446,21 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
raise TypeError('Elements of consider_constant must be '
'variables, but got ' + str(type(elem)))
if isinstance(wrt, set):
raise TypeError("wrt must not be a set. sets have no defined "
"iteration order, so we can't return gradients in a matching"
" order.")
using_list = isinstance(wrt, list)
using_tuple = isinstance(wrt, tuple)
if not using_list and not using_tuple:
wrt = [wrt]
for elem in wrt:
if not isinstance(elem, Variable):
raise TypeError("Expected Variable, got " + str(elem) +
" of type "+str(type(elem)))
var_to_node_to_idx = _populate_var_to_node_to_idx([cost], wrt)
# build a dict mapping var to the gradient of cost with respect to var
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论