提交 dd801b30 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

replaced zeros_like(x) with x.zeros_like() in order to support sparse

types
上级 0db5749e
......@@ -512,7 +512,7 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore
#the gradient of the constants is 0
for const in consider_constant:
grad_dict[const] = tensor.zeros_like(const)
grad_dict[const] = const.zeros_like()
#variables that do not influence the cost have zero gradient.
#if wrt is such a variable, populate the grad_dict with this info
......@@ -552,7 +552,7 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore
term_dict[node] = list(input_grads)
for i in xrange(len(term_dict[node])):
if term_dict[node][i] is None:
term_dict[node][i] = tensor.zeros_like(node.inputs[i])
term_dict[node][i] = node.inputs[i].zeros_like()
return term_dict[node]
......@@ -586,7 +586,7 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore
else:
#this variable is not connected to the cost in the computational
#graph so the gradient on it is zero
grad_dict[var] = tensor.zeros_like(var)
grad_dict[var] = var.zeros_like()
return grad_dict[var]
......@@ -662,7 +662,7 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = 'ignored'):
[access_grad_cache(var) for var in node.outputs]))
for i in xrange(len(term_dict[node])):
if term_dict[node][i] is None:
term_dict[node][i] = tensor.zeros_like(node.inputs[i])
term_dict[node][i] = node.inputs[i].zeros_like()
return term_dict[node]
......@@ -694,7 +694,7 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = 'ignored'):
else:
#this variable is not connected to the cost in the computational
#graph so the gradient on it is zero
grad_dict[var] = tensor.zeros_like(var)
grad_dict[var] = var.zeros_like()
return grad_dict[var]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论