提交 c7dbe3db authored 作者: Razvan Pascanu's avatar Razvan Pascanu

merge

...@@ -371,7 +371,7 @@ def scan( fn ...@@ -371,7 +371,7 @@ def scan( fn
# ^ explicitly provided a None for taps # ^ explicitly provided a None for taps
warning (' Output %s ( index %d) has a initial state ' warning (' Output %s ( index %d) has a initial state '
' but taps is explicitly set to None ' % ( ' but taps is explicitly set to None ' % (
outs_info[i]['initial'].name getattr(outs_info[i]['initial'],'name','None')
, i) ) , i) )
outs_info[i]['taps'] = [-1] outs_info[i]['taps'] = [-1]
else: else:
...@@ -416,12 +416,10 @@ def scan( fn ...@@ -416,12 +416,10 @@ def scan( fn
nw_slice = seq['input'][0].type() nw_slice = seq['input'][0].type()
actual_slice = seq['input'][k-mintap] actual_slice = seq['input'][k-mintap]
if not hasattr(seq['input'],'name'):
raise TypeError('Expected object with a "name" field, got '+str(seq)+"['input'] = "+str(seq['input']))
# Add names to slices for debugging and pretty printing .. # Add names to slices for debugging and pretty printing ..
# that is if the input already has a name # that is if the input already has a name
if seq['input'].name: if getattr(seq['input'],'name', None) is not None:
if k > 0: if k > 0:
nw_name = seq['input'].name + '[t+%d]'%k nw_name = seq['input'].name + '[t+%d]'%k
elif k == 0: elif k == 0:
...@@ -481,7 +479,7 @@ def scan( fn ...@@ -481,7 +479,7 @@ def scan( fn
# Add names -- it helps a lot when debugging # Add names -- it helps a lot when debugging
for (nw_seq, seq) in zip(scan_seqs, seqs): for (nw_seq, seq) in zip(scan_seqs, seqs):
if seq['input'].name: if getattr(seq['input'],'name', None) is not None:
nw_seq.name = seq['input'].name + '[%d:]'%k nw_seq.name = seq['input'].name + '[%d:]'%k
# Conventions : # Conventions :
...@@ -534,7 +532,7 @@ def scan( fn ...@@ -534,7 +532,7 @@ def scan( fn
actual_arg = init_out['initial'] actual_arg = init_out['initial']
arg = safe_new(init_out['initial']) arg = safe_new(init_out['initial'])
if init_out['initial'].name: if getattr(init_out['initial'],'name', None) is not None:
arg.name = init_out['initial'].name+'[t-1]' arg.name = init_out['initial'].name+'[t-1]'
# We need now to allocate space for storing the output and copy # We need now to allocate space for storing the output and copy
# the initial state over. We do this using the expand function # the initial state over. We do this using the expand function
...@@ -579,7 +577,7 @@ def scan( fn ...@@ -579,7 +577,7 @@ def scan( fn
nw_slice = init_out['initial'][0].type() nw_slice = init_out['initial'][0].type()
# give it a name or debugging and pretty printing # give it a name or debugging and pretty printing
if init_out['initial'].name: if getattr(init_out['initial'],'name', None) is not None:
if k > 0: if k > 0:
nw_slice.name = ( init_out['initial'].name + nw_slice.name = ( init_out['initial'].name +
'[t+%d]'%k ) '[t+%d]'%k )
...@@ -746,7 +744,7 @@ def scan( fn ...@@ -746,7 +744,7 @@ def scan( fn
for input in dummy_f.maker.expanded_inputs: for input in dummy_f.maker.expanded_inputs:
if isinstance(input.variable, SharedVariable) and input.update: if isinstance(input.variable, SharedVariable) and input.update:
new_var = safe_new(input.variable) new_var = safe_new(input.variable)
if input.variable.name: if getattr(input.variable,'name', None) is not None:
new_var.name = input.variable.name + '_copy' new_var.name = input.variable.name + '_copy'
shared_inner_inputs.append( new_var ) shared_inner_inputs.append( new_var )
shared_scan_inputs.append( input.variable ) shared_scan_inputs.append( input.variable )
...@@ -777,7 +775,7 @@ def scan( fn ...@@ -777,7 +775,7 @@ def scan( fn
## Step 5.6 all shared variables with no update rules ## Step 5.6 all shared variables with no update rules
def new_variable( v ): def new_variable( v ):
new_v = safe_new(v) new_v = safe_new(v)
if v.name: if getattr(v,'name', None) is not None:
new_v.name = v.name + '_copy' new_v.name = v.name + '_copy'
return new_v return new_v
other_inner_args += [ new_variable(arg) for arg in non_seqs other_inner_args += [ new_variable(arg) for arg in non_seqs
......
...@@ -226,10 +226,10 @@ class Scan(Op): ...@@ -226,10 +226,10 @@ class Scan(Op):
for idx in xrange(self.n_seqs): for idx in xrange(self.n_seqs):
if inputs[1+idx].dtype != self.inputs[idx].dtype: if inputs[1+idx].dtype != self.inputs[idx].dtype:
raise ValueError(err_msg1%( 'Sequence' raise ValueError(err_msg1%( 'Sequence'
, inputs[1+idx].name , str(inputs[1+idx])
, idx , idx
, inputs[1+idx].dtype , inputs[1+idx].dtype
, self.inputs[idx].name , str(self.inputs[idx])
, self.inputs[idx].dtype) ) , self.inputs[idx].dtype) )
# Check that this 3 things have the same dtype for mit_mot: # Check that this 3 things have the same dtype for mit_mot:
...@@ -246,10 +246,10 @@ class Scan(Op): ...@@ -246,10 +246,10 @@ class Scan(Op):
for k in self.tap_array[index-start]: for k in self.tap_array[index-start]:
if inputs[index].dtype != self.inputs[index_i].dtype: if inputs[index].dtype != self.inputs[index_i].dtype:
raise ValueError(err_msg1%( 'Initial state' raise ValueError(err_msg1%( 'Initial state'
, inputs[index].name , str(inputs[index])
, index , index
, inputs[index].dtype , inputs[index].dtype
, self.inputs[index_i].name , str(self.inputs[index_i])
, self.inputs[index_i].dtype) ) , self.inputs[index_i].dtype) )
index_i += 1 index_i += 1
for k in self.mit_mot_out_slices[index-start]: for k in self.mit_mot_out_slices[index-start]:
...@@ -266,14 +266,14 @@ class Scan(Op): ...@@ -266,14 +266,14 @@ class Scan(Op):
for k in self.tap_array[index-start]: for k in self.tap_array[index-start]:
if inputs[index].dtype != self.inputs[index_i].dtype: if inputs[index].dtype != self.inputs[index_i].dtype:
raise ValueError(err_msg1%( 'Initial state' raise ValueError(err_msg1%( 'Initial state'
, inputs[index].name , str(inputs[index])
, index , index
, inputs[index].dtype , inputs[index].dtype
, self.inputs[index_i].name , str(self.inputs[index_i])
, self.inputs[index_i].dtype) ) , self.inputs[index_i].dtype) )
index_i += 1 index_i += 1
if inputs[index].dtype != self.outputs[index_o].dtype: if inputs[index].dtype != self.outputs[index_o].dtype:
raise ValueError(err_msg2%( inputs[index].name raise ValueError(err_msg2%( str(inputs[index])
, index , index
, inputs[index].dtype , inputs[index].dtype
, self.outputs[index_o].dtype) ) , self.outputs[index_o].dtype) )
...@@ -287,7 +287,7 @@ class Scan(Op): ...@@ -287,7 +287,7 @@ class Scan(Op):
while index < end: while index < end:
if (hasattr(inputs[index],'dtype') and if (hasattr(inputs[index],'dtype') and
inputs[index].dtype != self.outputs[index_o].dtype): inputs[index].dtype != self.outputs[index_o].dtype):
raise ValueError(err_msg2%( inputs[index].name raise ValueError(err_msg2%( str(inputs[index])
, index , index
, inputs[index].dtype , inputs[index].dtype
, self.outputs[index_o].dtype) ) , self.outputs[index_o].dtype) )
...@@ -794,7 +794,7 @@ class Scan(Op): ...@@ -794,7 +794,7 @@ class Scan(Op):
g_out_slices.append(g_outs_no_shared[dx][0]) g_out_slices.append(g_outs_no_shared[dx][0])
else: else:
g_out_slices.append(None) g_out_slices.append(None)
if out.name: if getattr(out,'name',None) is not None:
inner_g_out.name = 'g_'+out.name inner_g_out.name = 'g_'+out.name
else: else:
inner_g_out.name = 'g_'+str(dx) inner_g_out.name = 'g_'+str(dx)
...@@ -872,7 +872,7 @@ class Scan(Op): ...@@ -872,7 +872,7 @@ class Scan(Op):
nw_seq = seq[dim_offset +k -mintap: -(maxtap -k)] nw_seq = seq[dim_offset +k -mintap: -(maxtap -k)]
else: else:
nw_seq = seq[dim_offset +k -mintap: ] nw_seq = seq[dim_offset +k -mintap: ]
if seq.name: if getattr(seq,'name', None) is not None:
nw_seq.name = seq.name + '[%d:]'%k nw_seq.name = seq.name + '[%d:]'%k
scan_seqs.append(nw_seq) scan_seqs.append(nw_seq)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论