提交 9a70e813 authored 作者: carriepl's avatar carriepl

Fix bug and refactor in cython backend

上级 5e358424
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -62,7 +62,7 @@ import copy ...@@ -62,7 +62,7 @@ import copy
def get_version(): def get_version():
return 0.288 return 0.289
@cython.boundscheck(False) @cython.boundscheck(False)
def perform( def perform(
...@@ -445,7 +445,7 @@ def perform( ...@@ -445,7 +445,7 @@ def perform(
# modified inplace and nothing needs to be done. # modified inplace and nothing needs to be done.
if not same_data: if not same_data:
outs[j][0][<unsigned int>(k + pos[j])] = \ outs[j][0][<unsigned int>(k + pos[j])] = \
input_storage[<unsigned int>inp_idx].storage[0] input_storage[<unsigned int>(n_seqs + inp_idx)].storage[0]
else: else:
# This output tap has not been preallocated, recover # This output tap has not been preallocated, recover
...@@ -465,6 +465,10 @@ def perform( ...@@ -465,6 +465,10 @@ def perform(
for j in range(begin, end): for j in range(begin, end):
# Copy the output value to `outs`, if necessary
if store_steps[j] == 1 or vector_outs[j] == 1:
outs[j][0][pos[j]] = output_storage[<unsigned int>(offset_out+j)].storage[0]
else:
# Check whether the initialization of the output storage map # Check whether the initialization of the output storage map
# for this output has been reused. # for this output has been reused.
old_var = old_output_storage[offset_out + j] old_var = old_output_storage[offset_out + j]
...@@ -480,15 +484,33 @@ def perform( ...@@ -480,15 +484,33 @@ def perform(
else: else:
output_reused = False output_reused = False
# Copy the output value to `outs`, if necessary if not output_reused:
if store_steps[j] == 1 or vector_outs[j] == 1 or not output_reused: outs[j][0][pos[j]] = \
outs[j][0][pos[j]] = output_storage[<unsigned int>(offset_out+j)].storage[0] output_storage[<unsigned int>(offset_out+j)].storage[0]
# 5.5 Copy over the values for nit_sot outputs # 5.5 Copy over the values for nit_sot outputs
begin = end begin = end
end += n_nit_sot end += n_nit_sot
for j in range(begin,end): for j in range(begin,end):
if i == 0:
jout = j+offset_out
shape = (store_steps[j],) + output_storage[jout].storage[0].shape
if len(output_storage[jout].storage[0].shape) == 0:
vector_outs[j] = 1
dtype = output_storage[jout].storage[0].dtype
if (outs[j][0] is None or
outs[j][0].shape[0] < store_steps[j] or
outs[j][0].shape[1:] != shape[1:] or
outs[j][0].dtype != dtype ):
outs[j][0] = node.outputs[j].type.value_zeros(shape)
elif outs[j][0].shape[0] != store_steps[j]:
outs[j][0] = outs[j][0][:store_steps[j]]
outs[j][0][pos[j]] = output_storage[jout].storage[0]
elif store_steps[j] == 1 or vector_outs[j] == 1:
outs[j][0][pos[j]] = output_storage[j+offset_out].storage[0]
else:
# Check whether the initialization of the output storage map # Check whether the initialization of the output storage map
# for this output has been reused. # for this output has been reused.
old_var = old_output_storage[offset_out + j] old_var = old_output_storage[offset_out + j]
...@@ -504,22 +526,7 @@ def perform( ...@@ -504,22 +526,7 @@ def perform(
else: else:
output_reused = False output_reused = False
if i == 0: if not output_reused:
jout = j+offset_out
shape = (store_steps[j],) + output_storage[jout].storage[0].shape
if len(output_storage[jout].storage[0].shape) == 0:
vector_outs[j] = 1
dtype = output_storage[jout].storage[0].dtype
if (outs[j][0] is None or
outs[j][0].shape[0] < store_steps[j] or
outs[j][0].shape[1:] != shape[1:] or
outs[j][0].dtype != dtype ):
outs[j][0] = node.outputs[j].type.value_zeros(shape)
elif outs[j][0].shape[0] != store_steps[j]:
outs[j][0] = outs[j][0][:store_steps[j]]
outs[j][0][pos[j]] = output_storage[jout].storage[0]
elif (store_steps[j] == 1 or vector_outs[j] == 1 or
not output_reused):
outs[j][0][pos[j]] = output_storage[j+offset_out].storage[0] outs[j][0][pos[j]] = output_storage[j+offset_out].storage[0]
# 5.6 Copy over the values for outputs corresponding to shared # 5.6 Copy over the values for outputs corresponding to shared
......
...@@ -17,7 +17,7 @@ from theano.gof import cmodule ...@@ -17,7 +17,7 @@ from theano.gof import cmodule
_logger = logging.getLogger('theano.scan_module.scan_perform') _logger = logging.getLogger('theano.scan_module.scan_perform')
version = 0.288 # must match constant returned in function get_version() version = 0.289 # must match constant returned in function get_version()
need_reload = False need_reload = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论