提交 78d0511e authored 作者: --global's avatar --global

Avoid duplicate input/output taps in mitmots

上级 c042a9c4
...@@ -2016,6 +2016,7 @@ class Scan(PureOp): ...@@ -2016,6 +2016,7 @@ class Scan(PureOp):
undefined_msg = None undefined_msg = None
through_shared = False through_shared = False
disconnected = True disconnected = True
for jdx in xrange(len(self.mit_mot_out_slices[idx])): for jdx in xrange(len(self.mit_mot_out_slices[idx])):
inner_inp_mitmot.append(dC_dXts[out_pos]) inner_inp_mitmot.append(dC_dXts[out_pos])
mitmot_inp_taps[idx].append(-self.mit_mot_out_slices[idx][jdx]) mitmot_inp_taps[idx].append(-self.mit_mot_out_slices[idx][jdx])
...@@ -2023,7 +2024,13 @@ class Scan(PureOp): ...@@ -2023,7 +2024,13 @@ class Scan(PureOp):
out_pos += 1 out_pos += 1
for jdx in xrange(len(self.tap_array[idx])): for jdx in xrange(len(self.tap_array[idx])):
tap = -self.tap_array[idx][jdx]
# Only create a new inner input if there is not already one
# associated with this input tap
if tap not in mitmot_inp_taps[idx]:
inner_inp_mitmot.append(dC_dXtm1s[ins_pos - self.n_seqs]) inner_inp_mitmot.append(dC_dXtm1s[ins_pos - self.n_seqs])
if isinstance(dC_dinps_t[ins_pos].type, NullType): if isinstance(dC_dinps_t[ins_pos].type, NullType):
# We cannot use Null in the inner graph, so we # We cannot use Null in the inner graph, so we
# use a zero tensor of the appropriate shape instead. # use a zero tensor of the appropriate shape instead.
...@@ -2032,7 +2039,24 @@ class Scan(PureOp): ...@@ -2032,7 +2039,24 @@ class Scan(PureOp):
dtype=theano.config.floatX)) dtype=theano.config.floatX))
undefined_msg = dC_dinps_t[ins_pos].type.why_null undefined_msg = dC_dinps_t[ins_pos].type.why_null
else: else:
inner_out_mitmot.append(dC_dinps_t[ins_pos]) new_inner_out_mitmot = dC_dinps_t[ins_pos]
# If there is already an inner input associated with that
# input tap, make sure the computation of the new output
# uses it instead of the input it's currently using
if tap in mitmot_inp_taps[idx]:
print("Duplicate tap")
to_replace = dC_dXtm1s[ins_pos - self.n_seqs]
replacement_idx = (len(mitmot_inp_taps[idx]) -
mitmot_inp_taps[idx].index(tap))
replacement = inner_inp_mitmot[-replacement_idx]
self.tap_array[idx]
new_inner_out_mitmot = theano.clone(new_inner_out_mitmot,
replace=[(to_replace, replacement)])
inner_out_mitmot.append(new_inner_out_mitmot)
if not disconnected_dC_dinps_t[ins_pos]: if not disconnected_dC_dinps_t[ins_pos]:
disconnected = False disconnected = False
...@@ -2041,12 +2065,15 @@ class Scan(PureOp): ...@@ -2041,12 +2065,15 @@ class Scan(PureOp):
if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]): if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]):
through_shared = True through_shared = True
n_mitmot_inps += 1
ins_pos += 1 ins_pos += 1
n_mitmot_outs += 1 n_mitmot_outs += 1
mitmot_inp_taps[idx].append(-self.tap_array[idx][jdx])
mitmot_out_taps[idx].append(-self.tap_array[idx][jdx]) mitmot_out_taps[idx].append(-self.tap_array[idx][jdx])
# Only add the tap as a new input tap if needed
if tap not in mitmot_inp_taps[idx]:
n_mitmot_inps += 1
mitmot_inp_taps[idx].append(-self.tap_array[idx][jdx])
if undefined_msg: if undefined_msg:
type_outs.append(undefined_msg) type_outs.append(undefined_msg)
elif through_shared: elif through_shared:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论