提交 6349790c authored 作者: Frederic Bastien's avatar Frederic Bastien

Make ScipyGer use prepare_node

上级 ca40ef22
...@@ -2172,8 +2172,8 @@ class GpuConv(GpuOp): ...@@ -2172,8 +2172,8 @@ class GpuConv(GpuOp):
bmode = 0 bmode = 0
if max_threads_dim0 is None: if max_threads_dim0 is None:
raise NotImplementedError("GpuConv.c_code should not be called " raise NotImplementedError("GpuConv.c_code should not be called "
"directly. It should be called by " "directly. It should be called after "
"make_thunk() that add some information " "prepare_node() that add some information "
"related to the selected GPU.") "related to the selected GPU.")
sub.update(locals()) sub.update(locals())
return """ return """
......
...@@ -22,22 +22,18 @@ if have_fblas: ...@@ -22,22 +22,18 @@ if have_fblas:
class ScipyGer(Ger): class ScipyGer(Ger):
# keep everything else, but override the make_thunk def prepare_node(self, node, storage_map, compute_map):
def make_thunk(self, node, storage_map, compute_map, no_recycling): if impl == 'py':
node.tag.local_ger = _blas_ger_fns[numpy.dtype(
node_input_storage = [storage_map[r] for r in node.inputs] node.inputs[0].type.dtype)]
node_output_storage = [storage_map[r] for r in node.outputs]
node_output_compute = [compute_map[r] for r in node.outputs] def perform(self, node, inputs, output_storage):
cA, calpha, cx, cy = inputs
# get vars for containers cZ, = output_storage
cA, calpha, cx, cy = node_input_storage
cZ, = node_output_storage
local_ger = _blas_ger_fns[numpy.dtype(node.inputs[0].type.dtype)]
def rval():
# N.B. some versions of scipy (e.g. mine) don't actually work # N.B. some versions of scipy (e.g. mine) don't actually work
# in-place on a, even when I tell it to. # in-place on a, even when I tell it to.
A = cA[0] A = cA
local_ger = node.tag.local_ger
if A.size == 0: if A.size == 0:
# We don't have to compute anything, A is empty. # We don't have to compute anything, A is empty.
# We need this special case because Numpy considers it # We need this special case because Numpy considers it
...@@ -47,21 +43,13 @@ class ScipyGer(Ger): ...@@ -47,21 +43,13 @@ class ScipyGer(Ger):
# so here to stop DebugMode from complaining. # so here to stop DebugMode from complaining.
A = A.copy() A = A.copy()
elif A.flags['C_CONTIGUOUS']: elif A.flags['C_CONTIGUOUS']:
A = local_ger(calpha[0], cy[0], cx[0], a=A.T, A = local_ger(calpha, cy, cx, a=A.T,
overwrite_a=int(self.destructive)).T overwrite_a=int(self.destructive)).T
else: else:
A = local_ger(calpha[0], cx[0], cy[0], a=A, A = local_ger(calpha, cx, cy, a=A,
overwrite_a=int(self.destructive)) overwrite_a=int(self.destructive))
cZ[0] = A cZ[0] = A
for o in node_output_compute:
o[0] = True
# TODO: If this is currently an unofficial part of the thunk API,
# then maybe it should be documented and made official?
rval.inputs = node_input_storage
rval.outputs = node_output_storage
rval.lazy = False
return rval
scipy_ger_no_inplace = ScipyGer(False) scipy_ger_no_inplace = ScipyGer(False)
scipy_ger_inplace = ScipyGer(True) scipy_ger_inplace = ScipyGer(True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论