提交 dd801b30 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

replaced zeros_like(x) with x.zeros_like() in order to support sparse

types
上级 0db5749e
...@@ -512,7 +512,7 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore ...@@ -512,7 +512,7 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore
#the gradient of the constants is 0 #the gradient of the constants is 0
for const in consider_constant: for const in consider_constant:
grad_dict[const] = tensor.zeros_like(const) grad_dict[const] = const.zeros_like()
#variables that do not influence the cost have zero gradient. #variables that do not influence the cost have zero gradient.
#if wrt is such a variable, populate the grad_dict with this info #if wrt is such a variable, populate the grad_dict with this info
...@@ -552,7 +552,7 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore ...@@ -552,7 +552,7 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore
term_dict[node] = list(input_grads) term_dict[node] = list(input_grads)
for i in xrange(len(term_dict[node])): for i in xrange(len(term_dict[node])):
if term_dict[node][i] is None: if term_dict[node][i] is None:
term_dict[node][i] = tensor.zeros_like(node.inputs[i]) term_dict[node][i] = node.inputs[i].zeros_like()
return term_dict[node] return term_dict[node]
...@@ -586,7 +586,7 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore ...@@ -586,7 +586,7 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore
else: else:
#this variable is not connected to the cost in the computational #this variable is not connected to the cost in the computational
#graph so the gradient on it is zero #graph so the gradient on it is zero
grad_dict[var] = tensor.zeros_like(var) grad_dict[var] = var.zeros_like()
return grad_dict[var] return grad_dict[var]
...@@ -662,7 +662,7 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = 'ignored'): ...@@ -662,7 +662,7 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = 'ignored'):
[access_grad_cache(var) for var in node.outputs])) [access_grad_cache(var) for var in node.outputs]))
for i in xrange(len(term_dict[node])): for i in xrange(len(term_dict[node])):
if term_dict[node][i] is None: if term_dict[node][i] is None:
term_dict[node][i] = tensor.zeros_like(node.inputs[i]) term_dict[node][i] = node.inputs[i].zeros_like()
return term_dict[node] return term_dict[node]
...@@ -694,7 +694,7 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = 'ignored'): ...@@ -694,7 +694,7 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = 'ignored'):
else: else:
#this variable is not connected to the cost in the computational #this variable is not connected to the cost in the computational
#graph so the gradient on it is zero #graph so the gradient on it is zero
grad_dict[var] = tensor.zeros_like(var) grad_dict[var] = var.zeros_like()
return grad_dict[var] return grad_dict[var]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论