提交 68e0fea1 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Don't force variables to have a name field .. but rather first check if

there is one or not.
上级 d1c51b24
......@@ -371,7 +371,7 @@ def scan( fn
# ^ explicitly provided a None for taps
warning (' Output %s ( index %d) has a initial state '
' but taps is explicitly set to None ' % (
outs_info[i]['initial'].name
getattr(outs_info[i]['initial'],'name','None')
, i) )
outs_info[i]['taps'] = [-1]
else:
......@@ -416,12 +416,10 @@ def scan( fn
nw_slice = seq['input'][0].type()
actual_slice = seq['input'][k-mintap]
if not hasattr(seq['input'],'name'):
raise TypeError('Expected object with a "name" field, got '+str(seq)+"['input'] = "+str(seq['input']))
# Add names to slices for debugging and pretty printing ..
# that is if the input already has a name
if seq['input'].name:
if getattr(seq['input'],'name', None) is not None:
if k > 0:
nw_name = seq['input'].name + '[t+%d]'%k
elif k == 0:
......@@ -481,7 +479,7 @@ def scan( fn
# Add names -- it helps a lot when debugging
for (nw_seq, seq) in zip(scan_seqs, seqs):
if seq['input'].name:
if getattr(seq['input'],'name', None) is not None:
nw_seq.name = seq['input'].name + '[%d:]'%k
# Conventions :
......@@ -534,7 +532,7 @@ def scan( fn
actual_arg = init_out['initial']
arg = safe_new(init_out['initial'])
if init_out['initial'].name:
if getattr(init_out['initial'],'name', None) is not None:
arg.name = init_out['initial'].name+'[t-1]'
# We need now to allocate space for storing the output and copy
# the initial state over. We do this using the expand function
......@@ -579,7 +577,7 @@ def scan( fn
nw_slice = init_out['initial'][0].type()
# give it a name or debugging and pretty printing
if init_out['initial'].name:
if getattr(init_out['initial'],'name', None) is not None:
if k > 0:
nw_slice.name = ( init_out['initial'].name +
'[t+%d]'%k )
......@@ -746,7 +744,7 @@ def scan( fn
for input in dummy_f.maker.expanded_inputs:
if isinstance(input.variable, SharedVariable) and input.update:
new_var = safe_new(input.variable)
if input.variable.name:
if getattr(input.variable,'name', None) is not None:
new_var.name = input.variable.name + '_copy'
shared_inner_inputs.append( new_var )
shared_scan_inputs.append( input.variable )
......@@ -777,7 +775,7 @@ def scan( fn
## Step 5.6 all shared variables with no update rules
def new_variable( v ):
new_v = safe_new(v)
if v.name:
if getattr(v,'name', None) is not None:
new_v.name = v.name + '_copy'
return new_v
other_inner_args += [ new_variable(arg) for arg in non_seqs
......
......@@ -226,10 +226,10 @@ class Scan(Op):
for idx in xrange(self.n_seqs):
if inputs[1+idx].dtype != self.inputs[idx].dtype:
raise ValueError(err_msg1%( 'Sequence'
, inputs[1+idx].name
, str(inputs[1+idx])
, idx
, inputs[1+idx].dtype
, self.inputs[idx].name
, str(self.inputs[idx])
, self.inputs[idx].dtype) )
# Check that this 3 things have the same dtype for mit_mot:
......@@ -246,10 +246,10 @@ class Scan(Op):
for k in self.tap_array[index-start]:
if inputs[index].dtype != self.inputs[index_i].dtype:
raise ValueError(err_msg1%( 'Initial state'
, inputs[index].name
, str(inputs[index])
, index
, inputs[index].dtype
, self.inputs[index_i].name
, str(self.inputs[index_i])
, self.inputs[index_i].dtype) )
index_i += 1
for k in self.mit_mot_out_slices[index-start]:
......@@ -266,14 +266,14 @@ class Scan(Op):
for k in self.tap_array[index-start]:
if inputs[index].dtype != self.inputs[index_i].dtype:
raise ValueError(err_msg1%( 'Initial state'
, inputs[index].name
, str(inputs[index])
, index
, inputs[index].dtype
, self.inputs[index_i].name
, str(self.inputs[index_i])
, self.inputs[index_i].dtype) )
index_i += 1
if inputs[index].dtype != self.outputs[index_o].dtype:
raise ValueError(err_msg2%( inputs[index].name
raise ValueError(err_msg2%( str(inputs[index])
, index
, inputs[index].dtype
, self.outputs[index_o].dtype) )
......@@ -287,7 +287,7 @@ class Scan(Op):
while index < end:
if (hasattr(inputs[index],'dtype') and
inputs[index].dtype != self.outputs[index_o].dtype):
raise ValueError(err_msg2%( inputs[index].name
raise ValueError(err_msg2%( str(inputs[index])
, index
, inputs[index].dtype
, self.outputs[index_o].dtype) )
......@@ -794,7 +794,7 @@ class Scan(Op):
g_out_slices.append(g_outs_no_shared[dx][0])
else:
g_out_slices.append(None)
if out.name:
if getattr(out,'name',None) is not None:
inner_g_out.name = 'g_'+out.name
else:
inner_g_out.name = 'g_'+str(dx)
......@@ -872,7 +872,7 @@ class Scan(Op):
nw_seq = seq[dim_offset +k -mintap: -(maxtap -k)]
else:
nw_seq = seq[dim_offset +k -mintap: ]
if seq.name:
if getattr(seq,'name', None) is not None:
nw_seq.name = seq.name + '[%d:]'%k
scan_seqs.append(nw_seq)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论