提交 2f986160 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add proper context inference for scan.

上级 c032ac66
...@@ -2478,8 +2478,11 @@ def local_gpu_allocempty(node): ...@@ -2478,8 +2478,11 @@ def local_gpu_allocempty(node):
return False return False
def typeInfer(node):
return typeConstructor
optdb.register('gpu_scanOp_make_inplace', optdb.register('gpu_scanOp_make_inplace',
scan_opt.ScanInplaceOptimizer(typeConstructor=typeConstructor, scan_opt.ScanInplaceOptimizer(typeInfer=typeInfer,
gpu_flag=True), gpu_flag=True),
75, 75,
'gpu', 'gpu',
......
...@@ -18,7 +18,7 @@ from theano.tensor.nnet.conv import ConvOp ...@@ -18,7 +18,7 @@ from theano.tensor.nnet.conv import ConvOp
from theano.tests.breakpoint import PdbBreakpoint from theano.tests.breakpoint import PdbBreakpoint
from .type import GpuArrayType, GpuArrayConstant, get_context from .type import GpuArrayType, GpuArrayConstant, get_context
from .basic_ops import (as_gpuarray_variable, from .basic_ops import (as_gpuarray_variable, infer_context_name,
host_from_gpu, GpuToGpu, host_from_gpu, GpuToGpu,
HostFromGpu, GpuFromHost, HostFromGpu, GpuFromHost,
GpuSplit, GpuContiguous, GpuSplit, GpuContiguous,
...@@ -961,12 +961,23 @@ def local_scan_to_gpua(node, context_name): ...@@ -961,12 +961,23 @@ def local_scan_to_gpua(node, context_name):
_cmodule_key = gof.CLinker().cmodule_key_(local_fgraph, []) _cmodule_key = gof.CLinker().cmodule_key_(local_fgraph, [])
info['gpu_hash'] = hash(_cmodule_key) info['gpu_hash'] = hash(_cmodule_key)
def typebuild(dtype, broadcastable, context_name=context_name):
return GpuArrayType(dtype=dtype, broadcastable=broadcastable,
context_name=context_name)
nw_op = scan_op.Scan(scan_ins, scan_outs, info, nw_op = scan_op.Scan(scan_ins, scan_outs, info,
typeConstructor=GpuArrayType).make_node(*nw_ins) typebuild=typebuild).make_node(*nw_ins)
return nw_op.outputs return nw_op.outputs
def _scan_type_infer(node):
context_name = infer_context_name(*node.inputs)
def typebuild(dtype, broadcastable, context_name=context_name):
return GpuArrayType(dtype=dtype, broadcastable=broadcastable,
context_name=context_name)
return typebuild
optdb.register('gpua_scanOp_make_inplace', optdb.register('gpua_scanOp_make_inplace',
scan_opt.ScanInplaceOptimizer(typeConstructor=GpuArrayType, scan_opt.ScanInplaceOptimizer(typeInfer=_scan_type_infer,
gpua_flag=True), gpua_flag=True),
75, 75,
'gpuarray', 'gpuarray',
......
...@@ -1014,9 +1014,9 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -1014,9 +1014,9 @@ class ScanInplaceOptimizer(Optimizer):
""" """
def __init__(self, typeConstructor=None, gpu_flag=False, gpua_flag=False): def __init__(self, typeInfer=None, gpu_flag=False, gpua_flag=False):
Optimizer.__init__(self) Optimizer.__init__(self)
self.typeConstructor = typeConstructor self.typeInfer = typeInfer
self.gpu_flag = gpu_flag self.gpu_flag = gpu_flag
self.gpua_flag = gpua_flag self.gpua_flag = gpua_flag
...@@ -1062,10 +1062,15 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -1062,10 +1062,15 @@ class ScanInplaceOptimizer(Optimizer):
ls[idx] = deep_copy_op(ls[idx]) ls[idx] = deep_copy_op(ls[idx])
inputs = ls_begin + ls + ls_end inputs = ls_begin + ls + ls_end
if self.typeInfer is None:
typeConstructor = None
else:
typeConstructor = self.typeInfer(node)
new_op = scan_op.Scan(op.inputs, new_op = scan_op.Scan(op.inputs,
op.outputs, op.outputs,
info, info,
typeConstructor=self.typeConstructor) typeConstructor=typeConstructor)
# Do not call make_node for test_value # Do not call make_node for test_value
new_outs = new_op(*inputs, **dict(return_list=True)) new_outs = new_op(*inputs, **dict(return_list=True))
...@@ -2298,7 +2303,7 @@ scan_eqopt2 = theano.gof.EquilibriumDB() ...@@ -2298,7 +2303,7 @@ scan_eqopt2 = theano.gof.EquilibriumDB()
optdb.register('scan_eqopt1', scan_eqopt1, .1, 'fast_run', 'scan') optdb.register('scan_eqopt1', scan_eqopt1, .1, 'fast_run', 'scan')
optdb.register('scan_eqopt2', scan_eqopt2, 1.6, 'fast_run', 'scan') optdb.register('scan_eqopt2', scan_eqopt2, 1.6, 'fast_run', 'scan')
optdb.register('scanOp_make_inplace', optdb.register('scanOp_make_inplace',
ScanInplaceOptimizer(typeConstructor=None, ScanInplaceOptimizer(typeInfer=None,
gpu_flag=False), gpu_flag=False),
75, 75,
'fast_run', 'fast_run',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论