提交 3fbd990b authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Change the join(x, alloc(..)) with set_subtensor( alloc(..)[..], x)

上级 b0a7305b
......@@ -380,8 +380,7 @@ class ScanSaveMem(Optimizer):
# If the memory for this output has been pre-allocated
# before going into the scan op (by an alloc node)
if idx < op.n_mit_sot + op.n_sit_sot:
_nw_input =nw_inputs[offset+idx].owner.inputs[1]
_nw_input = nw_inputs[offset+idx].owner.inputs[1]
nw_input = scan_utils.expand( _nw_input, val - init_l[i] )
nw_inputs[offset+idx] = nw_input
replaced_outs.append(op.n_mit_mot + idx)
......
......@@ -533,7 +533,8 @@ def scan( fn
# defined in scan utils
sit_sot_scan_inputs.append(
scan_utils.expand(
tensor.shape_padleft(actual_arg)
tensor.unbroadcast(
tensor.shape_padleft(actual_arg), 0)
, actual_n_steps
) )
......
......@@ -509,11 +509,18 @@ def expand( tensor_var, size):
# Corner case that I might use in an optimization
if size == 0:
return tensor_var
#shapes = [ tensor_var.shape[x] for x in xrange(tensor_var.ndim) ]
#zeros_shape = [size] + shapes[1:]
#empty = tensor.zeros( zeros_shape
# , dtype = tensor_var.dtype)
#return tensor.join(0, tensor_var, empty)
# V2:
shapes = [ tensor_var.shape[x] for x in xrange(tensor_var.ndim) ]
zeros_shape = [size] + shapes[1:]
zeros_shape = [size+shapes[0]] + shapes[1:]
empty = tensor.zeros( zeros_shape
, dtype = tensor_var.dtype)
return tensor.join(0, tensor_var, empty)
return tensor.set_subtensor(empty[:shapes[0]], tensor_var)
class Clone(Op):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论