提交 c60e26be authored 作者: Olivier Breuleux's avatar Olivier Breuleux

added grad to OpFromGraph

上级 f5f99c92
...@@ -156,6 +156,16 @@ class T_OpFromGraph(unittest.TestCase): ...@@ -156,6 +156,16 @@ class T_OpFromGraph(unittest.TestCase):
res = fn(xv, yv, zv) res = fn(xv, yv, zv)
assert res.shape == (2, 5) assert res.shape == (2, 5)
assert numpy.all(180.0 == res) assert numpy.all(180.0 == res)
def test_grad(self):
x, y, z = T.matrices('xyz')
e = x + y * z
op = OpFromGraph([x, y, z], [e], linker='c|py')
f = op(x, y, z)
f = f - T.grad(f, y)
fn = function([x, y, z], [f])
xv, yv, zv = N.ones((2, 2)), N.ones((2, 2))*3, N.ones((2, 2))*5
assert numpy.all(11.0 == fn(xv, yv, zv))
......
...@@ -204,13 +204,25 @@ class OpFromGraph(gof.Op): ...@@ -204,13 +204,25 @@ class OpFromGraph(gof.Op):
""" """
def __init__(self, inputs, outputs, **kwargs): def __init__(self, inputs, outputs, **kwargs):
do_grad = kwargs.pop('do_grad', True)
if kwargs.get('borrow_outputs') or kwargs.get('unpack_single'): if kwargs.get('borrow_outputs') or kwargs.get('unpack_single'):
raise ValueError('The borrow_outputs and unpack_single options cannot be True') raise ValueError('The borrow_outputs and unpack_single options cannot be True')
kwargs['unpack_single'] = False kwargs['unpack_single'] = False
kwargs['borrow_outputs'] = False kwargs['borrow_outputs'] = False
self.fn = function(inputs, outputs, **kwargs) self.fn = function(inputs, outputs, **kwargs)
self.inputs = inputs
self.outputs = outputs
self.input_types = [input.type for input in inputs] self.input_types = [input.type for input in inputs]
self.output_types = [output.type for output in outputs] self.output_types = [output.type for output in outputs]
if do_grad:
import gradient as G
output_grads = [t() for t in self.output_types]
gd = G.grad_sources_inputs(zip(self.outputs, output_grads), self.inputs)
gs = map(gd.get, self.inputs)
self.grad_ops = [OpFromGraph(inputs + output_grads,
[g],
do_grad = False)
for g in gs]
def make_node(self, *inputs): def make_node(self, *inputs):
for input, type in zip(inputs, self.input_types): for input, type in zip(inputs, self.input_types):
...@@ -225,6 +237,12 @@ class OpFromGraph(gof.Op): ...@@ -225,6 +237,12 @@ class OpFromGraph(gof.Op):
for output, result in zip(outputs, results): for output, result in zip(outputs, results):
output[0] = result output[0] = result
def grad(self, inputs, output_grads):
if hasattr(self, 'grad_ops'):
return [go(*(inputs + output_grads)) for go in self.grad_ops]
else:
raise NotImplementedError
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论