提交 69f07b84 authored 作者: James Bergstra's avatar James Bergstra

added optional warn_type parameter to tensor.grad

上级 eb6c4c4b
...@@ -108,7 +108,7 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True): ...@@ -108,7 +108,7 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
if g_r and (getattr(r,'type',0) != getattr(g_r,'type', 1)): if g_r and (getattr(r,'type',0) != getattr(g_r,'type', 1)):
r_type = getattr(r,'type', None) r_type = getattr(r,'type', None)
g_r_type = getattr(g_r,'type', None) g_r_type = getattr(g_r,'type', None)
info('%s.grad returned a different type for input %i: %s vs. %s'%(node.op, ii, r_type, g_r_type)) warning('%s.grad returned a different type for input %i: %s vs. %s'%(node.op, ii, r_type, g_r_type))
if g_r and len(sources) == 1 and sources[0][0].name and r.name: if g_r and len(sources) == 1 and sources[0][0].name and r.name:
g_r.name = "(d%s/d%s)" % (sources[0][0].name, r.name) g_r.name = "(d%s/d%s)" % (sources[0][0].name, r.name)
if g_r is not None: if g_r is not None:
......
...@@ -2570,7 +2570,7 @@ outer = Outer() ...@@ -2570,7 +2570,7 @@ outer = Outer()
# Gradient # Gradient
######################### #########################
def grad(cost, wrt, g_cost=None, consider_constant=[]): def grad(cost, wrt, g_cost=None, consider_constant=[], warn_type=False):
""" """
:type cost: `Variable` :type cost: `Variable`
:type wrt: `Variable` or list of `Variable`s. :type wrt: `Variable` or list of `Variable`s.
...@@ -2579,6 +2579,9 @@ def grad(cost, wrt, g_cost=None, consider_constant=[]): ...@@ -2579,6 +2579,9 @@ def grad(cost, wrt, g_cost=None, consider_constant=[]):
``ones_like(cost)``. ``ones_like(cost)``.
:param consider_constant: a list of expressions not to backpropagate through :param consider_constant: a list of expressions not to backpropagate through
:param warn_type: a value of True will cause warnings to be logged for any Op that emits a
gradient that does not match its input type.
:rtype: `Variable` or list of `Variable`s (depending upon `wrt`) :rtype: `Variable` or list of `Variable`s (depending upon `wrt`)
:return: symbolic expression of gradient of `cost` with respect to `wrt`. :return: symbolic expression of gradient of `cost` with respect to `wrt`.
...@@ -2597,7 +2600,8 @@ def grad(cost, wrt, g_cost=None, consider_constant=[]): ...@@ -2597,7 +2600,8 @@ def grad(cost, wrt, g_cost=None, consider_constant=[]):
if g_cost is None: if g_cost is None:
g_cost = ones_like(cost) g_cost = ones_like(cost)
inputs = gof.graph.inputs([cost]) inputs = gof.graph.inputs([cost])
gmap = gradient.grad_sources_inputs([(cost, g_cost)], inputs + consider_constant) gmap = gradient.grad_sources_inputs([(cost, g_cost)], inputs + consider_constant,
warn_type=warn_type)
def zero(p): def zero(p):
return TensorConstant( return TensorConstant(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论