提交 b7a20acb authored 作者: Olivier Breuleux's avatar Olivier Breuleux

merge

...@@ -96,21 +96,21 @@ class _test_compile(unittest.TestCase): ...@@ -96,21 +96,21 @@ class _test_compile(unittest.TestCase):
fn() fn()
self.failUnless(go[0].data == 6.0) self.failUnless(go[0].data == 6.0)
def test_prog_noopt(self): def test_noopt(self):
gi, go = graph1() gi, go = graph1()
p = Prog(gi,go) p = Function(gi,go)
self.failUnless(p() == 1.5) self.failUnless(p() == 1.5)
def test_prog_opt(self): def test_opt(self):
opt = gof.opt.PatternOptimizer((Div, '1', '2'), (Div, '2', '1')) opt = gof.opt.PatternOptimizer((Div, '1', '2'), (Div, '2', '1'))
gi, go = graph1() gi, go = graph1()
p = Prog(gi,go, optimizer=opt) p = Function(gi,go, optimizer=opt)
self.failUnless(p() == 6.0) self.failUnless(p() == 6.0)
def test_prog_multiout(self): def test_multiout(self):
opt = gof.opt.PatternOptimizer((Div, '1', '2'), (Div, '2', '1')) opt = gof.opt.PatternOptimizer((Div, '1', '2'), (Div, '2', '1'))
gi, go = graph2() gi, go = graph2()
p = Prog(gi,go, optimizer=opt) p = Function(gi,go, optimizer=opt)
a,b,c = p() a,b,c = p()
self.failUnless(a == 6.0) self.failUnless(a == 6.0)
self.failUnless(b == 6.0) self.failUnless(b == 6.0)
......
差异被折叠。
...@@ -6,15 +6,15 @@ import gof ...@@ -6,15 +6,15 @@ import gof
_optimizations = None _optimizations = None
def prog_py_opt(inputs, outputs, features=[]): def exec_py_opt(inputs, outputs, features=[]):
"""Return an optimized graph running purely python implementations""" """Return an optimized graph running purely python implementations"""
return Prog(intputs, outputs, features, _optimizations, gof.link.PerformLinker, False) return Function(intputs, outputs, features, _optimizations, gof.link.PerformLinker, False)
def prog_opt(inputs, outputs, features=[]): def exec_opt(inputs, outputs, features=[]):
"""Return a fast implementation""" """Return a fast implementation"""
return Prog(intputs, outputs, features, _optimizations, gof.link.PerformLinker, False) return Function(intputs, outputs, features, _optimizations, gof.link.PerformLinker, False)
class Prog: class Function:
"""An 'executable' compiled from a graph """An 'executable' compiled from a graph
This class is meant to be used as a function: the idea is to use This class is meant to be used as a function: the idea is to use
......
差异被折叠。
import gof
class OrderError(Exception):
"""Grad has been manipulated in the wrong order"""
class Grad(object):
"""A dictionary-like class, into which derivative expressions may be added.
Attributes:
map - dict: result -> grad(result)
outputs - list: results from which to backpropagate gradient
did_bprop - bool: has bprop been called?
items_got - set: results for which we have returned the gradient
Methods:
add() - accumulate a gradient expression
bprop() - recursively construct gradient expressions
__call__() - retrieve the gradient wrt a given Op or result
__getitem__() - retrieve the gradient wrt a given Op or result
This class operates on graphs of nodes which implement the UpdateGradient interface.
"""
def __init__(self, dct={}):
self.map = {}
self.outputs = []
self.did_bprop = False
self.items_got = set([])
for key,val in dct.items():
self.add_output(key,val)
def __contains__(self, item):
return item in self.map
def __getitem__(self, r):
"""Return the gradient wrt result r
r is also added to the set of things for which the gradient has been
given. Subsequent attempts to modify the gradient wrt r will fail
with exception FixedGradientError.
"""
self.items_got.add(r)
try:
return self.map[r]
except KeyError:
return None
def __call__(self, r):
"""Return the gradient wrt result r"""
return self.__getitem__(r)
def add_output(self, r, dr):
self.add(r, dr)
self.outputs.append(r)
def add(self, r, dr):
"""Add dr to the sum of gradients associated with r."""
if r in self.items_got:
raise OrderError('gradient has already been retrieved', r)
if r in self.map:
self.map[r] = self.map[r] + dr
else:
self.map[r] = dr
def bprop(self):
"""Build a backpropagation graph.
This function traverses the graph backward from self.outputs, calling
update_gradient on the ops as it goes. Ops without an update_gradient
function are considered not differentiable. The update_gradient
function is defined in the UpdateGradient class.
maybe_redo
"""
if self.did_bprop:
raise OrderError('bprop has already been done')
try:
outputs = self.outputs
inputs = gof.graph.inputs(outputs)
for op in gof.graph.io_toposort(inputs, outputs).__reversed__():
op.update_gradient(self)
finally:
self.did_bprop = True
def grad(cost, param=None, cost_grad = 1.0):
"""Return symbolic expression of gradient of <cost> wrt <param>.
If <param> is None, then return a Grad instance, from which the gradients of
multiple objects can be retrieved using the __getitem__ or __call__ methods
(as in function currying in languages such as scheme and OCaML).
If <param> is not None, then return the gradient expression for
d cost / d param.
"""
rval = Grad({cost:cost_grad})
rval.bprop()
if param is None:
return rval
else:
return rval(param)
class UpdateGradient:
"""This class defines the interface that Grad.bprop expects of each
differentiable Op"""
def update_gradient(self, grad_d):
"""Override this function to call grad_d.add(r,grad_r) for each
differentiable input result, r.
You can assume that the gradient with respect to all output results
has been accumulated in grad_d. These expressions are available by
calling grad_d[o] for o in self.outputs. If grad_d[o] returns None,
then this function should assume that grad_d[o] is an appropriate sort
of zero.
"""
raise AbstractFunctionError()
class SelfGrad (UpdateGradient):
"""This class implements update_gradient in terms of the popular self.grad
This class defines update_gradient (necessary for Grad.bprop) to call a
self.grad function like this:
if len(self.outputs) > 1:
self.grad(self.inputs, [grad_d[o] for o in self.outputs])
else
self.grad(self.inputs, grad_d[output[0]])
self.grad() is an Abstract function, see its documentation for the
expected behaviour.
"""
def update_gradient(self, grad_d):
#Call self.grad(inputs, output_gradients) and add the result to grad_d
if len(self.outputs) > 1:
inputgs = self.grad(self.inputs, [grad_d[o] for o in self.outputs])
else:
inputgs = self.grad(self.inputs, grad_d[self.outputs[0]])
if len(self.inputs) == 1 and is_result(inputgs):
inputgs = [inputgs]
else:
assert len(inputgs) == len(self.inputs)
for input, inputgrad in zip(self.inputs, inputgs):
grad_d.add(input, inputgrad)
def grad(self, *args):
"""Return gradient expressions wrt input arguments
If len(self.inputs)==1 : return the input gradient expression
If len(self.inputs)>=2 : return a list of input gradient expressions
"""
raise AbstractFunctionError()
from tensor import *
from gof import Op, utils, Destroyer, Viewer from gof import Op, utils, Destroyer, Viewer
import gof.op
import gradient
from tensor import *
def upcast(dtype, *dtypes): def _upcast(dtype, *dtypes):
z = numpy.zeros((), dtype = dtype) z = numpy.zeros((), dtype = dtype)
for dtype in dtypes: for dtype in dtypes:
z = z + numpy.zeros((), dtype = dtype) z = z + numpy.zeros((), dtype = dtype)
return str(z.dtype) return str(z.dtype)
def wrap_as_tensor(x): def _wrap_as_tensor(x):
if isinstance(x, Tensor): if isinstance(x,Op):
return _wrap_as_tensor(x.out)
elif isinstance(x, Tensor):
return x return x
else: else:
return Tensor(data=x, constant=True) return Tensor(data=x, constant=True)
class TensorOp(Op): # _TensorOp is a convenient base class, permitting to factor the code for the
# Ops in this file.
# It is not necessary to inherit from TensorOp to make an Op that manipulates
# Tensors.
class _TensorOp(Op, gradient.SelfGrad):
nin = -1 nin = -1
nout = 1 nout = 1
cast_method = lambda self, *args: upcast(*args) cast_method = lambda self, *args: _upcast(*args)
def __init__(self, *inputs): def __init__(self, *inputs):
inputs = map(wrap_as_tensor, inputs) inputs = map(_wrap_as_tensor, inputs)
if self.nin >= 0: if self.nin >= 0:
if len(inputs) != self.nin: if len(inputs) != self.nin:
...@@ -69,10 +78,10 @@ class TensorOp(Op): ...@@ -69,10 +78,10 @@ class TensorOp(Op):
class UnaryTensorOp(TensorOp): class UnaryTensorOp(_TensorOp):
nin = 1 nin = 1
class BinaryTensorOp(TensorOp): class BinaryTensorOp(_TensorOp):
nin = 2 nin = 2
...@@ -104,7 +113,7 @@ class BinaryTensorOp(TensorOp): ...@@ -104,7 +113,7 @@ class BinaryTensorOp(TensorOp):
def scalar_switch(normal_f, scalar_f, scalar_f_reverse = None): def scalar_switch(normal_f, scalar_f, scalar_f_reverse = None):
def f(x, y): def f(x, y):
x, y = wrap_as_tensor(x), wrap_as_tensor(y) x, y = _wrap_as_tensor(x), _wrap_as_tensor(y)
if 0 not in y.broadcastable: if 0 not in y.broadcastable:
return scalar_f(x, y) return scalar_f(x, y)
if 0 not in x.broadcastable: if 0 not in x.broadcastable:
...@@ -129,7 +138,7 @@ def assert_tensor_scalar(x, a): ...@@ -129,7 +138,7 @@ def assert_tensor_scalar(x, a):
class Elemwise(TensorOp): class Elemwise(_TensorOp):
@staticmethod @staticmethod
def extract_name(name): def extract_name(name):
...@@ -211,7 +220,7 @@ class TensorScalarOp(Elemwise): ...@@ -211,7 +220,7 @@ class TensorScalarOp(Elemwise):
## Dot ## ## Dot ##
######### #########
class Dot(TensorOp): class Dot(_TensorOp):
@staticmethod @staticmethod
def _output_shape(xshape, yshape): def _output_shape(xshape, yshape):
# This describes the logic to calculate numpy.dot(x, y).shape # This describes the logic to calculate numpy.dot(x, y).shape
...@@ -454,7 +463,7 @@ class Fill(Elemwise): ...@@ -454,7 +463,7 @@ class Fill(Elemwise):
#### Unary Operations #### #### Unary Operations ####
########################## ##########################
class Transpose(TensorOp, Viewer): class Transpose(_TensorOp, Viewer):
def view_map(self): def view_map(self):
return {self.out: [self.inputs[0]]} return {self.out: [self.inputs[0]]}
def impl(self, x): def impl(self, x):
...@@ -754,6 +763,8 @@ Tensor.__mul__ = mul ...@@ -754,6 +763,8 @@ Tensor.__mul__ = mul
Tensor.__iadd__ = add_inplace Tensor.__iadd__ = add_inplace
Tensor.__isub__ = sub_inplace Tensor.__isub__ = sub_inplace
Tensor.__imul__ = mul_inplace Tensor.__imul__ = mul_inplace
Tensor.__pow__ = pow
Tensor.__ipow__ = pow_inplace
Tensor.T = property(transpose) Tensor.T = property(transpose)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论