提交 01e852ae authored 作者: Frederic's avatar Frederic

[CRASH] Make sure the var is on the GPU.

上级 3ea6aa18
...@@ -27,6 +27,9 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(GpuOp): ...@@ -27,6 +27,9 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(GpuOp):
def make_node(self, x, b, y_idx): def make_node(self, x, b, y_idx):
#N.B. won't work when we don't cast y_idx to float anymore #N.B. won't work when we don't cast y_idx to float anymore
x = as_cuda_ndarray_variable(x)
b = as_cuda_ndarray_variable(b)
y_idx = as_cuda_ndarray_variable(y_idx)
nll = y_idx.type() nll = y_idx.type()
sm = x.type() sm = x.type()
am = y_idx.type() am = y_idx.type()
...@@ -237,6 +240,9 @@ class GpuCrossentropySoftmax1HotWithBiasDx(GpuOp): ...@@ -237,6 +240,9 @@ class GpuCrossentropySoftmax1HotWithBiasDx(GpuOp):
return self.__class__.__name__ return self.__class__.__name__
def make_node(self, dy, sm, y_idx): def make_node(self, dy, sm, y_idx):
dy = as_cuda_ndarray_variable(dy)
sm = as_cuda_ndarray_variable(sm)
y_idx = as_cuda_ndarray_variable(y_idx)
return Apply(self, [dy, sm, y_idx], [sm.type()]) return Apply(self, [dy, sm, y_idx], [sm.type()])
def c_code_cache_version(self): def c_code_cache_version(self):
...@@ -379,6 +385,7 @@ class GpuSoftmax(GpuOp): ...@@ -379,6 +385,7 @@ class GpuSoftmax(GpuOp):
return self.__class__.__name__ return self.__class__.__name__
def make_node(self, x): def make_node(self, x):
x = as_cuda_ndarray_variable(x)
return Apply(self, [x], [x.type()]) return Apply(self, [x], [x.type()])
def infer_shape(self, node, shape): def infer_shape(self, node, shape):
...@@ -543,6 +550,7 @@ class GpuSoftmaxWithBias(GpuOp): ...@@ -543,6 +550,7 @@ class GpuSoftmaxWithBias(GpuOp):
return self.__class__.__name__ return self.__class__.__name__
def make_node(self, x, b): def make_node(self, x, b):
x = as_cuda_ndarray_variable(x)
return Apply(self, [x, b], [x.type()]) return Apply(self, [x, b], [x.type()])
def infer_shape(self, node, shape): def infer_shape(self, node, shape):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论