提交 ac8ed0a7 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add a way to differentiate the old and new gpu backend in ScanInplaceOptimizer.

上级 461f2cdd
......@@ -502,6 +502,7 @@ def gpu_reconstruct_graph(inputs, outputs, tag=None):
@op_lifter([scan_op.Scan])
def local_scan_to_gpua(node):
info = copy.deepcopy(node.op.info)
info['gpu'] = True
info['gpua'] = True
nw_ins = [node.inputs[0]]
e = (1 +
......@@ -528,7 +529,7 @@ def local_scan_to_gpua(node):
tmp_in, tmp_out = gpu_reconstruct_graph(scan_ins, scan_outs)
local_fgraph = gof.FunctionGraph(tmp_in, tmp_out, clone=False)
_cmodule_key = gof.CLinker().cmodule_key_(local_fgraph, [])
# info['gpu_hash'] = hash(_cmodule_key)
info['gpu_hash'] = hash(_cmodule_key)
nw_op = scan_op.Scan(scan_ins, scan_outs, info,
typeConstructor=GpuArrayType).make_node(*nw_ins)
......@@ -536,7 +537,7 @@ def local_scan_to_gpua(node):
optdb.register('gpua_scanOp_make_inplace',
scan_opt.ScanInplaceOptimizer(typeConstructor=GpuArrayType,
gpu_flag=True),
gpua_flag=True),
75,
'gpua',
'fast_run',
......
......@@ -73,6 +73,8 @@ class Scan(PureOp):
will be moved on the GPU if the optimization gets applied (following
Theano's philosophy of moving as much as possible on gpu).
"""
if 'gpua' not in info:
info['gpua'] = False
# adding properties into self
self.inputs = inputs
self.outputs = outputs
......
......@@ -537,10 +537,11 @@ class PushOutSeqScan(gof.Optimizer):
class ScanInplaceOptimizer(Optimizer):
"""Graph optimizer for Scan(makes it run inplace)"""
def __init__(self, typeConstructor=None, gpu_flag=False):
def __init__(self, typeConstructor=None, gpu_flag=False, gpua_flag=False):
Optimizer.__init__(self)
self.typeConstructor = typeConstructor
self.gpu_flag = gpu_flag
self.gpua_flag = gpua_flag
def add_requirements(self, fgraph):
fgraph.attach_feature(toolbox.ReplaceValidate())
......@@ -551,7 +552,8 @@ class ScanInplaceOptimizer(Optimizer):
nodes = fgraph.toposort()
scan_nodes = [x for x in nodes
if (isinstance(x.op, scan_op.Scan) and
x.op.info['gpu'] == self.gpu_flag)]
x.op.info['gpu'] == self.gpu_flag and
x.op.info['gpua'] == self.gpua_flag)]
for scan_idx in xrange(len(scan_nodes)):
node = scan_nodes[scan_idx]
op = node.op
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论