提交 a56442e0 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Some last fixes for proper context_name passing.

上级 470d02f8
......@@ -15,7 +15,7 @@ from theano.tensor.nnet import SoftmaxGrad
from theano.tensor.signal.downsample import (
DownsampleFactorMax, MaxPoolGrad, AveragePoolGrad)
from . import pygpu, init_dev
from . import pygpu
from .type import get_context, gpu_context_type
from .basic_ops import (as_gpuarray_variable, infer_context_name,
gpu_contiguous, HostFromGpu,
......@@ -29,6 +29,7 @@ from .nnet import GpuSoftmax
from .opt import gpu_seqopt, register_opt, conv_groupopt, op_lifter
from .opt_util import alpha_merge, output_merge, inplace_allocempty
def _dnn_check_compile():
preambule = """
#include <stdio.h>
......
......@@ -169,8 +169,8 @@ class InputToGpuOptimizer(Optimizer):
isinstance(input.clients[0][0].op, GpuFromHost))):
continue
ctx_name = getattr(input.tag, 'context_name', None)
try:
ctx_name = getattr(input.tag, 'context_name', None)
new_input = host_from_gpu(GpuFromHost(ctx_name)(input))
fgraph.replace_validate(input, new_input,
"InputToGpuOptimizer")
......@@ -180,7 +180,7 @@ class InputToGpuOptimizer(Optimizer):
except ValueError:
# If there is no context tag and no default context
# then it stays on the CPU
if ctx is not None:
if not hasattr(input.tag, 'context_name'):
raise
pass
......@@ -695,7 +695,7 @@ def local_gpua_gemm(node, context_name):
@register_opt('fast_compile')
@op_lifter([tensor.basic.Dot])
def local_gpua_hgemm(node):
def local_gpua_hgemm(node, context_name):
from theano.sandbox.cuda import nvcc_compiler
if nvcc_compiler.nvcc_version < '7.5':
_logger.warning("Not performing dot of float16 on the GPU since "
......@@ -707,8 +707,9 @@ def local_gpua_hgemm(node):
if (A.ndim == 2 and B.ndim == 2 and
A.dtype == 'float16' and B.dtype == 'float16'):
fgraph = node.inputs[0].fgraph
C = GpuAllocEmpty(dtype='float16')(shape_i(A, 0, fgraph),
shape_i(B, 1, fgraph))
C = GpuAllocEmpty(dtype='float16', context_name=context_name)(
shape_i(A, 0, fgraph),
shape_i(B, 1, fgraph))
return gpugemm_no_inplace(C, 1.0, A, B, 0.0)
......@@ -739,7 +740,7 @@ def local_gpua_dot22(node, context_name):
@register_opt('fast_compile')
@op_lifter([tensor.basic.Eye])
def local_gpua_eye(node, context_name):
return GpuEye(dtype=node.op.dtype)
return GpuEye(dtype=node.op.dtype, context_name=context_name)
@register_opt('fast_compile')
......@@ -971,8 +972,10 @@ def local_scan_to_gpua(node, context_name):
typebuild=typebuild).make_node(*nw_ins)
return nw_op.outputs
def _scan_type_infer(node):
context_name = infer_context_name(*node.inputs)
def typebuild(dtype, broadcastable, context_name=context_name):
return GpuArrayType(dtype=dtype, broadcastable=broadcastable,
context_name=context_name)
......
from __future__ import print_function
import os
import copy
import numpy
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论