提交 8353b754 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

added error checking of consider_constant

上级 69e3825b
...@@ -269,6 +269,21 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False, ...@@ -269,6 +269,21 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
""" """
if consider_constant is None: if consider_constant is None:
consider_constant = [] consider_constant = []
else:
#error checking on consider_constant: verify that it is a collection
# of theano variables
# this is important, if someone accidentally passes a nested data
# structure with theano variables at the leaves, only the root will
# be properly considered constant
if not hasattr(consider_constant, '__iter__'):
raise TypeError('consider_constant must be an iterable collection,'
' got '+str(type(consider_constant)))
for elem in consider_constant:
if not isinstance(elem, gof.Variable):
raise TypeError('Elements of consider_constant must be variables,'
'but got '+str(type(elem)))
if not isinstance(cost, TensorVariable): if not isinstance(cost, TensorVariable):
raise TypeError('In tensor.grad(), cost argument should be a TensorVariable.', cost) raise TypeError('In tensor.grad(), cost argument should be a TensorVariable.', cost)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论