提交 3fbd990b authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Change the join(x, alloc(..)) with set_subtensor( alloc(..)[..], x)

上级 b0a7305b
...@@ -380,8 +380,7 @@ class ScanSaveMem(Optimizer): ...@@ -380,8 +380,7 @@ class ScanSaveMem(Optimizer):
# If the memory for this output has been pre-allocated # If the memory for this output has been pre-allocated
# before going into the scan op (by an alloc node) # before going into the scan op (by an alloc node)
if idx < op.n_mit_sot + op.n_sit_sot: if idx < op.n_mit_sot + op.n_sit_sot:
_nw_input =nw_inputs[offset+idx].owner.inputs[1] _nw_input = nw_inputs[offset+idx].owner.inputs[1]
nw_input = scan_utils.expand( _nw_input, val - init_l[i] ) nw_input = scan_utils.expand( _nw_input, val - init_l[i] )
nw_inputs[offset+idx] = nw_input nw_inputs[offset+idx] = nw_input
replaced_outs.append(op.n_mit_mot + idx) replaced_outs.append(op.n_mit_mot + idx)
......
...@@ -533,7 +533,8 @@ def scan( fn ...@@ -533,7 +533,8 @@ def scan( fn
# defined in scan utils # defined in scan utils
sit_sot_scan_inputs.append( sit_sot_scan_inputs.append(
scan_utils.expand( scan_utils.expand(
tensor.shape_padleft(actual_arg) tensor.unbroadcast(
tensor.shape_padleft(actual_arg), 0)
, actual_n_steps , actual_n_steps
) ) ) )
......
...@@ -509,11 +509,18 @@ def expand( tensor_var, size): ...@@ -509,11 +509,18 @@ def expand( tensor_var, size):
# Corner case that I might use in an optimization # Corner case that I might use in an optimization
if size == 0: if size == 0:
return tensor_var return tensor_var
#shapes = [ tensor_var.shape[x] for x in xrange(tensor_var.ndim) ]
#zeros_shape = [size] + shapes[1:]
#empty = tensor.zeros( zeros_shape
# , dtype = tensor_var.dtype)
#return tensor.join(0, tensor_var, empty)
# V2:
shapes = [ tensor_var.shape[x] for x in xrange(tensor_var.ndim) ] shapes = [ tensor_var.shape[x] for x in xrange(tensor_var.ndim) ]
zeros_shape = [size] + shapes[1:] zeros_shape = [size+shapes[0]] + shapes[1:]
empty = tensor.zeros( zeros_shape empty = tensor.zeros( zeros_shape
, dtype = tensor_var.dtype) , dtype = tensor_var.dtype)
return tensor.join(0, tensor_var, empty) return tensor.set_subtensor(empty[:shapes[0]], tensor_var)
class Clone(Op): class Clone(Op):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论