提交 4336f227 authored 作者: bergstrj@iro.umontreal.ca's avatar bergstrj@iro.umontreal.ca

gradient.py written and tested

......@@ -4,11 +4,282 @@
#
import unittest
import numpy
import compile
import tensor
import tensor_ops as T
import tensor
import gof
from gradient import *
import gradient
class posneg(T.TensorOp):
nout=2
def impl(self, x): return x, -x
def grad(self, x, (gpos, gneg)): return gpos - gneg
class posnegzero(T.TensorOp):
nout=3
def impl(self, x): return x, -x, 0.0
def grad(self, x, (gpos, gneg, gzero)): return gpos - gneg
class _test_grad_sources_inputs(unittest.TestCase):
def test_retNone1(self):
"""Test that it is not ok to return None from op.grad()"""
class retNone(gof.op.Op):
def __init__(self, arg):
self.inputs = [gof.result.ResultBase()]
self.outputs = [gof.result.ResultBase()]
def grad(self, x, gz):
pass
a = retNone(5)
try:
grad_sources_inputs([(a.out, 1)], None)
except ValueError, e:
self.failUnless(e[0] is gradient._msg_retNone)
return
self.fail()
def test_retNone1_b(self):
"""Test that it is ok to return [None] from op.grad()"""
class retNone(gof.op.Op):
def __init__(self, arg):
self.inputs = arg
self.outputs = [gof.result.ResultBase()]
def grad(self, x, gz):
return [None]
i = gof.result.ResultBase()
a = retNone([i])
g = grad_sources_inputs([(a.out, 1)], None)
self.failUnless(not i in g)
def test_wrong_rval_len1(self):
"""Test that it is not ok to return the wrong number of gradients"""
class retNone(gof.op.Op):
def __init__(self, arg):
self.inputs = arg
self.outputs = [gof.result.ResultBase()]
def grad(self, inputs, gz):
return [None]
i = gof.result.ResultBase()
j = gof.result.ResultBase()
a1 = retNone([i])
g = grad_sources_inputs([(a1.out, 1)], None)
a2 = retNone([i,j])
try:
g = grad_sources_inputs([(a2.out, 1)], None)
except ValueError, e:
self.failUnless(e[0] is gradient._msg_badlen)
return
self.fail()
def test_stop_on_all_none(self):
"""Test that op.grad() is not called when output grads are all None"""
class retNone(gof.op.Op):
def __init__(self, arg, tst):
self.inputs = arg
self.outputs = [gof.result.ResultBase()]
self.tst = tst
def grad(self, inputs, gz):
self.tst.fail()
i = gof.result.ResultBase()
a1 = retNone([i],self)
g = grad_sources_inputs([(a1.out, None)], None)
def test_no_invalid_graph(self):
"""Test that bprop fails on an invalid graph"""
raise NotImplementedError()
def test_1in_1out(self):
"""Test grad is called correctly for a 1-to-1 op"""
gval = gof.result.ResultBase()
class O(gof.op.Op):
def __init__(self):
self.inputs = [gof.result.ResultBase()]
self.outputs = [gof.result.ResultBase()]
def grad(self, x, gz):
return gval
a1 = O()
g = grad_sources_inputs([(a1.outputs[0], 1)], None)
self.failUnless(g[a1.inputs[0]] is gval)
def test_1in_Nout(self):
"""Test grad is called correctly for a 1-to-many op"""
gval = gof.result.ResultBase()
class O(gof.op.Op):
def __init__(self):
self.inputs = [gof.result.ResultBase()]
self.outputs = [gof.result.ResultBase(),gof.result.ResultBase()]
def grad(self, x, (gz1, gz2)):
return gval
a1 = O()
g = grad_sources_inputs([(a1.outputs[0], 1)], None)
self.failUnless(g[a1.inputs[0]] is gval)
def test_Nin_1out(self):
"""Test grad is called correctly for a many-to-1 op"""
gval0 = gof.result.ResultBase()
gval1 = gof.result.ResultBase()
class O(gof.op.Op):
def __init__(self):
self.inputs = [gof.result.ResultBase(),gof.result.ResultBase()]
self.outputs = [gof.result.ResultBase()]
def grad(self, (x0,x1), gz):
return (gval0, gval1)
a1 = O()
g = grad_sources_inputs([(a1.outputs[0], 1)], None)
self.failUnless(g[a1.inputs[0]] is gval0)
self.failUnless(g[a1.inputs[1]] is gval1)
def test_Nin_Nout(self):
"""Test grad is called correctly for a many-to-many op"""
gval0 = gof.result.ResultBase()
gval1 = gof.result.ResultBase()
class O(gof.op.Op):
def __init__(self):
self.inputs = [gof.result.ResultBase(),gof.result.ResultBase()]
self.outputs = [gof.result.ResultBase(),gof.result.ResultBase()]
def grad(self, (x0,x1), (gz0,gz1)):
return gval0, gval1
a1 = O()
g = grad_sources_inputs([(a1.outputs[0], 1)], None)
self.failUnless(g[a1.inputs[0]] is gval0)
self.failUnless(g[a1.inputs[1]] is gval1)
def test_some_None_ograds(self):
"""Test grad is called when some output gradients are None"""
class O(gof.op.Op):
def __init__(self, arg, tst):
self.inputs = arg
self.outputs = [gof.result.ResultBase(),gof.result.ResultBase()]
self.tst = tst
def grad(self, inputs, g_out):
return [1]
i = gof.result.ResultBase()
a1 = O([i],self)
g = grad_sources_inputs([(a1.outputs[0], 1)], None)
self.failUnless(g[i] is 1)
def test_some_None_igrads(self):
"""Test that traversal works properly when an op return some None"""
class O(gof.op.Op):
def __init__(self, arg, tst, grad_ok):
self.inputs = arg
self.outputs = [gof.result.ResultBase(),gof.result.ResultBase()]
self.tst = tst
self.grad_ok = grad_ok
def grad(self, inputs, g_out):
if not self.grad_ok:
self.tst.fail()
else:
return [1, None]
i = gof.result.ResultBase()
j = gof.result.ResultBase()
k = gof.result.ResultBase()
a1 = O([i,j],self,True)
a2 = O([a1.outputs[1], k], self, True)
g = grad_sources_inputs([(a2.outputs[0], 1)], None)
self.failUnless(g[i] is 1 and j not in g and k not in g)
a1 = O([i,j],self,True)
a2 = O([k, a1.outputs[1]], self, True)
g = grad_sources_inputs([(a2.outputs[0], 1)], None)
self.failUnless(g[k] is 1 and i not in g and j not in g)
def test_inputs(self):
"""Test that passing inputs shortens the traversal"""
class O(gof.op.Op):
def __init__(self, arg, tst, grad_ok):
self.inputs = arg
self.outputs = [gof.result.ResultBase(),gof.result.ResultBase()]
self.tst = tst
self.grad_ok = grad_ok
def grad(self, inputs, (g0,g1)):
if not self.grad_ok:
self.tst.fail()
else:
if g1:
return [g0, g0+g1]
else:
return [g0, g0]
i = gof.result.ResultBase()
j = gof.result.ResultBase()
k = gof.result.ResultBase()
a1 = O([i,j],self,True)
a2 = O([k,a1.outputs[1]], self, True)
g = grad_sources_inputs([(a2.outputs[0], 1), (a1.outputs[1],4),
(a1.outputs[0], 3), (a1.outputs[0], 3)], a1.outputs)
self.failUnless(g[a2.inputs[0]] == 1)
self.failUnless(g[a2.inputs[1]] == 5)
self.failUnless(g[a1.outputs[0]] == 6)
self.failUnless(g[a1.outputs[1]] == 5)
self.failUnless(a1.inputs[0] not in g)
self.failUnless(a1.inputs[1] not in g)
def test_multiple_sources(self):
"""Test that passing multiple sources works"""
class O(gof.op.Op):
def __init__(self, arg, tst, grad_ok):
self.inputs = arg
self.outputs = [gof.result.ResultBase(),gof.result.ResultBase()]
self.tst = tst
self.grad_ok = grad_ok
def grad(self, inputs, (g0,g1)):
if not self.grad_ok:
self.tst.fail()
else:
if g1:
return [g0, g0+g1]
else:
return [g0, g0]
i = gof.result.ResultBase()
j = gof.result.ResultBase()
k = gof.result.ResultBase()
a1 = O([i,j],self,True)
a2 = O([k,a1.outputs[1]], self, True)
g = grad_sources_inputs([(a2.outputs[0], 1), (a1.outputs[1],4),
(a1.outputs[0], 3), (a1.outputs[0], 3)], None)
self.failUnless(g[a2.inputs[0]] == 1)
self.failUnless(g[a2.inputs[1]] == 5)
self.failUnless(g[a1.outputs[0]] == 6)
self.failUnless(g[a1.outputs[1]] == 5)
self.failUnless(g[a1.inputs[0]] == 6)
self.failUnless(g[a1.inputs[1]] == 11)
class _test_grad(unittest.TestCase):
class O(gof.op.Op):
def __init__(self):
self.inputs = [gof.result.ResultBase(),gof.result.ResultBase()]
self.outputs = [gof.result.ResultBase(),gof.result.ResultBase()]
self.gval0 = gof.result.ResultBase()
self.gval1 = gof.result.ResultBase()
def grad(self, (x0,x1), (gz0,gz1)):
return self.gval0, self.gval1
def test_1param(self):
"""grad: Test passing a single result param"""
a1 = _test_grad.O()
self.failUnless(a1.gval0 is grad(a1.outputs[0], a1.inputs[0]))
def test_Nparam(self):
"""grad: Test passing multiple result params"""
a1 = _test_grad.O()
g0,g1 = grad(a1.outputs[0], a1.inputs)
self.failUnless(a1.gval0 is g0)
self.failUnless(a1.gval1 is g1)
def test_1None_rval(self):
"""grad: Test returning a single None from grad"""
a1 = _test_grad.O()
self.failUnless(None is grad(a1.outputs[0], a1.outputs[1]))
self.failUnless(None is grad(a1.outputs[0], 'wtf'))
def test_NNone_rval(self):
"""grad: Test returning some Nones from grad"""
a1 = _test_grad.O()
g0,g1,g2 = grad(a1.outputs[0], a1.inputs + ['wtf'])
self.failUnless(a1.gval0 is g0)
self.failUnless(a1.gval1 is g1)
self.failUnless(None is g2)
def matrix():
return tensor.Tensor('float64', [0,0])
......@@ -17,14 +288,11 @@ def matrices(n):
return [matrix() for i in xrange(n)]
class _testNone(unitTest.TestCase):
def test0(self):
class _testCase_matinv:# (unittest.TestCase):
def setUp(self):
numpy.random.seed(1)
def matinv(self,dim):
# symbolic program
a,b = matrices(2)
......@@ -55,15 +323,6 @@ class _testCase_matinv:# (unittest.TestCase):
class _testCase_old:#(unittest.TestCase):
class posneg(T._TensorOp):
nout=2
def impl(x): return x, -x
def grad(x, gpos, gneg): return gpos - gneg
class posnegzero(T._TensorOp):
nout=3
def impl(x): return x, -x, 0.0
def grad(x, gpos, gneg, gzero): return gpos - gneg
def setUp(self):
numpy.random.seed(1)
......@@ -143,37 +402,6 @@ class _testCase_old:#(unittest.TestCase):
self.assertEqual(max, min)
self.assertEqual(max, 0.0)
def test_repeat_bprop(self):
"""Refuse to repeat bprop"""
a = numpy.ones((3,3,3))
b,c,d = _testCase.posnegzero(a)
#print b, c, d
g = Grad({b:wrappers.wrap(a), c:wrappers.wrap(a)})
g.bprop()
try:
g.bprop()
self.assertEqual('should have raised')
except Exception, e:
self.assertEqual(str(e), 'bprop has already been done. Consider calling with maybe_redo=True.')
return
self.assertEqual('should have caught')
def test_repeat_bprop1(self):
"""Force repeat bprop"""
a = numpy.ones((3,3,3))
z = numpy.zeros((3,3,3))
b,c,d = _testCase.posnegzero(a)
#print b, c, d
g = Grad({b:wrappers.wrap(a), c:wrappers.wrap(z)})
g.bprop()
g.bprop(maybe_redo=True)
max = numpy.max(g(a))
min = numpy.min(g(a))
self.assertEqual(max, min)
self.assertEqual(max, 2.0)
def tearDown(self):
core.pop_mode()
......
import gof
import gof, gof.result
_msg_retNone = 'op.grad(...) returned None, consider returning [None]'
_msg_badlen = 'op.grad(...) returned wrong number of gradients'
def _unpack_result(lst):
if len(lst) > 1:
return lst
else
else:
return lst[0]
def _pack_result(arg):
if gof.result.is_result(arg): return [arg]
return arg
if isinstance(arg, gof.result.ResultBase):
return [arg]
else:
return arg
def grad_sources_inputs(sources, inputs):
def grad_sources_inputs(sources, graph_inputs):
"""Return a dictionary mapping each result necessary for a source to its gradient
sources - a list of gradient sources (explained below)
inputs - a list of results considered to be constant
graph_inputs - a list of results considered to be constant
A gradient source is a pair (r, g_r), in which r is a result, and g_r is a
result that is a gradient wrt r.
......@@ -49,33 +54,37 @@ def grad_sources_inputs(sources, inputs):
None instead of a result instance.
"""
gmap = {}
for (r, g_r) in self.sources:
if r in gmap:
gmap[r] = gmap[r] + dr
else:
gmap[r] = dr
outputs = gmap.keys()
for (r, g_r) in sources:
if g_r is not None:
if r in gmap:
gmap[r] = gmap[r] + g_r
else:
gmap[r] = g_r
graph_outputs = gmap.keys()
if inputs is None:
inputs = gof.graph.inputs(outputs)
if graph_inputs is None:
graph_inputs = gof.graph.inputs(graph_outputs)
for op in gof.graph.io_toposort(inputs, outputs).__reversed__():
g_outputs = [gmap[o] for o in self.outputs]
if all(map(lambda x:x is None, g_outputs)):
continue
output_arg = unpack_singleton(g_outputs)
input_arg = unpack_singleton(op.inputs)
for op in gof.graph.io_toposort(graph_inputs, graph_outputs).__reversed__():
g_outputs = [gmap.get(o,None) for o in op.outputs]
#if all output gradients are None, continue
if all(map(lambda x:x is None, g_outputs)): continue
output_arg = _unpack_result(g_outputs)
input_arg = _unpack_result(op.inputs)
op_grad = op.grad(input_arg, output_arg)
if op_grad is None:
raise Exception('If you really mean for grad(...) to return None,
please return [None]', op.__class__)
g_inputs = pack_singleton(op_grad)
assert len(g_inputs) == len(op.inputs)
for r, g_r in zip(self.inputs, g_inputs):
raise ValueError(_msg_retNone, op.__class__)
g_inputs = _pack_result(op_grad)
if len(g_inputs) != len(op.inputs):
raise ValueError(_msg_badlen,
op.__class__,
len(g_inputs),
len(op.inputs))
for r, g_r in zip(op.inputs, g_inputs):
if g_r is not None:
if r in gmap:
gmap[r] = gmap[r] + g_r
......@@ -83,17 +92,16 @@ def grad_sources_inputs(sources, inputs):
gmap[r] = g_r
return gmap
def diff(cost, param):
def grad(cost, param):
"""Return symbolic expression of gradient of <cost> wrt <param>.
If <param> is a list, then return a list containing the gradient of cost wrt
each element of the list.
"""
inputs = gof.graph.inputs([cost])
gmap = grad_sources_inputs([(cost, 1.0)], inputs)
if isinstance(param, lst):
return [gmap[p] for p in param]
if isinstance(param, list):
return [gmap.get(p, None) for p in param]
else:
return gmap[param]
return gmap.get(param, None)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论