提交 9003aba4 authored 作者: Frederic's avatar Frederic

Do the vector workaround only on TensorVariable and document it.

上级 54c8ee2f
...@@ -422,9 +422,15 @@ class Scan(PureOp): ...@@ -422,9 +422,15 @@ class Scan(PureOp):
raise ValueError('For output %s you need to provide a ' raise ValueError('For output %s you need to provide a '
'scalar int !', str(outer_nitsot)) 'scalar int !', str(outer_nitsot))
assert len(new_inputs) == len(inputs) assert len(new_inputs) == len(inputs)
self.vector_seqs = [seq.ndim == 1 for seq in
new_inputs[1:1 + self.n_seqs]] # The vector_seqs and vector_outs are just a workaround
self.vector_outs = [arg.ndim == 1 for arg in # strange NumPy behavior: vector_ndarray[int] return a NumPy
# scalar and not a NumPy ndarray of 0 dimensions.
self.vector_seqs = [isinstance(seq, tensor.TensorVariable) and
seq.ndim == 1 for seq in
new_inputs[1:1 + self.n_seqs]]
self.vector_outs = [isinstance(arg, tensor.TensorVariable) and
arg.ndim == 1 for arg in
new_inputs[1 + self.n_seqs: (1 + self.n_seqs + new_inputs[1 + self.n_seqs: (1 + self.n_seqs +
self.n_outs)]] self.n_outs)]]
self.vector_outs += [False] * self.n_nit_sot self.vector_outs += [False] * self.n_nit_sot
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论