提交 8f5ed94e authored 作者: Razvan Pascanu's avatar Razvan Pascanu

consider corner case regarding initial state

initial state can be one dimension less if the taps equal just [-1]
上级 63ae9bb3
...@@ -66,17 +66,16 @@ class TestScan(unittest.TestCase): ...@@ -66,17 +66,16 @@ class TestScan(unittest.TestCase):
scan_inputs.append(dict(input=inp, taps=[x['tap'] for x in scan_inputs.append(dict(input=inp, taps=[x['tap'] for x in
info])) info]))
n_states = len(states_info) n_states = len(states_info)
states = [tensor.matrix('x%d' % k) for k in xrange(n_states)]
scan_states = [] scan_states = []
states = [] states = []
for state, info in zip(states, states_info): for info in states_info:
if len(info) == 1 and info[0]['tap'] == -1: if len(info) == 1 and info[0]['tap'] == -1:
state = tensor.vector('x%d' % k) state = tensor.vector('x%d' % k)
states.append(state) states.append(state)
scan_states.append(state) scan_states.append(state)
else: else:
state = tensor.matrix('x%d' % k) state = tensor.matrix('x%d' % k)
states.append(states) states.append(state)
scan_states.append( scan_states.append(
dict(initial=state, taps=[x['tap'] for x in info])) dict(initial=state, taps=[x['tap'] for x in info]))
n_parameters = len(parameters_info) n_parameters = len(parameters_info)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论