提交 71d77144 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fixed bug where taking the gradient of an expression with inplace ops

could result in an expression with cyclical dependencies
上级 b862796c
...@@ -543,6 +543,20 @@ def _populate_grad_dict(var_to_node_to_idx,\ ...@@ -543,6 +543,20 @@ def _populate_grad_dict(var_to_node_to_idx,\
if node not in term_dict: if node not in term_dict:
inputs = node.inputs inputs = node.inputs
def try_to_copy(var):
if hasattr(var,'copy'):
return var.copy()
return var
#inplace ops often have inplace in their expression for the gradient
#this can result in cyclical dependencies, ie there not being an order
#in which we can run all the resulting inplace ops without destroying
#some op's input before the time that it is needed
#to get around this, we try to symbolically copy all of the inputs
#so it is only the copy that is destroyed
inputs = [try_to_copy(ipt) for ipt in inputs ]
output_grads = [ access_grad_cache(var) for var in node.outputs ] output_grads = [ access_grad_cache(var) for var in node.outputs ]
input_grads = node.op.grad(inputs, output_grads) input_grads = node.op.grad(inputs, output_grads)
...@@ -1057,6 +1071,7 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, ...@@ -1057,6 +1071,7 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None,
symbolic_grad = grad(cost, tensor_pt, g_cost, symbolic_grad = grad(cost, tensor_pt, g_cost,
disconnected_inputs='ignore') disconnected_inputs='ignore')
grad_fn = function(tensor_pt, symbolic_grad) grad_fn = function(tensor_pt, symbolic_grad)
for test_num in xrange(n_tests): for test_num in xrange(n_tests):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论