提交 d08563a8 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

grad_depth parameter for OpFromGraph

上级 c60e26be
......@@ -160,7 +160,7 @@ class T_OpFromGraph(unittest.TestCase):
def test_grad(self):
x, y, z = T.matrices('xyz')
e = x + y * z
op = OpFromGraph([x, y, z], [e], linker='c|py')
op = OpFromGraph([x, y, z], [e], linker='c|py', grad_depth = 2)
f = op(x, y, z)
f = f - T.grad(f, y)
fn = function([x, y, z], [f])
......
......@@ -194,6 +194,12 @@ class OpFromGraph(gof.Op):
unpack_single = False
borrow_outputs = False
OpFromGraph takes an additional input, grad_depth. If grad_depth
is n, OpFromGraph will make special Ops for gradients up to the
nth level, allowing the user to differentiate this op up to n
times. The parameter defaults to 1. If grad_depth == 0, the op
will not be differentiable.
Example:
x, y, z = tensor.scalars('xyz')
e = x + y * z
......@@ -203,8 +209,7 @@ class OpFromGraph(gof.Op):
fn = function([x, y, z], [e2])
"""
def __init__(self, inputs, outputs, **kwargs):
do_grad = kwargs.pop('do_grad', True)
def __init__(self, inputs, outputs, grad_depth = 1, **kwargs):
if kwargs.get('borrow_outputs') or kwargs.get('unpack_single'):
raise ValueError('The borrow_outputs and unpack_single options cannot be True')
kwargs['unpack_single'] = False
......@@ -214,15 +219,19 @@ class OpFromGraph(gof.Op):
self.outputs = outputs
self.input_types = [input.type for input in inputs]
self.output_types = [output.type for output in outputs]
if do_grad:
if grad_depth > 0:
import gradient as G
output_grads = [t() for t in self.output_types]
gd = G.grad_sources_inputs(zip(self.outputs, output_grads), self.inputs)
gs = map(gd.get, self.inputs)
self.grad_ops = [OpFromGraph(inputs + output_grads,
[g],
do_grad = False)
for g in gs]
self.grad_ops = []
for g in gs:
if g is None:
self.grad_ops.append(lambda *args: None)
else:
self.grad_ops.append(OpFromGraph(inputs + output_grads,
[g],
grad_depth = grad_depth - 1))
def make_node(self, *inputs):
for input, type in zip(inputs, self.input_types):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论