提交 6d661d80 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix to Scan.infer_shape after code review.

上级 2ffbd9fa
...@@ -602,7 +602,7 @@ class Scan(Op): ...@@ -602,7 +602,7 @@ class Scan(Op):
# Non-sequences have a direct equivalent from self.inputs in node.inputs # Non-sequences have a direct equivalent from self.inputs in node.inputs
inner_non_sequences = self.inputs[len(seqs_shape) + len(outs_shape):] inner_non_sequences = self.inputs[len(seqs_shape) + len(outs_shape):]
out_equivalent = {} out_equivalent = {}
for in_ns, out_ns in zip(inner_non_sequences, input_shapes[offset:]): for in_ns, out_ns in zip(inner_non_sequences, node.inputs[offset:]):
out_equivalent[in_ns] = out_ns out_equivalent[in_ns] = out_ns
outs_shape = scan_utils.infer_shape( outs_shape = scan_utils.infer_shape(
...@@ -629,9 +629,9 @@ class Scan(Op): ...@@ -629,9 +629,9 @@ class Scan(Op):
# node.inputs, and constants, without using the variables # node.inputs, and constants, without using the variables
# in the inner function. # in the inner function.
r = node.outputs[n_outs+x] r = node.outputs[n_outs+x]
assert r.ndim == 1 + len(outs_shape[n_outs+x]) assert r.ndim == 1 + len(out_shape_x)
shp = [node.inputs[offset+self.n_shared_outs+x]] shp = [node.inputs[offset+self.n_shared_outs+x]]
for i, shp_i in zip(xrange(1,r.ndim), outs_shape[n_outs+x]): for i, shp_i in zip(xrange(1,r.ndim), out_shape_x):
# Validate shp_i. v_shape_i is either None (if invalid), # Validate shp_i. v_shape_i is either None (if invalid),
# or a (variable, Boolean) tuple. The Boolean indicates # or a (variable, Boolean) tuple. The Boolean indicates
# whether variable is shp_i (if True), or an valid # whether variable is shp_i (if True), or an valid
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论