提交 caef102d authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Adding the cython version of scan

While we agreed that there might be a more principial way of solving this, this solution was fast to add and it is pretty efficient for now.
上级 95e010f5
...@@ -357,6 +357,10 @@ class Scan(Op): ...@@ -357,6 +357,10 @@ class Scan(Op):
the thunk can potentially cache return values (like CLinker does), the thunk can potentially cache return values (like CLinker does),
then it must not do so for variables in the no_recycling list. then it must not do so for variables in the no_recycling list.
""" """
# Setting up all my variables in what I believe is a more Cython
# friendly form
node_input_storage = [storage_map[r] for r in node.inputs] node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs] node_output_storage = [storage_map[r] for r in node.outputs]
node_input_compute = [compute_map[r] for r in node.inputs] node_input_compute = [compute_map[r] for r in node.inputs]
...@@ -384,7 +388,73 @@ class Scan(Op): ...@@ -384,7 +388,73 @@ class Scan(Op):
name = self.name, name = self.name,
profile = profile) profile = profile)
p = self.execute
try:
cython_mintaps = numpy.asarray(self.mintaps, dtype = 'int32')
cython_tap_array_len = \
numpy.asarray([ len(x) for x in self.tap_array],
dtype='int32')
if len(self.tap_array) == 0:
d1 = 0
else:
d1 = numpy.max(cython_tap_array_len)
d0 = len(self.tap_array)
cython_tap_array = numpy.zeros((d0,d1), dtype='int32')
for _d0 in range(d0):
for _d1 in range(cython_tap_array_len[_d0]):
cython_tap_array[_d0,_d1] = self.tap_array[_d0][_d1]
cython_mit_mot_out_nslices = \
numpy.asarray([ len(x) for x in self.mit_mot_out_slices],
dtype='int32')
if len(self.mit_mot_out_slices) == 0:
d1 = 0
else:
d1 = numpy.max(cython_mit_mot_out_nslices)
d0 = len(self.mit_mot_out_slices)
cython_mit_mot_out_slices = numpy.zeros((d0,d1),
dtype='int32')
for _d0 in range(d0):
for _d1 in range(cython_mit_mot_out_nslices[_d0]):
cython_mit_mot_out_slices[_d0,_d1] = \
self.mit_mot_out_slices[_d0][_d1]
vector_seqs = [ seq.ndim == 1 for seq in
self.inputs[1:1+self.n_seqs ] ]
vector_outs = [ arg.ndim ==1 for arg in
self.inputs[1+self.n_seqs: (1+self.n_seqs +
self.n_outs)] ]
vector_outs += [ False]*self.n_nit_sot
cython_vector_seqs = numpy.asarray(self.vector_seqs,
dtype='int32')
cython_vector_outs = numpy.asarray(self.vector_outs,
dtype='int32')
import scan_perform_ext
p = lambda node, args, outs:\
scan_perform_ext.perform(
self.n_shared_outs,
self.n_mit_mot_outs,
self.n_seqs,
self.n_mit_mot,
self.n_mit_sot,
self.n_sit_sot,
self.n_nit_sot,
args[0],
self.as_while,
cython_mintaps,
cython_tap_array,
cython_tap_array_len,
cython_vector_seqs,
cython_vector_outs,
cython_mit_mot_out_slices,
cython_mit_mot_out_nslices,
self.fn.fn,
self.fn,
self.inplace,
args,
outs,
self)
except ImportError:
p = self.execute
# default arguments are stored in the closure of `rval` # default arguments are stored in the closure of `rval`
def rval(p=p, i=node_input_storage, o=node_output_storage, n=node): def rval(p=p, i=node_input_storage, o=node_output_storage, n=node):
r = p(n, [x[0] for x in i], o) r = p(n, [x[0] for x in i], o)
......
This source diff could not be displayed because it is too large. You can view the blob instead.
差异被折叠。
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论