提交 828ed78c authored 作者: Razvan Pascanu's avatar Razvan Pascanu

connection pattern for scan

上级 bc74b395
...@@ -1248,6 +1248,67 @@ class Scan(PureOp): ...@@ -1248,6 +1248,67 @@ class Scan(PureOp):
ipos += len(otaps) ipos += len(otaps)
return ipos + opos return ipos + opos
def connection_pattern(self, node):
connection_pattern = [[True for output in node.outputs]]
connection_pattern += [[False for output in node.outputs]
for x in node.inputs[1:]]
def compute_gradient(y, g_y, diff_inputs):
gmp = gradient.grad_sources_inputs(
[(y, g_y)],
[x for x in theano.gof.graph.inputs([y])
if x in diff_inputs])
return [gmp.get(p, None) for p in diff_inputs]
def _get_inner_outs(oidx):
s = 0
if self.n_mit_mot > 0:
e = len(self.mitmot_out_taps()[0])
else:
e = 1
for p in xrange(oidx):
s = e
if p < self.n_mit_mot:
e += len(self.mitmot_out_taps()[p])
else:
e += 1
return self.outputs[s:e]
def _get_inner_inps(iidx):
s = 0
e = 1
p = iidx
if (node.inputs[iidx] in self.outer_nitsot(node) or
node.inputs[iidx] in self.outer_shared(node)):
return None
if node.inputs[iidx] in self.outer_non_seqs(node):
loc_idx = self.outer_non_seqs(node).index(
node.inputs[iidx])
return [self.inner_non_seqs()[loc_idx]]
for p in xrange(iidx):
s = e
if p < self.n_seqs:
e += 1
else:
e += len(self.tap_array[p-self.n_seqs])
return self.inputs[s:e]
for oidx, out in enumerate(node.outputs):
for iidx, inp in enumerate(node.inputs[1:]):
ols = _get_inner_outs(oidx)
ils = _get_inner_inps(iidx)
if ils is None:
# The gradient should be undefined, not disconnected
connection_pattern[iidx+1][oidx] = True
else:
for inner_out in ols:
tmp = compute_gradient(
inner_out, safe_new(inner_out, dtype='float64'), ils)
if any([x is not None for x in tmp]):
connection_pattern[iidx+1][oidx] = True
return connection_pattern
### GRAD FUNCTION ### GRAD FUNCTION
def grad(self, inputs, dC_douts): def grad(self, inputs, dC_douts):
outs = self(*inputs) outs = self(*inputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论