提交 2f986160 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add proper context inference for scan.

上级 c032ac66
......@@ -2478,8 +2478,11 @@ def local_gpu_allocempty(node):
return False
def typeInfer(node):
return typeConstructor
optdb.register('gpu_scanOp_make_inplace',
scan_opt.ScanInplaceOptimizer(typeConstructor=typeConstructor,
scan_opt.ScanInplaceOptimizer(typeInfer=typeInfer,
gpu_flag=True),
75,
'gpu',
......
......@@ -18,7 +18,7 @@ from theano.tensor.nnet.conv import ConvOp
from theano.tests.breakpoint import PdbBreakpoint
from .type import GpuArrayType, GpuArrayConstant, get_context
from .basic_ops import (as_gpuarray_variable,
from .basic_ops import (as_gpuarray_variable, infer_context_name,
host_from_gpu, GpuToGpu,
HostFromGpu, GpuFromHost,
GpuSplit, GpuContiguous,
......@@ -961,12 +961,23 @@ def local_scan_to_gpua(node, context_name):
_cmodule_key = gof.CLinker().cmodule_key_(local_fgraph, [])
info['gpu_hash'] = hash(_cmodule_key)
def typebuild(dtype, broadcastable, context_name=context_name):
return GpuArrayType(dtype=dtype, broadcastable=broadcastable,
context_name=context_name)
nw_op = scan_op.Scan(scan_ins, scan_outs, info,
typeConstructor=GpuArrayType).make_node(*nw_ins)
typebuild=typebuild).make_node(*nw_ins)
return nw_op.outputs
def _scan_type_infer(node):
context_name = infer_context_name(*node.inputs)
def typebuild(dtype, broadcastable, context_name=context_name):
return GpuArrayType(dtype=dtype, broadcastable=broadcastable,
context_name=context_name)
return typebuild
optdb.register('gpua_scanOp_make_inplace',
scan_opt.ScanInplaceOptimizer(typeConstructor=GpuArrayType,
scan_opt.ScanInplaceOptimizer(typeInfer=_scan_type_infer,
gpua_flag=True),
75,
'gpuarray',
......
......@@ -1014,9 +1014,9 @@ class ScanInplaceOptimizer(Optimizer):
"""
def __init__(self, typeConstructor=None, gpu_flag=False, gpua_flag=False):
def __init__(self, typeInfer=None, gpu_flag=False, gpua_flag=False):
Optimizer.__init__(self)
self.typeConstructor = typeConstructor
self.typeInfer = typeInfer
self.gpu_flag = gpu_flag
self.gpua_flag = gpua_flag
......@@ -1062,10 +1062,15 @@ class ScanInplaceOptimizer(Optimizer):
ls[idx] = deep_copy_op(ls[idx])
inputs = ls_begin + ls + ls_end
if self.typeInfer is None:
typeConstructor = None
else:
typeConstructor = self.typeInfer(node)
new_op = scan_op.Scan(op.inputs,
op.outputs,
info,
typeConstructor=self.typeConstructor)
typeConstructor=typeConstructor)
# Do not call make_node for test_value
new_outs = new_op(*inputs, **dict(return_list=True))
......@@ -2298,7 +2303,7 @@ scan_eqopt2 = theano.gof.EquilibriumDB()
optdb.register('scan_eqopt1', scan_eqopt1, .1, 'fast_run', 'scan')
optdb.register('scan_eqopt2', scan_eqopt2, 1.6, 'fast_run', 'scan')
optdb.register('scanOp_make_inplace',
ScanInplaceOptimizer(typeConstructor=None,
ScanInplaceOptimizer(typeInfer=None,
gpu_flag=False),
75,
'fast_run',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论