提交 d08563a8 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

grad_depth parameter for OpFromGraph

上级 c60e26be
...@@ -160,7 +160,7 @@ class T_OpFromGraph(unittest.TestCase): ...@@ -160,7 +160,7 @@ class T_OpFromGraph(unittest.TestCase):
def test_grad(self): def test_grad(self):
x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
e = x + y * z e = x + y * z
op = OpFromGraph([x, y, z], [e], linker='c|py') op = OpFromGraph([x, y, z], [e], linker='c|py', grad_depth = 2)
f = op(x, y, z) f = op(x, y, z)
f = f - T.grad(f, y) f = f - T.grad(f, y)
fn = function([x, y, z], [f]) fn = function([x, y, z], [f])
......
...@@ -194,6 +194,12 @@ class OpFromGraph(gof.Op): ...@@ -194,6 +194,12 @@ class OpFromGraph(gof.Op):
unpack_single = False unpack_single = False
borrow_outputs = False borrow_outputs = False
OpFromGraph takes an additional input, grad_depth. If grad_depth
is n, OpFromGraph will make special Ops for gradients up to the
nth level, allowing the user to differentiate this op up to n
times. The parameter defaults to 1. If grad_depth == 0, the op
will not be differentiable.
Example: Example:
x, y, z = tensor.scalars('xyz') x, y, z = tensor.scalars('xyz')
e = x + y * z e = x + y * z
...@@ -203,8 +209,7 @@ class OpFromGraph(gof.Op): ...@@ -203,8 +209,7 @@ class OpFromGraph(gof.Op):
fn = function([x, y, z], [e2]) fn = function([x, y, z], [e2])
""" """
def __init__(self, inputs, outputs, **kwargs): def __init__(self, inputs, outputs, grad_depth = 1, **kwargs):
do_grad = kwargs.pop('do_grad', True)
if kwargs.get('borrow_outputs') or kwargs.get('unpack_single'): if kwargs.get('borrow_outputs') or kwargs.get('unpack_single'):
raise ValueError('The borrow_outputs and unpack_single options cannot be True') raise ValueError('The borrow_outputs and unpack_single options cannot be True')
kwargs['unpack_single'] = False kwargs['unpack_single'] = False
...@@ -214,15 +219,19 @@ class OpFromGraph(gof.Op): ...@@ -214,15 +219,19 @@ class OpFromGraph(gof.Op):
self.outputs = outputs self.outputs = outputs
self.input_types = [input.type for input in inputs] self.input_types = [input.type for input in inputs]
self.output_types = [output.type for output in outputs] self.output_types = [output.type for output in outputs]
if do_grad: if grad_depth > 0:
import gradient as G import gradient as G
output_grads = [t() for t in self.output_types] output_grads = [t() for t in self.output_types]
gd = G.grad_sources_inputs(zip(self.outputs, output_grads), self.inputs) gd = G.grad_sources_inputs(zip(self.outputs, output_grads), self.inputs)
gs = map(gd.get, self.inputs) gs = map(gd.get, self.inputs)
self.grad_ops = [OpFromGraph(inputs + output_grads, self.grad_ops = []
for g in gs:
if g is None:
self.grad_ops.append(lambda *args: None)
else:
self.grad_ops.append(OpFromGraph(inputs + output_grads,
[g], [g],
do_grad = False) grad_depth = grad_depth - 1))
for g in gs]
def make_node(self, *inputs): def make_node(self, *inputs):
for input, type in zip(inputs, self.input_types): for input, type in zip(inputs, self.input_types):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论