提交 8de65286 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

infer_shape function for scan

上级 c9a2c0dd
...@@ -149,6 +149,17 @@ class ScanOp(PureOp): ...@@ -149,6 +149,17 @@ class ScanOp(PureOp):
rval = rval ^ val rval = rval ^ val
return rval return rval
def infer_shape(self, node, input_shapes):
for inp, inp_shp in izip(node.inputs, input_shapes):
assert inp_shp is None or len(inp_shp) == inp.type.ndim
n_outs = len(self.outputs)
if self.gpu:
return [(Shape_i(0)(o),) + x[1:] for o, x
in izip(node.outputs, input_shapes[1: n_outs + 1])]
else:
return input_shapes[1: n_outs + 1]
def make_thunk(self, node, storage_map, compute_map, no_recycling): def make_thunk(self, node, storage_map, compute_map, no_recycling):
pass pass
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论