提交 caef102d authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Adding the cython version of scan

While we agreed that there might be a more principial way of solving this, this solution was fast to add and it is pretty efficient for now.
上级 95e010f5
......@@ -357,6 +357,10 @@ class Scan(Op):
the thunk can potentially cache return values (like CLinker does),
then it must not do so for variables in the no_recycling list.
"""
# Setting up all my variables in what I believe is a more Cython
# friendly form
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
node_input_compute = [compute_map[r] for r in node.inputs]
......@@ -384,7 +388,73 @@ class Scan(Op):
name = self.name,
profile = profile)
p = self.execute
try:
cython_mintaps = numpy.asarray(self.mintaps, dtype = 'int32')
cython_tap_array_len = \
numpy.asarray([ len(x) for x in self.tap_array],
dtype='int32')
if len(self.tap_array) == 0:
d1 = 0
else:
d1 = numpy.max(cython_tap_array_len)
d0 = len(self.tap_array)
cython_tap_array = numpy.zeros((d0,d1), dtype='int32')
for _d0 in range(d0):
for _d1 in range(cython_tap_array_len[_d0]):
cython_tap_array[_d0,_d1] = self.tap_array[_d0][_d1]
cython_mit_mot_out_nslices = \
numpy.asarray([ len(x) for x in self.mit_mot_out_slices],
dtype='int32')
if len(self.mit_mot_out_slices) == 0:
d1 = 0
else:
d1 = numpy.max(cython_mit_mot_out_nslices)
d0 = len(self.mit_mot_out_slices)
cython_mit_mot_out_slices = numpy.zeros((d0,d1),
dtype='int32')
for _d0 in range(d0):
for _d1 in range(cython_mit_mot_out_nslices[_d0]):
cython_mit_mot_out_slices[_d0,_d1] = \
self.mit_mot_out_slices[_d0][_d1]
vector_seqs = [ seq.ndim == 1 for seq in
self.inputs[1:1+self.n_seqs ] ]
vector_outs = [ arg.ndim ==1 for arg in
self.inputs[1+self.n_seqs: (1+self.n_seqs +
self.n_outs)] ]
vector_outs += [ False]*self.n_nit_sot
cython_vector_seqs = numpy.asarray(self.vector_seqs,
dtype='int32')
cython_vector_outs = numpy.asarray(self.vector_outs,
dtype='int32')
import scan_perform_ext
p = lambda node, args, outs:\
scan_perform_ext.perform(
self.n_shared_outs,
self.n_mit_mot_outs,
self.n_seqs,
self.n_mit_mot,
self.n_mit_sot,
self.n_sit_sot,
self.n_nit_sot,
args[0],
self.as_while,
cython_mintaps,
cython_tap_array,
cython_tap_array_len,
cython_vector_seqs,
cython_vector_outs,
cython_mit_mot_out_slices,
cython_mit_mot_out_nslices,
self.fn.fn,
self.fn,
self.inplace,
args,
outs,
self)
except ImportError:
p = self.execute
# default arguments are stored in the closure of `rval`
def rval(p=p, i=node_input_storage, o=node_output_storage, n=node):
r = p(n, [x[0] for x in i], o)
......
This source diff could not be displayed because it is too large. You can view the blob instead.
差异被折叠。
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论