提交 c4b8a8d6 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Raises an Exception when trying to compute the gradient of a non-scalar cost.

Closes #538 .
上级 060fc688
...@@ -3978,9 +3978,9 @@ outer = Outer() ...@@ -3978,9 +3978,9 @@ outer = Outer()
def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False): def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False):
""" """
:type cost: `Variable` :type cost: Scalar (0-dimensional) `Variable`
:type wrt: `Variable` or list of `Variable`s. :type wrt: `Variable` or list of `Variable`s.
:type g_cost: `Variable` broadcastable to size of `cost`, or None :type g_cost: Scalar `Variable`, or None
:param g_cost: an expression for the gradient through cost. The default is :param g_cost: an expression for the gradient through cost. The default is
``ones_like(cost)``. ``ones_like(cost)``.
:param consider_constant: a list of expressions not to backpropagate through :param consider_constant: a list of expressions not to backpropagate through
...@@ -4003,9 +4003,11 @@ def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False): ...@@ -4003,9 +4003,11 @@ def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False):
raise TypeError('In tensor.grad(), cost argument should be a TensorVariable.', cost) raise TypeError('In tensor.grad(), cost argument should be a TensorVariable.', cost)
if cost.type.ndim: if cost.type.ndim:
_warn('the passing of a non-scalar cost to theano.tensor.grad() is deprecated.' raise TypeError(
' Use the lower-level ' 'In tensor.grad, "cost" argument should be a scalar, but ndim'
'theano.gradient if you really want to do this') ' is %i (should be 0). If you want to compute the gradient of'
' the sum of cost, you should use cost.sum().'
% cost.type.ndim)
if g_cost is None: if g_cost is None:
g_cost = ones_like(cost) g_cost = ones_like(cost)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论