提交 6349790c authored 作者: Frederic Bastien's avatar Frederic Bastien

Make ScipyGer use prepare_node

上级 ca40ef22
......@@ -2172,8 +2172,8 @@ class GpuConv(GpuOp):
bmode = 0
if max_threads_dim0 is None:
raise NotImplementedError("GpuConv.c_code should not be called "
"directly. It should be called by "
"make_thunk() that add some information "
"directly. It should be called after "
"prepare_node() that add some information "
"related to the selected GPU.")
sub.update(locals())
return """
......
......@@ -22,46 +22,34 @@ if have_fblas:
class ScipyGer(Ger):
# keep everything else, but override the make_thunk
def make_thunk(self, node, storage_map, compute_map, no_recycling):
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
node_output_compute = [compute_map[r] for r in node.outputs]
# get vars for containers
cA, calpha, cx, cy = node_input_storage
cZ, = node_output_storage
local_ger = _blas_ger_fns[numpy.dtype(node.inputs[0].type.dtype)]
def rval():
# N.B. some versions of scipy (e.g. mine) don't actually work
# in-place on a, even when I tell it to.
A = cA[0]
if A.size == 0:
# We don't have to compute anything, A is empty.
# We need this special case because Numpy considers it
# C-contiguous, wich is confusing.
if not self.destructive:
# Sometimes numpy thinks empty matrices can share memory,
# so here to stop DebugMode from complaining.
A = A.copy()
elif A.flags['C_CONTIGUOUS']:
A = local_ger(calpha[0], cy[0], cx[0], a=A.T,
overwrite_a=int(self.destructive)).T
else:
A = local_ger(calpha[0], cx[0], cy[0], a=A,
overwrite_a=int(self.destructive))
cZ[0] = A
for o in node_output_compute:
o[0] = True
# TODO: If this is currently an unofficial part of the thunk API,
# then maybe it should be documented and made official?
rval.inputs = node_input_storage
rval.outputs = node_output_storage
rval.lazy = False
return rval
def prepare_node(self, node, storage_map, compute_map):
if impl == 'py':
node.tag.local_ger = _blas_ger_fns[numpy.dtype(
node.inputs[0].type.dtype)]
def perform(self, node, inputs, output_storage):
cA, calpha, cx, cy = inputs
cZ, = output_storage
# N.B. some versions of scipy (e.g. mine) don't actually work
# in-place on a, even when I tell it to.
A = cA
local_ger = node.tag.local_ger
if A.size == 0:
# We don't have to compute anything, A is empty.
# We need this special case because Numpy considers it
# C-contiguous, wich is confusing.
if not self.destructive:
# Sometimes numpy thinks empty matrices can share memory,
# so here to stop DebugMode from complaining.
A = A.copy()
elif A.flags['C_CONTIGUOUS']:
A = local_ger(calpha, cy, cx, a=A.T,
overwrite_a=int(self.destructive)).T
else:
A = local_ger(calpha, cx, cy, a=A,
overwrite_a=int(self.destructive))
cZ[0] = A
scipy_ger_no_inplace = ScipyGer(False)
scipy_ger_inplace = ScipyGer(True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论