提交 319c583a authored 作者: carriepl's avatar carriepl

Implement expand_empty in scan_utils.py

上级 9c6b0886
...@@ -624,6 +624,22 @@ def expand(tensor_var, size): ...@@ -624,6 +624,22 @@ def expand(tensor_var, size):
return tensor.set_subtensor(empty[:shapes[0]], tensor_var) return tensor.set_subtensor(empty[:shapes[0]], tensor_var)
def expand_empty(tensor_var, size):
"""
Transforms the shape of a tensor from (d1, d2 ... ) to ( d1+size, d2, ..)
by adding uninitialized memory at the end of the tensor.
"""
if size == 0:
return tensor_var
shapes = [tensor_var.shape[x] for x in xrange(tensor_var.ndim)]
new_shape = [size + shapes[0]] + shapes[1:]
empty = tensor.AllocEmpty(tensor_var.dtype)(*new_shape)
return tensor.set_subtensor(empty[:shapes[0]], tensor_var)
def equal_computations(xs, ys, in_xs=None, in_ys=None): def equal_computations(xs, ys, in_xs=None, in_ys=None):
"""Checks if Theano graphs represent the same computations. """Checks if Theano graphs represent the same computations.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论