提交 913f11c5 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make a generic function for the new backend and make sure to compile on GPU.

上级 13b48305
......@@ -10,6 +10,7 @@ from theano.configparser import config
import theano.tensor as T
import theano.sandbox.cuda as cuda
from theano.compile import Mode
from .mode import get_mode
try:
from theano.gpuarray.type import GpuArrayType, _name_for_ctx
......@@ -200,45 +201,25 @@ def compile_gpu_func(nan_is_error, inf_is_error, big_is_error):
cuda_compile_failed = True
def f_gpua_min(inp):
dt = inp.dtype
ctx_name = _name_for_ctx(inp.context)
key = (dt, ctx_name)
f = f_gpua_min.cache.get(key, None)
if f is None:
guard_in = GpuArrayType(str(dt), (False,), context_name=ctx_name)()
f = theano.function([guard_in], T.min(guard_in),
mode='FAST_RUN', profile=False)
f_gpua_min.cache[key] = f
return f(inp)
f_gpua_min.cache = dict()
def f_gpua_max(inp):
dt = inp.dtype
ctx_name = _name_for_ctx(inp.context)
key = (dt, ctx_name)
f = f_gpua_min.cache.get(key, None)
if f is None:
guard_in = GpuArrayType(str(dt), (False,), context_name=ctx_name)()
f = theano.function([guard_in], T.max(guard_in),
mode='FAST_RUN', profile=False)
f_gpua_min.cache[key] = f
return f(inp)
f_gpua_max.cache = dict()
def f_gpua_absmax(inp):
dt = inp.dtype
ctx_name = _name_for_ctx(inp.context)
key = (dt, ctx_name)
f = f_gpua_min.cache.get(key, None)
if f is None:
guard_in = GpuArrayType(str(dt), (False,), context_name=ctx_name)()
f = theano.function([guard_in], T.max(T.abs_(guard_in)),
mode='FAST_RUN', profile=False)
f_gpua_min.cache[key] = f
return f(inp)
f_gpua_absmax.cache = dict()
def f_compute(op):
def result(inp):
dtype = inp.dtype
ctx_name = _name_for_ctx(inp.context)
key = (dtype, ctx_name)
f = result.cache.get(key, None)
if f is None:
guard_in = GpuArrayType(str(dtype), (False,), context_name=ctx_name)()
mode = get_mode('FAST_RUN').including('gpuarray')
f = theano.function([guard_in], op(guard_in),
mode=mode, profile=False)
result.cache[key] = f
return f(inp)
result.cache = dict()
return result
f_gpua_min = f_compute(T.min)
f_gpua_max = f_compute(T.max)
f_gpua_absmax = f_compute(lambda x: T.max(T.abs_(x)))
class NanGuardMode(Mode):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论