提交 cfdd63a6 authored 作者: carriepl's avatar carriepl

Make Scan allocate initial buffers with expand_empty

上级 68180cf9
...@@ -627,7 +627,7 @@ def scan(fn, ...@@ -627,7 +627,7 @@ def scan(fn,
# the initial state over. We do this using the expand function # the initial state over. We do this using the expand function
# defined in scan utils # defined in scan utils
sit_sot_scan_inputs.append( sit_sot_scan_inputs.append(
scan_utils.expand( scan_utils.expand_empty(
tensor.unbroadcast( tensor.unbroadcast(
tensor.shape_padleft(actual_arg), 0), tensor.shape_padleft(actual_arg), 0),
actual_n_steps actual_n_steps
...@@ -653,7 +653,7 @@ def scan(fn, ...@@ -653,7 +653,7 @@ def scan(fn,
idx_offset = abs(numpy.min(init_out['taps'])) idx_offset = abs(numpy.min(init_out['taps']))
# Sequence # Sequence
mit_sot_scan_inputs.append( mit_sot_scan_inputs.append(
scan_utils.expand(init_out['initial'][:mintap], scan_utils.expand_empty(init_out['initial'][:mintap],
actual_n_steps)) actual_n_steps))
if i in return_steps: if i in return_steps:
...@@ -866,7 +866,7 @@ def scan(fn, ...@@ -866,7 +866,7 @@ def scan(fn,
if isinstance(new_var.type, ops.expandable_types): if isinstance(new_var.type, ops.expandable_types):
sit_sot_inner_inputs.append(new_var) sit_sot_inner_inputs.append(new_var)
sit_sot_scan_inputs.append( sit_sot_scan_inputs.append(
scan_utils.expand( scan_utils.expand_empty(
tensor.unbroadcast( tensor.unbroadcast(
tensor.shape_padleft(input.variable), 0), tensor.shape_padleft(input.variable), 0),
actual_n_steps)) actual_n_steps))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论