提交 f0bdbb7e authored 作者: Nicholas Leonard's avatar Nicholas Leonard

changes to method param names. unit tests

上级 1b1e2ec3
...@@ -543,7 +543,7 @@ def grad(cost, wrt, consider_constant=None, ...@@ -543,7 +543,7 @@ def grad(cost, wrt, consider_constant=None,
rval, = rval rval, = rval
return rval return rval
def subgrad(wrt, grad_end, known_grads=None, cost=None, details=False): def subgrad(wrt, end, start=None, cost=None, details=False):
''' '''
With respect to wrt, computes gradients of known_grads, cost, With respect to wrt, computes gradients of known_grads, cost,
or both, up to grad_end theano variables in theano digraph. or both, up to grad_end theano variables in theano digraph.
...@@ -566,58 +566,74 @@ def subgrad(wrt, grad_end, known_grads=None, cost=None, details=False): ...@@ -566,58 +566,74 @@ def subgrad(wrt, grad_end, known_grads=None, cost=None, details=False):
parameters parameters
---------- ----------
wrt : list wrt : list
gradients are computed with regard to (wrt) these variables. gradients are computed with respect to (wrt) these variables.
known_grads : dict end : list
parameters, gradients (key, value) in the forward part
(near cost) of the graph for which gradients are known.
These will be used to compute the gradients backwards
up to the variables in grad_end.
grad_end : list
theano variables where to stop the backpropagation of gradients theano variables where to stop the backpropagation of gradients
(they will be considered constant in theano.grad). (they will be considered constant in theano.grad).
start : dict
Theano variables, gradients (key, value) in the forward part
(near a cost) of the graph for which gradients are known.
These will be used to compute the gradients backwards
up to the variables in grad_end (they will be used as known_grads
in theano.grad).
cost : theano scalar cost : theano scalar
additional costs for which to compute the gradients. For additional costs for which to compute the gradients. For
example, these could be weight decay, or l1 constraint on output example, these could be weight decay, or l1 constraint on output
details: bool details: bool
when True, return OrderedDict of wrt, gradients, and lists of when True, return OrderedDict of wrt, gradients, and lists of
gradients derived from known_grads, cost_grads, respectively gradients derived from known_grads, cost_grads, respectively
(in same order as params) (in same order as wrt)
return return
------ ------
Returns an OrderedDict of params (keys), gradients (values) Returns an OrderedDict of params (keys), gradients (values)
''' '''
assert ((cost is not None) or (known_grads is not None)) assert ((cost is not None) or (start is not None))
assert isinstance(grad_end, list) assert isinstance(end, list)
assert isinstance(wrt, list) assert isinstance(wrt, list)
if known_grads is not None: if start is not None:
assert isinstance(known_grads, dict) assert isinstance(start, dict)
kg_grads = None
params = list(set(wrt + end))
start_grads = None
cost_grads = None cost_grads = None
if known_grads is not None: if start is not None:
kg_grads = list(theano.grad(cost=None, wrt=wrt, start_grads = list(
known_grads=known_grads, theano.grad(
consider_constant=grad_end, cost=None, wrt=params, known_grads=start,
disconnected_inputs='ignore')) consider_constant=end,
disconnected_inputs='ignore'
)
)
if cost is not None: if cost is not None:
cost_grads = list(theano.grad(cost=cost, wrt=wrt, cost_grads = list(
consider_constant=grad_end, theano.grad(
disconnected_inputs='ignore')) cost=cost, wrt=params,
consider_constant=end,
disconnected_inputs='ignore'
)
)
grads = None grads = None
if known_grads is None: if start is None:
grads = cost_grads grads = cost_grads
else: else:
grads = kg_grads grads = start_grads
if cost_grads is not None: if cost_grads is not None:
for i in range(len(grads)): for i in range(len(grads)):
grads[i] += cost_grads[i] grads[i] += cost_grads[i]
pgrads = OrderedDict(zip(params, grads))
# separate wrt from end grads:
wrt_grads = list(pgrads[k] for k in wrt)
end_grads = list(pgrads[k] for k in end)
if details: if details:
return grads, kg_grads, cost_grads return wrt_grads, end_grads, start_grads, cost_grads
return grads return wrt_grads, end_grads
def _node_to_pattern(node): def _node_to_pattern(node):
""" given an apply node, obtain its connection pattern """ given an apply node, obtain its connection pattern
......
...@@ -569,7 +569,7 @@ def test_subgrad(): ...@@ -569,7 +569,7 @@ def test_subgrad():
cost2 += theano.tensor.sqr(w2.sum()) cost2 += theano.tensor.sqr(w2.sum())
cost1 = theano.tensor.sqr(w1.sum()) cost1 = theano.tensor.sqr(w1.sum())
params = [[w2,a1],[w1,x]] params = [[w2],[w1]]
costs = [cost2,cost1] costs = [cost2,cost1]
grad_ends = [[a1], [x]] grad_ends = [[a1], [x]]
...@@ -578,30 +578,24 @@ def test_subgrad(): ...@@ -578,30 +578,24 @@ def test_subgrad():
values = [rng.randn(2), rng.randn(3)] values = [rng.randn(2), rng.randn(3)]
values = [np.cast[ipt.dtype](value) for ipt, value in zip(inputs, values)] values = [np.cast[ipt.dtype](value) for ipt, value in zip(inputs, values)]
wrt = [w2, a1, w1, x] wrt = [w2, w1]
cost = cost2 + cost1 cost = cost2 + cost1
true_grads = theano.grad(cost, wrt) true_grads = theano.grad(cost, wrt)
true_grads = theano.function(inputs, true_grads) true_grads = theano.function(inputs, true_grads)
true_grads = true_grads(*values) true_grads = true_grads(*values)
from theano.gof.python25 import OrderedDict from theano.gof.python25 import OrderedDict
known_grad = None next_grad = None
params2 = [] param_grads = []
for i in xrange(2): for i in xrange(2):
param = params[i] param_grad, next_grad = theano.subgrad(
cost = costs[i] wrt=params[i], end=grad_ends[i],
grad_end = grad_ends[i] start=next_grad, cost=costs[i]
pgrad = theano.subgrad(
wrt=param, grad_end=grad_end,
known_grads=known_grad, cost=cost
) )
known_grad = OrderedDict(zip(param,pgrad)) next_grad = OrderedDict(zip(grad_ends[i], next_grad))
params2.extend(pgrad) param_grads.extend(param_grad)
pgrads = theano.function(inputs, params2) pgrads = theano.function(inputs, param_grads)
pgrads = pgrads(*values) pgrads = pgrads(*values)
print(pgrads)
print(true_grads)
for true_grad, pgrad in zip(true_grads, pgrads): for true_grad, pgrad in zip(true_grads, pgrads):
print(true_grad, pgrad) print(true_grad, pgrad)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论