提交 e17c495f authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Removed depricated refrences to other_ignore arguments (they were not used

by the code .. and just reduced the readability of the code ) - all scan tests still pass
上级 4fc3c1b2
...@@ -646,7 +646,7 @@ if cuda.cuda_available: ...@@ -646,7 +646,7 @@ if cuda.cuda_available:
+ thescan.n_shared_outs) + thescan.n_shared_outs)
nw_ins += [safe_to_gpu(x) for x in inputs[1:e] ] nw_ins += [safe_to_gpu(x) for x in inputs[1:e] ]
b = e b = e
e = e + thescan.n_nit_sot + thescan.n_other_ignore e = e + thescan.n_nit_sot
nw_ins += inputs[b:e] nw_ins += inputs[b:e]
nw_ins += [safe_to_gpu(x) for x in inputs[e:] ] nw_ins += [safe_to_gpu(x) for x in inputs[e:] ]
scan_ins = [ tensor_to_cuda(x) for x in thescan.inputs] scan_ins = [ tensor_to_cuda(x) for x in thescan.inputs]
...@@ -678,7 +678,7 @@ if cuda.cuda_available: ...@@ -678,7 +678,7 @@ if cuda.cuda_available:
+ thescan.n_shared_outs) + thescan.n_shared_outs)
nw_ins += [safe_to_gpu(x) for x in inputs[1:e] ] nw_ins += [safe_to_gpu(x) for x in inputs[1:e] ]
b = e b = e
e = e + thescan.n_nit_sot + thescan.n_other_ignore e = e + thescan.n_nit_sot
nw_ins += inputs[b:e] nw_ins += inputs[b:e]
nw_ins += [safe_to_gpu(x) for x in inputs[e:] ] nw_ins += [safe_to_gpu(x) for x in inputs[e:] ]
......
...@@ -759,23 +759,14 @@ def scan( fn ...@@ -759,23 +759,14 @@ def scan( fn
nit_sot_rightOrder.append( i ) nit_sot_rightOrder.append( i )
n_nit_sot += 1 n_nit_sot += 1
## Step 5.5 Sequences with no taps used ## Step 5.5 all other arguments including extra inputs
n_other_ignore = 0
ignore_scan_seqs = []
ignore_inner_seqs = []
for i,seq in enumerate(seqs):
if not 'taps' in seq:
ignore_scan_seqs.append(seq['input'])
n_other_ignore += 1
## Step 5.6 all other arguments including extra inputs
other_scan_args = [] other_scan_args = []
other_inner_args = [] other_inner_args = []
other_scan_args += [ arg for arg in non_seqs other_scan_args += [ arg for arg in non_seqs
if not isinstance(arg, SharedVariable) ] if not isinstance(arg, SharedVariable) ]
## Step 5.8 all shared variables with no update rules ## Step 5.6 all shared variables with no update rules
def new_variable( v ): def new_variable( v ):
new_v = safe_new(v) new_v = safe_new(v)
if v.name: if v.name:
...@@ -805,7 +796,6 @@ def scan( fn ...@@ -805,7 +796,6 @@ def scan( fn
mit_sot_inner_inputs + mit_sot_inner_inputs +
sit_sot_inner_inputs + sit_sot_inner_inputs +
shared_inner_inputs + shared_inner_inputs +
ignore_inner_seqs +
other_shared_inner_args + other_shared_inner_args +
other_inner_args ) other_inner_args )
...@@ -850,7 +840,6 @@ def scan( fn ...@@ -850,7 +840,6 @@ def scan( fn
info['n_sit_sot'] = n_sit_sot info['n_sit_sot'] = n_sit_sot
info['n_shared_outs'] = n_shared_outs info['n_shared_outs'] = n_shared_outs
info['n_nit_sot'] = n_nit_sot info['n_nit_sot'] = n_nit_sot
info['n_other_ignore'] = n_other_ignore
info['truncate_gradient'] = truncate_gradient info['truncate_gradient'] = truncate_gradient
info['name'] = name info['name'] = name
info['mode'] = mode info['mode'] = mode
...@@ -876,7 +865,6 @@ def scan( fn ...@@ -876,7 +865,6 @@ def scan( fn
sit_sot_scan_inputs + sit_sot_scan_inputs +
shared_scan_inputs + shared_scan_inputs +
[ actual_n_steps for x in xrange(n_nit_sot) ] + [ actual_n_steps for x in xrange(n_nit_sot) ] +
ignore_scan_seqs +
other_shared_scan_args + other_shared_scan_args +
other_scan_args ) other_scan_args )
......
...@@ -427,7 +427,7 @@ class Scan(Op): ...@@ -427,7 +427,7 @@ class Scan(Op):
outs[idx][0] = args[self.seqs_arg_offset + idx].copy() outs[idx][0] = args[self.seqs_arg_offset + idx].copy()
offset = self.nit_sot_arg_offset + self.n_nit_sot + self.n_other_ignore offset = self.nit_sot_arg_offset + self.n_nit_sot
other_args = args[offset:] other_args = args[offset:]
zipped_outs = [(outs[idx], self.vector_outs[idx], tap, zipped_outs = [(outs[idx], self.vector_outs[idx], tap,
store_steps[idx], idx) for idx in xrange(self.n_outs) store_steps[idx], idx) for idx in xrange(self.n_outs)
...@@ -595,7 +595,7 @@ class Scan(Op): ...@@ -595,7 +595,7 @@ class Scan(Op):
outs_shape += [ input_shapes[idx+offset] ] outs_shape += [ input_shapes[idx+offset] ]
# non_sequences # non_sequences
offset += self.n_nit_sot + self.n_other_ignore + self.n_shared_outs offset += self.n_nit_sot + self.n_shared_outs
inner_ins_shapes = seqs_shape + outs_shape + input_shapes[offset:] inner_ins_shapes = seqs_shape + outs_shape + input_shapes[offset:]
assert len(inner_ins_shapes) == len(self.inputs) assert len(inner_ins_shapes) == len(self.inputs)
...@@ -940,10 +940,8 @@ class Scan(Op): ...@@ -940,10 +940,8 @@ class Scan(Op):
info['name'] = None info['name'] = None
info['mode'] = self.mode info['mode'] = self.mode
info['inplace'] = False info['inplace'] = False
info['n_other_ignore'] = 0
n_mit_sot = 0 n_mit_sot = 0
n_sit_sot = 0 n_sit_sot = 0
n_other_ignore_seqs = 0
if self.truncate_gradient != -1 : if self.truncate_gradient != -1 :
do_steps = tensor.minimum(args[0], self.truncate_gradient) do_steps = tensor.minimum(args[0], self.truncate_gradient)
else: else:
...@@ -955,8 +953,7 @@ class Scan(Op): ...@@ -955,8 +953,7 @@ class Scan(Op):
+ self.n_mit_sot + self.n_mit_sot
+ self.n_sit_sot + self.n_sit_sot
+ self.n_nit_sot + self.n_nit_sot
+ self.n_shared_outs + self.n_shared_outs )
+ self.n_other_ignore )
scan_inputs = ( [do_steps] + scan_inputs = ( [do_steps] +
scan_seqs + scan_seqs +
...@@ -999,7 +996,6 @@ class Scan(Op): ...@@ -999,7 +996,6 @@ class Scan(Op):
gradients += [ x[::-1] for x in outputs[:end]] gradients += [ x[::-1] for x in outputs[:end]]
gradients += [ None for x in xrange(self.n_shared_outs)] gradients += [ None for x in xrange(self.n_shared_outs)]
gradients += [ None for x in xrange(self.n_nit_sot) ] gradients += [ None for x in xrange(self.n_nit_sot) ]
gradients += [ None for x in xrange(self.n_other_ignore) ]
begin = end + self.n_seqs begin = end + self.n_seqs
end = begin + n_shared_outs end = begin + n_shared_outs
......
...@@ -754,7 +754,6 @@ def compress_outs(op, not_required, inputs): ...@@ -754,7 +754,6 @@ def compress_outs(op, not_required, inputs):
info['n_sit_sot'] = 0 info['n_sit_sot'] = 0
info['n_shared_outs'] = 0 info['n_shared_outs'] = 0
info['n_nit_sot'] = 0 info['n_nit_sot'] = 0
info['n_other_ignore'] = op.info['n_other_ignore']
info['truncate_gradient'] = op.info['truncate_gradient'] info['truncate_gradient'] = op.info['truncate_gradient']
info['name'] = op.info['name'] info['name'] = op.info['name']
info['inplace'] = op.info['inplace'] info['inplace'] = op.info['inplace']
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论