提交 201e95a7 authored 作者: nouiz's avatar nouiz

Merge pull request #249 from pascanur/infershape_cond_scan

Infershape cond scan
...@@ -979,6 +979,11 @@ class Scan(PureOp): ...@@ -979,6 +979,11 @@ class Scan(PureOp):
scan_outs += [x for x in scan_outs += [x for x in
input_shapes[offset:offset + self.n_shared_outs]] input_shapes[offset:offset + self.n_shared_outs]]
# if we are dealing with a repeat-until, then we do not know the
# leading dimension so we replace it for every entry with Shape_i
if self.as_while:
scan_outs = [(Shape_i(0)(o),)+x[1:]
for o, x in zip(node.outputs,scan_outs)]
return scan_outs return scan_outs
### GRAD FUNCTION ### GRAD FUNCTION
......
...@@ -2495,8 +2495,6 @@ class T_Scan(unittest.TestCase): ...@@ -2495,8 +2495,6 @@ class T_Scan(unittest.TestCase):
# so if it compiles it means the test pass # so if it compiles it means the test pass
f = theano.function([V, W], O) f = theano.function([V, W], O)
def test_while2(self): def test_while2(self):
x = tensor.vector('x') x = tensor.vector('x')
def lambda_fn(x_t): def lambda_fn(x_t):
...@@ -2515,6 +2513,17 @@ class T_Scan(unittest.TestCase): ...@@ -2515,6 +2513,17 @@ class T_Scan(unittest.TestCase):
if isinstance(x.op, theano.scan_module.scan_op.Scan)] if isinstance(x.op, theano.scan_module.scan_op.Scan)]
assert len(lssc) == 1 assert len(lssc) == 1
def test_while_infershape(self):
x = tensor.vector('x')
def lambda_fn(x_t):
return x_t + 1, theano.scan_module.until( x_t > 3)
o, _ = theano.scan(lambda_fn, x)
f = theano.function([x], o.shape[0], mode=mode_with_opt)
vx = numpy.zeros((50,), dtype = theano.config.floatX)
vx[23] = 4
out = f(vx)
assert out == 24
def test_grad_multiple_seqs_different_nsteps(self): def test_grad_multiple_seqs_different_nsteps(self):
# Example provided Michael Forbes # Example provided Michael Forbes
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论