提交 b4b313d2 authored 作者: Frederic Bastien's avatar Frederic Bastien 提交者: Frederic Bastien

Fix gh-6578 Fix scan naming of variables.

上级 c28ece51
......@@ -536,11 +536,16 @@ def scan(fn,
maxtap_proxy = max(maxtap, 0)
mintap_proxy = min(mintap, 0)
start = (k - mintap_proxy)
nw_name = None
if k == maxtap_proxy:
nw_seq = seq['input'][start:]
if getattr(seq['input'], 'name', None) is not None:
nw_name = seq['input'].name + "[%d:]" % start
else:
end = -(maxtap_proxy - k)
nw_seq = seq['input'][start:end]
if getattr(seq['input'], 'name', None) is not None:
nw_name = seq['input'].name + "[%d:%d]" % (start, end)
if go_backwards:
nw_seq = nw_seq[::-1]
......@@ -549,6 +554,9 @@ def scan(fn,
inner_seqs.append(nw_slice)
inner_slices.append(actual_slice)
n_seqs += 1
# Add names -- it helps a lot when debugging
if nw_name is not None:
nw_seq.name = nw_name
# Since we've added all sequences now we need to level them up based on
# n_steps or their different shapes
......@@ -576,12 +584,6 @@ def scan(fn,
else:
actual_n_steps = tensor.as_tensor(n_steps)
# Add names -- it helps a lot when debugging
for (nw_seq, seq) in zip(scan_seqs, seqs):
if getattr(seq['input'], 'name', None) is not None:
nw_seq.name = seq['input'].name + '[%d:]' % k
scan_seqs = [seq[:actual_n_steps] for seq in scan_seqs]
# Conventions :
# mit_mot = multiple input taps, multiple output taps ( only provided
......
......@@ -1604,44 +1604,44 @@ class T_Scan(unittest.TestCase):
| | | | | | |Elemwise{minimum,no_inplace} [id K] ''
| | | | | | | |Subtensor{int64} [id L] ''
| | | | | | | | |Shape [id M] ''
| | | | | | | | | |Subtensor{int64::} [id N] 'u1[1:]'
| | | | | | | | | |Subtensor{int64::} [id N] 'u1[0:]'
| | | | | | | | | |u1 [id O]
| | | | | | | | | |Constant{0} [id P]
| | | | | | | | |Constant{0} [id Q]
| | | | | | | |Subtensor{int64} [id R] ''
| | | | | | | |Shape [id S] ''
| | | | | | | | |Subtensor{int64:int64:} [id T] 'u2[1:]'
| | | | | | | | |Subtensor{int64:int64:} [id T] 'u2[0:-2]'
| | | | | | | | |u2 [id U]
| | | | | | | | |Constant{0} [id V]
| | | | | | | | |Constant{-2} [id W]
| | | | | | | |Constant{0} [id X]
| | | | | | |Subtensor{int64} [id Y] ''
| | | | | | |Shape [id Z] ''
| | | | | | | |Subtensor{int64:int64:} [id BA] ''
| | | | | | | |Subtensor{int64:int64:} [id BA] 'u2[1:-1]'
| | | | | | | |u2 [id U]
| | | | | | | |Constant{1} [id BB]
| | | | | | | |Constant{-1} [id BC]
| | | | | | |Constant{0} [id BD]
| | | | | |Subtensor{int64} [id BE] ''
| | | | | |Shape [id BF] ''
| | | | | | |Subtensor{int64::} [id BG] ''
| | | | | | |Subtensor{int64::} [id BG] 'u2[2:]'
| | | | | | |u2 [id U]
| | | | | | |Constant{2} [id BH]
| | | | | |Constant{0} [id BI]
| | | | |Subtensor{:int64:} [id BJ] ''
| | | | | |Subtensor{int64::} [id N] 'u1[1:]'
| | | | | |Subtensor{int64::} [id N] 'u1[0:]'
| | | | | |ScalarFromTensor [id BK] ''
| | | | | |Elemwise{minimum,no_inplace} [id I] ''
| | | | |Subtensor{:int64:} [id BL] ''
| | | | | |Subtensor{int64:int64:} [id T] 'u2[1:]'
| | | | | |Subtensor{int64:int64:} [id T] 'u2[0:-2]'
| | | | | |ScalarFromTensor [id BM] ''
| | | | | |Elemwise{minimum,no_inplace} [id I] ''
| | | | |Subtensor{:int64:} [id BN] ''
| | | | | |Subtensor{int64:int64:} [id BA] ''
| | | | | |Subtensor{int64:int64:} [id BA] 'u2[1:-1]'
| | | | | |ScalarFromTensor [id BO] ''
| | | | | |Elemwise{minimum,no_inplace} [id I] ''
| | | | |Subtensor{:int64:} [id BP] ''
| | | | | |Subtensor{int64::} [id BG] ''
| | | | | |Subtensor{int64::} [id BG] 'u2[2:]'
| | | | | |ScalarFromTensor [id BQ] ''
| | | | | |Elemwise{minimum,no_inplace} [id I] ''
| | | | |IncSubtensor{Set;:int64:} [id BR] ''
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论