提交 54ba8808 authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier

Reimplement connection_pattern() to avoid calling grad()

上级 46ebe84a
......@@ -1377,31 +1377,12 @@ class Scan(PureOp):
# that had many scan one inside each others.
if hasattr(node.tag, 'connection_pattern'):
return node.tag.connection_pattern
# The gradient wrt to n_steps is disconnected
connection_pattern = [[False for output in node.outputs]]
connection_pattern += [[False for output in node.outputs]
for x in node.inputs[1:]]
def compute_gradient(y, g_y, diff_inputs):
rval = []
gmp = OrderedDict()
consider_inps = [x for x in theano.gof.graph.inputs([y])
if x in diff_inputs]
for x in consider_inps:
try:
gmp[x] = gradient.grad(cost=None,
known_grads={y: g_y}, wrt=x)
except gradient.NullTypeGradError:
# It means the gradient is undefined (which implies
# is connected).
# Warning: x is not the right gradient here, but the only
# thing we will check later is whether it is None.
gmp[x] = x
except gradient.DisconnectedInputError:
gmp[x] = None
return [gmp.get(p, None) for p in diff_inputs]
def _get_inner_outs(oidx):
# Define helper functions
def _get_inner_outs_idx(oidx):
"""Given the index of an outer output, return the indices of the
corresponding inner output(s) in a sequence.
"""
s = 0
if self.n_mit_mot > 0:
e = len(self.mitmot_out_taps()[0])
......@@ -1413,13 +1394,13 @@ class Scan(PureOp):
e += len(self.mitmot_out_taps()[p])
else:
e += 1
return self.outputs[s:e]
def _get_inner_inps(outer_iidx):
"""Given the index of an outer input, return the corresponding
inner input(s) as a sequence.
"""
return range(s, e)
def _get_inner_inps_idx(outer_iidx):
"""Given the index of an outer input, return the indices of the
corresponding inner input(s) in a sequence.
"""
outer_iidx_from_inner_iidx = self.get_outer_iidx_from_inner_iidx_seq()
# For every inner input, if the corresponding outer input is the
......@@ -1429,46 +1410,35 @@ class Scan(PureOp):
if outer_iidx_from_inner_iidx[i] == outer_iidx:
inner_iidxs.append(i)
# The inner inputs can be selected this way because the indices in
# inner_iidxs are consecutive and in ascending order
if len(inner_iidxs) > 0:
inner_inputs = self.inputs[inner_iidxs[0]:inner_iidxs[-1]+1]
else:
inner_inputs = []
return inner_iidxs
return inner_inputs
# Obtain the connection pattern of the inner function.
inner_connect_pattern = self.inner_connection_pattern(node)
for oidx, out in enumerate(node.outputs):
for iidx, inp in enumerate(node.inputs[1:]):
ols = _get_inner_outs(oidx)
ils = _get_inner_inps(iidx + 1)
# Initially assume no outer input is connected to any outer output
connection_pattern = [[False for output in node.outputs]
for x in node.inputs]
# For every possible pair of outer input and outer output, iterate
# over every possible pairing of their corresponding inner inputs
# and inner outputs and, if one such pair of inner variables is
# connected than the pair of outer variables is connected.
for outer_oidx in range(len(node.outputs)):
inner_oidxs = _get_inner_outs_idx(outer_oidx)
for outer_iidx in range(len(node.inputs)):
inner_iidxs = _get_inner_inps_idx(outer_iidx)
for inner_oidx in inner_oidxs:
for inner_iidx in inner_iidxs:
if inner_connect_pattern[inner_iidx][inner_oidx]:
connection_pattern[outer_iidx][outer_oidx] = True
break
if connection_pattern[outer_iidx][outer_oidx]:
break
if ils is None:
# The gradient should be disconnected
connection_pattern[iidx + 1][oidx] = False
else:
for inner_out in ols:
# We check for the dtype because inner_out could be
# any Theano type like Generic or RandomState, for
# which we can not impose a dtype
if hasattr(inner_out, 'dtype'):
# Note that we do not care about the output of
# this compute gradient. We just care to see if
# it is None or not. (i.e. disconnected or not)
try:
old = theano.config.compute_test_value
theano.config.compute_test_value = 'off'
tmp = compute_gradient(
inner_out,
safe_new(inner_out, dtype='float64'),
ils)
finally:
theano.config.compute_test_value = old
else:
# It should be undefined not disconnected
tmp = ils
if any([x is not None for x in tmp]):
connection_pattern[iidx + 1][oidx] = True
# Applying Floyd-Warshall to find all paths connecting inputs to
# outputs. Note that if `x` is an input to `y_t` and `y_tm1` is an
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论