提交 cd037afb authored 作者: Frederic's avatar Frederic

Remove import not used and fix output dtype computation.

上级 024347f9
...@@ -3,9 +3,7 @@ from theano.compat.six import StringIO ...@@ -3,9 +3,7 @@ from theano.compat.six import StringIO
from theano.sandbox.cuda.nvcc_compiler import NVCC_compiler from theano.sandbox.cuda.nvcc_compiler import NVCC_compiler
from theano.sandbox.cuda.kernel_codegen import (nvcc_kernel,
inline_softmax,
inline_softmax_fixed_shared)
try: try:
import pygpu import pygpu
from pygpu import gpuarray, elemwise from pygpu import gpuarray, elemwise
...@@ -13,6 +11,7 @@ except ImportError: ...@@ -13,6 +11,7 @@ except ImportError:
pass pass
from theano.sandbox.gpuarray.basic_ops import as_gpuarray_variable from theano.sandbox.gpuarray.basic_ops import as_gpuarray_variable
from theano.sandbox.gpuarray.type import GpuArrayType
class GpuCrossentropySoftmaxArgmax1HotWithBias(Op): class GpuCrossentropySoftmaxArgmax1HotWithBias(Op):
...@@ -36,7 +35,8 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(Op): ...@@ -36,7 +35,8 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(Op):
x = as_gpuarray_variable(x) x = as_gpuarray_variable(x)
b = as_gpuarray_variable(b) b = as_gpuarray_variable(b)
y_idx = as_gpuarray_variable(y_idx) y_idx = as_gpuarray_variable(y_idx)
nll = y_idx.type() nll = GpuArrayType(x.type.dtype,
y_idx.type.broadcastable)()
sm = x.type() sm = x.type()
am = y_idx.type() am = y_idx.type()
return Apply(self, [x, b, y_idx], [nll, sm, am]) return Apply(self, [x, b, y_idx], [nll, sm, am])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论