提交 9a3fb648 authored 作者: Alexander Matyasko's avatar Alexander Matyasko

Fix shape concatenation in scan

上级 048d0c47
...@@ -2561,7 +2561,7 @@ class Scan(PureOp): ...@@ -2561,7 +2561,7 @@ class Scan(PureOp):
n_zeros = inputs[0] - n_steps n_zeros = inputs[0] - n_steps
shp = (n_zeros,) shp = (n_zeros,)
if x.ndim > 1: if x.ndim > 1:
shp = shp + x.shape[1:] shp = shp + tuple(x.shape[i] for i in range(1, x.ndim))
z = tensor.zeros(shp, dtype=x.dtype) z = tensor.zeros(shp, dtype=x.dtype)
x = tensor.concatenate([x[::-1], z], axis=0) x = tensor.concatenate([x[::-1], z], axis=0)
gradients.append(x) gradients.append(x)
...@@ -2589,7 +2589,7 @@ class Scan(PureOp): ...@@ -2589,7 +2589,7 @@ class Scan(PureOp):
n_zeros = inputs[0] - grad_steps n_zeros = inputs[0] - grad_steps
shp = (n_zeros,) shp = (n_zeros,)
if x.ndim > 1: if x.ndim > 1:
shp = shp + x.shape[1:] shp = shp + tuple(x.shape[i] for i in range(1, x.ndim))
z = tensor.zeros(shp, dtype=x.dtype) z = tensor.zeros(shp, dtype=x.dtype)
x = tensor.concatenate([x[::-1], z], axis=0) x = tensor.concatenate([x[::-1], z], axis=0)
gradients.append(x) gradients.append(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论