提交 9333a813 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

a few more changes to make the cython code work

上级 2acda2d0
...@@ -63,7 +63,7 @@ from theano.sandbox import cuda ...@@ -63,7 +63,7 @@ from theano.sandbox import cuda
def get_version(): def get_version():
return 0.266 return 0.276
@cython.boundscheck(False) @cython.boundscheck(False)
def perform( def perform(
...@@ -85,7 +85,7 @@ def perform( ...@@ -85,7 +85,7 @@ def perform(
numpy.ndarray[numpy.int32_t,ndim=1] mit_mot_out_nslices, numpy.ndarray[numpy.int32_t,ndim=1] mit_mot_out_nslices,
fn, fn,
fnct, fnct,
bint inplace, numpy.ndarray[numpy.int32_t,ndim=1] destroy_map,
args, args,
outs, outs,
self): self):
...@@ -145,9 +145,8 @@ def perform( ...@@ -145,9 +145,8 @@ def perform(
fnct: python object fnct: python object
Only used to attach some timings for the profile mode ( can be Only used to attach some timings for the profile mode ( can be
skiped if we don't care about Theano's profile mode) skiped if we don't care about Theano's profile mode)
inplace destroy_map
Boolean that says if things should be computed inplace or if they Array of boolean saying if an output is computed inplace
should not.
args: list of ndarrays (and random states) args: list of ndarrays (and random states)
The inputs of scan in a given order ( n_steps, sequences, mit_mot, The inputs of scan in a given order ( n_steps, sequences, mit_mot,
mit_sot, sit_sot, nit_sot, shared_outs, other_args) mit_sot, sit_sot, nit_sot, shared_outs, other_args)
...@@ -230,7 +229,7 @@ def perform( ...@@ -230,7 +229,7 @@ def perform(
# 2.1 Create storage space for outputs # 2.1 Create storage space for outputs
for idx in range(n_outs): for idx in range(n_outs):
if inplace: if destroy_map[idx] != 0:
# ^ Case 1. Outputs should be computed inplace of their # ^ Case 1. Outputs should be computed inplace of their
# initial state # initial state
outs[idx][0] = args[ <unsigned int>(1+ n_seqs + idx)] outs[idx][0] = args[ <unsigned int>(1+ n_seqs + idx)]
......
...@@ -14,7 +14,7 @@ logging.basicConfig(level=logging.DEBUG) ...@@ -14,7 +14,7 @@ logging.basicConfig(level=logging.DEBUG)
if config.compiledir not in sys.path: if config.compiledir not in sys.path:
sys.path.append(config.compiledir) sys.path.append(config.compiledir)
version = 0.266 # must match constant returned in function get_version() version = 0.276 # must match constant returned in function get_version()
need_reload = False need_reload = False
try: try:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论