提交 5e9bf42a authored 作者: nouiz's avatar nouiz

Merge pull request #646 from pascanur/fix_rop_shared_variables_rebase

Fix rop shared variables rebase
...@@ -250,10 +250,12 @@ def Rop(f, wrt, eval_points): ...@@ -250,10 +250,12 @@ def Rop(f, wrt, eval_points):
for pack in enumerate(zip(wrt, eval_points)): for pack in enumerate(zip(wrt, eval_points)):
i = pack[0] i = pack[0]
wrt_elem, eval_point = pack[1] wrt_elem, eval_point = pack[1]
if not isinstance(wrt_elem, gof.Variable):
wrt_elem = as_tensor_variable(wrt_elem) wrt_elem = as_tensor_variable(wrt_elem)
if not isinstance(eval_point, gof.Variable):
eval_point = as_tensor_variable(eval_point) eval_point = as_tensor_variable(eval_point)
try:
wrt_dim = len(wrt_elem.type.broadcastable) wrt_dim = len(wrt_elem.type.broadcastable)
eval_dim = len(eval_point.type.broadcastable) eval_dim = len(eval_point.type.broadcastable)
...@@ -265,6 +267,10 @@ def Rop(f, wrt, eval_points): ...@@ -265,6 +267,10 @@ def Rop(f, wrt, eval_points):
str(wrt_dim) + str(wrt_dim) +
' versus ' + ' versus ' +
str(eval_dim)) str(eval_dim))
except:
# wrt_elem and eval_point can be non-tensor variable which do
# not have broadcastable flags
pass
seen_nodes = {} seen_nodes = {}
...@@ -283,7 +289,12 @@ def Rop(f, wrt, eval_points): ...@@ -283,7 +289,12 @@ def Rop(f, wrt, eval_points):
if inp in wrt: if inp in wrt:
local_eval_points.append(eval_points[wrt.index(inp)]) local_eval_points.append(eval_points[wrt.index(inp)])
elif inp.owner is None: elif inp.owner is None:
try:
local_eval_points.append(inp.zeros_like()) local_eval_points.append(inp.zeros_like())
except:
# None should be used for non-differentiable
# arguments, like for example random states
local_eval_points.append(None)
elif inp.owner in seen_nodes: elif inp.owner in seen_nodes:
local_eval_points.append( local_eval_points.append(
......
...@@ -175,6 +175,9 @@ class mrg_uniform_base(Op): ...@@ -175,6 +175,9 @@ class mrg_uniform_base(Op):
def grad(self,inputs,ograd): def grad(self,inputs,ograd):
return [None for i in inputs] return [None for i in inputs]
def R_op(self, inputs, eval_points):
return [None for i in eval_points]
class mrg_uniform(mrg_uniform_base): class mrg_uniform(mrg_uniform_base):
#CPU VERSION #CPU VERSION
......
...@@ -1532,15 +1532,19 @@ class Scan(PureOp): ...@@ -1532,15 +1532,19 @@ class Scan(PureOp):
rval = scan_utils.reconstruct_graph(self.inputs, rval = scan_utils.reconstruct_graph(self.inputs,
self.outputs, '_rop') self.outputs, '_rop')
self_inputs = rval[0] self_inputs = rval[0]
rop_of_inputs = rval[0][:self.n_seqs + self.n_outs] + \
rval[0][self.n_seqs + self.n_outs + self.n_shared_outs:]
self_outputs = rval[1] self_outputs = rval[1]
# Step 1. Compute the R_op of the inner function # Step 1. Compute the R_op of the inner function
inner_eval_points = [scan_utils.safe_new(x, '_evalpoint') inner_eval_points = [scan_utils.safe_new(x, '_evalpoint')
for x in self_inputs] for x in rop_of_inputs]
if self.as_while: if self.as_while:
rop_self_outputs = self_outputs[:-1] rop_self_outputs = self_outputs[:-1]
else: else:
rop_self_outputs = self_outputs rop_self_outputs = self_outputs
rop_outs = tensor.Rop(rop_self_outputs, self_inputs, inner_eval_points) if self.info['n_shared_outs'] > 0:
rop_self_outputs = rop_self_outputs[:-self.info['n_shared_outs']]
rop_outs = tensor.Rop(rop_self_outputs, rop_of_inputs, inner_eval_points)
if type(rop_outs) not in (list, tuple): if type(rop_outs) not in (list, tuple):
rop_outs = [rop_outs] rop_outs = [rop_outs]
# Step 2. Figure out what corresponds to what in the scan # Step 2. Figure out what corresponds to what in the scan
...@@ -1559,7 +1563,7 @@ class Scan(PureOp): ...@@ -1559,7 +1563,7 @@ class Scan(PureOp):
info['n_sit_sot'] = self.n_sit_sot * 2 info['n_sit_sot'] = self.n_sit_sot * 2
info['n_mit_mot'] = self.n_mit_mot * 2 info['n_mit_mot'] = self.n_mit_mot * 2
info['n_nit_sot'] = self.n_nit_sot * 2 info['n_nit_sot'] = self.n_nit_sot * 2
info['n_shared_outs'] = self.n_shared_outs * 2 info['n_shared_outs'] = self.n_shared_outs
info['gpu'] = False info['gpu'] = False
info['as_while'] = self.as_while info['as_while'] = self.as_while
info['profile'] = self.profile info['profile'] = self.profile
...@@ -1587,7 +1591,14 @@ class Scan(PureOp): ...@@ -1587,7 +1591,14 @@ class Scan(PureOp):
ib = 0 ib = 0
e = 1 + self.n_seqs e = 1 + self.n_seqs
ie = self.n_seqs ie = self.n_seqs
scan_seqs = inputs[b:e] + eval_points[b:e] clean_eval_points = []
for inp, evp in zip(inputs[b:e], eval_points[b:e]):
if evp is not None:
clean_eval_points.append(evp)
else:
clean_eval_points.append(inp.zeros_like())
scan_seqs = inputs[b:e] + clean_eval_points
inner_seqs = self_inputs[ib:ie] + inner_eval_points[ib:ie] inner_seqs = self_inputs[ib:ie] + inner_eval_points[ib:ie]
# MIT_MOT sequences ... # MIT_MOT sequences ...
...@@ -1596,7 +1607,14 @@ class Scan(PureOp): ...@@ -1596,7 +1607,14 @@ class Scan(PureOp):
ib = ie ib = ie
ie = ie + int(numpy.sum([len(x) for x in ie = ie + int(numpy.sum([len(x) for x in
self.tap_array[:self.n_mit_mot]])) self.tap_array[:self.n_mit_mot]]))
scan_mit_mot = inputs[b:e] + eval_points[b:e] clean_eval_points = []
for inp, evp in zip(inputs[b:e], eval_points[b:e]):
if evp is not None:
clean_eval_points.append(evp)
else:
clean_eval_points.append(inp.zeros_like())
scan_mit_mot = inputs[b:e] + clean_eval_points
inner_mit_mot = self_inputs[ib:ie] + inner_eval_points[ib:ie] inner_mit_mot = self_inputs[ib:ie] + inner_eval_points[ib:ie]
# MIT_SOT sequences ... # MIT_SOT sequences ...
...@@ -1606,6 +1624,13 @@ class Scan(PureOp): ...@@ -1606,6 +1624,13 @@ class Scan(PureOp):
ie = ie + int(numpy.sum([len(x) for x in ie = ie + int(numpy.sum([len(x) for x in
self.tap_array[self.n_mit_mot:\ self.tap_array[self.n_mit_mot:\
self.n_mit_mot + self.n_mit_sot]])) self.n_mit_mot + self.n_mit_sot]]))
clean_eval_points = []
for inp, evp in zip(inputs[b:e], eval_points[b:e]):
if evp is not None:
clean_eval_points.append(evp)
else:
clean_eval_points.append(inp.zeros_like())
scan_mit_sot = inputs[b:e] + eval_points[b:e] scan_mit_sot = inputs[b:e] + eval_points[b:e]
inner_mit_sot = self_inputs[ib:ie] + inner_eval_points[ib:ie] inner_mit_sot = self_inputs[ib:ie] + inner_eval_points[ib:ie]
...@@ -1614,7 +1639,14 @@ class Scan(PureOp): ...@@ -1614,7 +1639,14 @@ class Scan(PureOp):
e = e + self.n_sit_sot e = e + self.n_sit_sot
ib = ie ib = ie
ie = ie + self.n_sit_sot ie = ie + self.n_sit_sot
scan_sit_sot = inputs[b:e] + eval_points[b:e] clean_eval_points = []
for inp, evp in zip(inputs[b:e], eval_points[b:e]):
if evp is not None:
clean_eval_points.append(evp)
else:
clean_eval_points.append(inp.zeros_like())
scan_sit_sot = inputs[b:e] + clean_eval_points
inner_sit_sot = self_inputs[ib:ie] + inner_eval_points[ib:ie] inner_sit_sot = self_inputs[ib:ie] + inner_eval_points[ib:ie]
#Shared outs ... #Shared outs ...
...@@ -1622,8 +1654,8 @@ class Scan(PureOp): ...@@ -1622,8 +1654,8 @@ class Scan(PureOp):
e = e + self.n_shared_outs e = e + self.n_shared_outs
ib = ie ib = ie
ie = ie + self.n_shared_outs ie = ie + self.n_shared_outs
scan_shared = inputs[b:e] + eval_points[b:e] scan_shared = inputs[b:e]
inner_shared = self_inputs[ib:ie] + inner_eval_points[ib:ie] inner_shared = self_inputs[ib:ie]
# NIT_SOT sequences # NIT_SOT sequences
b = e b = e
...@@ -1631,8 +1663,15 @@ class Scan(PureOp): ...@@ -1631,8 +1663,15 @@ class Scan(PureOp):
scan_nit_sot = inputs[b:e] * 2 scan_nit_sot = inputs[b:e] * 2
# All other arguments # All other arguments
scan_other = inputs[e:] + eval_points[e:] clean_eval_points = []
inner_other = self_inputs[ie:] + inner_eval_points[ie:] for inp, evp in zip(inputs[e:], eval_points[e:]):
if evp is not None:
clean_eval_points.append(evp)
else:
clean_eval_points.append(inp.zeros_like())
scan_other = inputs[e:] + clean_eval_points
# inner_eval_points do not have entries for shared variables
inner_other = self_inputs[ie:] + inner_eval_points[ib:]
# Outputs # Outputs
n_mit_mot_outs = int(numpy.sum([len(x) for x in n_mit_mot_outs = int(numpy.sum([len(x) for x in
...@@ -1652,7 +1691,7 @@ class Scan(PureOp): ...@@ -1652,7 +1691,7 @@ class Scan(PureOp):
inner_out_nit_sot = self_outputs[b:e] + rop_outs[b:e] inner_out_nit_sot = self_outputs[b:e] + rop_outs[b:e]
b = e b = e
e = e + self.n_shared_outs e = e + self.n_shared_outs
inner_out_shared = self_outputs[b:e] + rop_outs[b:e] inner_out_shared = self_outputs[b:e]
inner_ins = (inner_seqs + inner_ins = (inner_seqs +
inner_mit_mot + inner_mit_mot +
...@@ -1695,9 +1734,7 @@ class Scan(PureOp): ...@@ -1695,9 +1734,7 @@ class Scan(PureOp):
b = e + self.n_nit_sot b = e + self.n_nit_sot
e = e + self.n_nit_sot * 2 e = e + self.n_nit_sot * 2
final_outs += outputs[b:e] final_outs += outputs[b:e]
b = e + self.n_shared_outs final_outs += [None]*self.n_shared_outs
e = e + self.n_shared_outs * 2
final_outs += outputs[b:e]
return final_outs return final_outs
......
...@@ -2389,6 +2389,83 @@ class T_Scan(unittest.TestCase): ...@@ -2389,6 +2389,83 @@ class T_Scan(unittest.TestCase):
f2 = theano.function([], gx) f2 = theano.function([], gx)
assert numpy.allclose(f2(), numpy.ones((10,))) assert numpy.allclose(f2(), numpy.ones((10,)))
def test_rop2(self):
seed = utt.fetch_seed()
rng = numpy.random.RandomState(seed)
floatX = theano.config.floatX
v_u = numpy.array(rng.uniform(size=(3, 5)) - .5, dtype=floatX)
v_W = numpy.array(rng.uniform(size=(5, 5)) - .5, dtype=floatX)
v_h0 = numpy.array(rng.uniform(size=(5,)) - .5, dtype=floatX)
v_eu = numpy.array(rng.uniform(size=(3, 5)) - .5, dtype=floatX)
v_eW = numpy.array(rng.uniform(size=(5, 5)) - .5, dtype=floatX)
v_eh0 = numpy.array(rng.uniform(size=(5,)) - .5, dtype=floatX)
def rnn_fn(_u, _y, _W):
srng = theano.tensor.shared_randomstreams.RandomStreams(seed)
sl_o = theano.tensor.tanh(theano.tensor.dot(_W, (_u + _y + \
srng.uniform(size=v_h0.shape) *
numpy.float32(1e-6))))
return sl_o
u = theano.tensor.matrix('U')
h0 = theano.tensor.vector('h0')
W = theano.tensor.matrix('W')
_u = theano.tensor.specify_shape(u, v_u.shape)
_u.name = '_U'
_h0 = theano.tensor.specify_shape(h0, v_h0.shape)
_h0.name = '_h0'
_W = theano.tensor.specify_shape(W, v_W.shape)
_W.name = '_W'
o, _ = theano.scan(rnn_fn,
sequences=_u,
outputs_info=_h0,
non_sequences=_W,
name='rnn_fn')
o = o[-1]
eu = theano.tensor.matrix('eu')
eh0 = theano.tensor.vector('eh0')
eW = theano.tensor.matrix('eW')
nwo_u = theano.tensor.Rop(o, _u, eu)
nwo_h0 = theano.tensor.Rop(o, _h0, eh0)
nwo_W = theano.tensor.Rop(o, _W, eW)
fn_rop = theano.function([u, h0, W, eu, eh0, eW],
[nwo_u, nwo_h0, nwo_W, o],
on_unused_input='ignore')
n2o_u, _ = theano.scan(lambda i, o, u, h0, W, eu: \
(theano.tensor.grad(o[i], u) * eu).sum(),
sequences=tensor.arange(o.shape[0]),
non_sequences=[o, u, h0, W, eu],
name='jacobU')
n2o_h0, _ = theano.scan(lambda i, o, u, h0, W, eh0: \
(theano.tensor.grad(o[i], h0) * eh0).sum(),
sequences=tensor.arange(o.shape[0]),
non_sequences=[o, u, h0, W, eh0],
name='jacobh')
n2o_W, _ = theano.scan(lambda i, o, u, h0, W, eW: \
(theano.tensor.grad(o[i], W) * eW).sum(),
sequences=tensor.arange(o.shape[0]),
non_sequences=[o, u, h0, W, eW],
name='jacobW')
fn_test = theano.function([u, h0, W, eu, eh0, eW],
[n2o_u, n2o_h0, n2o_W, o],
on_unused_input='ignore')
vnu, vnh0, vnW, vno = fn_rop(v_u, v_h0, v_W, v_eu, v_eh0, v_eW)
tnu, tnh0, tnW, tno = fn_test(v_u, v_h0, v_W, v_eu, v_eh0, v_eW)
assert numpy.allclose(vnu, tnu, atol=1e-6)
assert numpy.allclose(vnh0, tnh0, atol=1e-6)
assert numpy.allclose(vnW, tnW, atol=1e-6)
def test_rop(self): def test_rop(self):
seed = utt.fetch_seed() seed = utt.fetch_seed()
rng = numpy.random.RandomState(seed) rng = numpy.random.RandomState(seed)
...@@ -2430,7 +2507,8 @@ class T_Scan(unittest.TestCase): ...@@ -2430,7 +2507,8 @@ class T_Scan(unittest.TestCase):
nwo_h0 = theano.tensor.Rop(o, _h0, eh0) nwo_h0 = theano.tensor.Rop(o, _h0, eh0)
nwo_W = theano.tensor.Rop(o, _W, eW) nwo_W = theano.tensor.Rop(o, _W, eW)
fn_rop = theano.function([u, h0, W, eu, eh0, eW], fn_rop = theano.function([u, h0, W, eu, eh0, eW],
[nwo_u, nwo_h0, nwo_W]) [nwo_u, nwo_h0, nwo_W],
on_unused_input='ignore')
n2o_u, _ = theano.scan(lambda i, o, u, h0, W, eu: \ n2o_u, _ = theano.scan(lambda i, o, u, h0, W, eu: \
(theano.tensor.grad(o[i], u) * eu).sum(), (theano.tensor.grad(o[i], u) * eu).sum(),
...@@ -2451,7 +2529,8 @@ class T_Scan(unittest.TestCase): ...@@ -2451,7 +2529,8 @@ class T_Scan(unittest.TestCase):
name='jacobW') name='jacobW')
fn_test = theano.function([u, h0, W, eu, eh0, eW], fn_test = theano.function([u, h0, W, eu, eh0, eW],
[n2o_u, n2o_h0, n2o_W]) [n2o_u, n2o_h0, n2o_W],
on_unused_input='ignore')
vnu, vnh0, vnW = fn_rop(v_u, v_h0, v_W, v_eu, v_eh0, v_eW) vnu, vnh0, vnW = fn_rop(v_u, v_h0, v_W, v_eu, v_eh0, v_eW)
tnu, tnh0, tnW = fn_test(v_u, v_h0, v_W, v_eu, v_eh0, v_eW) tnu, tnh0, tnW = fn_test(v_u, v_h0, v_W, v_eu, v_eh0, v_eW)
......
...@@ -251,6 +251,9 @@ class RandomFunction(gof.Op): ...@@ -251,6 +251,9 @@ class RandomFunction(gof.Op):
def grad(self, inputs, outputs): def grad(self, inputs, outputs):
return [None for i in inputs] return [None for i in inputs]
def R_op(self, inputs, eval_points):
return [None for i in eval_points]
def _infer_ndim_bcast(ndim, shape, *args): def _infer_ndim_bcast(ndim, shape, *args):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论