提交 2c24b282 authored 作者: Frederic Bastien's avatar Frederic Bastien

Add comment and assert in scan code.

上级 758f0ff9
...@@ -166,7 +166,13 @@ class Scan(PureOp): ...@@ -166,7 +166,13 @@ class Scan(PureOp):
'could happen if the inner graph of scan results in ' 'could happen if the inner graph of scan results in '
'an upcast or downcast. Please make sure that you use' 'an upcast or downcast. Please make sure that you use'
'dtypes consistently') 'dtypes consistently')
# TODO make the assert exact
# TODO assert the type(dtype, nbdim of self.inputs and inputs correspond)
#assert len(inputs) >= len(self.inputs)
# if self.info['as_while']:
# assert len(inputs) == len(self.inputs) + 2 + self.info["n_nit_sot"]
# else:
# assert len(inputs) == len(self.inputs) + 1 + self.info["n_nit_sot"]
# Flags that indicate which inputs are vectors # Flags that indicate which inputs are vectors
self.vector_seqs = [seq.ndim == 1 for seq in self.vector_seqs = [seq.ndim == 1 for seq in
...@@ -903,7 +909,12 @@ class Scan(PureOp): ...@@ -903,7 +909,12 @@ class Scan(PureOp):
# Here, we build a list inner_ins_shape, such that inner_ins_shape[i] # Here, we build a list inner_ins_shape, such that inner_ins_shape[i]
# is the shape of self.inputs[i] # is the shape of self.inputs[i]
for inp, inp_shp in zip(node.inputs, input_shapes):
assert inp_shp is None or len(inp_shp) == inp.ndim
# sequences # sequences
# We skip iputs_shapes[0] as it is the total or current number
# of iteration
seqs_shape = [x[1:] for x in input_shapes[1:1 + self.n_seqs]] seqs_shape = [x[1:] for x in input_shapes[1:1 + self.n_seqs]]
# mit_mot, mit_sot, sit_sot # mit_mot, mit_sot, sit_sot
......
...@@ -372,6 +372,10 @@ def infer_shape(outs, inputs, input_shapes): ...@@ -372,6 +372,10 @@ def infer_shape(outs, inputs, input_shapes):
# inside. We don't use the full ShapeFeature interface, but we # inside. We don't use the full ShapeFeature interface, but we
# let it initialize itself with an empty env, otherwise we will # let it initialize itself with an empty env, otherwise we will
# need to do it manually # need to do it manually
for inp, inp_shp in zip(inputs, input_shapes):
if inp_shp is not None and len(inp_shp) != inp.ndim:
assert len(inp_shp) == inp.ndim
shape_feature = tensor.opt.ShapeFeature() shape_feature = tensor.opt.ShapeFeature()
shape_feature.on_attach(theano.gof.Env([], [])) shape_feature.on_attach(theano.gof.Env([], []))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论