提交 98e109fa authored 作者: Chinnadhurai Sankar's avatar Chinnadhurai Sankar

fix crash in MRG_RandomStreams with the new backend

上级 a7900bd8
...@@ -24,7 +24,8 @@ from . import multinomial ...@@ -24,7 +24,8 @@ from . import multinomial
import theano.sandbox.cuda import theano.sandbox.cuda
from theano.sandbox.cuda import GpuOp from theano.sandbox.cuda import GpuOp
from theano.gpuarray.basic_ops import GpuKernelBase, Kernel, infer_context_name from theano.sandbox.cuda.basic_ops import as_cuda_ndarray_variable
from theano.gpuarray.basic_ops import GpuKernelBase, Kernel, infer_context_name,as_gpuarray_variable
from theano.gpuarray.type import GpuArrayType from theano.gpuarray.type import GpuArrayType
from theano.gpuarray.fp16_help import write_w from theano.gpuarray.fp16_help import write_w
from theano.gpuarray.opt import (register_opt as register_gpua, from theano.gpuarray.opt import (register_opt as register_gpua,
...@@ -312,19 +313,6 @@ class mrg_uniform_base(Op): ...@@ -312,19 +313,6 @@ class mrg_uniform_base(Op):
s = "no_inplace" s = "no_inplace"
return self.__class__.__name__ + "{%s,%s}" % (self.output_type, s) return self.__class__.__name__ + "{%s,%s}" % (self.output_type, s)
def make_node(self, rstate, size):
# error checking slightly redundant here, since
# this op should not be called directly.
#
# call through MRG_RandomStreams instead.
broad = []
for i in range(self.output_type.ndim):
broad.append(tensor.extract_constant(size[i]) == 1)
output_type = self.output_type.clone(broadcastable=broad)()
return Apply(self,
[rstate, size],
[rstate.type(), output_type])
def grad(self, inputs, ograd): def grad(self, inputs, ograd):
return [gradient.grad_undefined(self, k, inp, return [gradient.grad_undefined(self, k, inp,
'No gradient defined through ' 'No gradient defined through '
...@@ -338,6 +326,21 @@ class mrg_uniform_base(Op): ...@@ -338,6 +326,21 @@ class mrg_uniform_base(Op):
class mrg_uniform(mrg_uniform_base): class mrg_uniform(mrg_uniform_base):
# CPU VERSION # CPU VERSION
def make_node(self, rstate, size):
# error checking slightly redundant here, since
# this op should not be called directly.
#
# call through MRG_RandomStreams instead.
broad = []
for i in range(self.output_type.ndim):
broad.append(tensor.extract_constant(size[i]) == 1)
output_type = self.output_type.clone(broadcastable=broad)()
rstate = as_tensor_variable(rstate)
return Apply(self,
[rstate, size],
[rstate.type(), output_type])
@classmethod @classmethod
def new(cls, rstate, ndim, dtype, size): def new(cls, rstate, ndim, dtype, size):
v_size = as_tensor_variable(size) v_size = as_tensor_variable(size)
...@@ -564,6 +567,20 @@ class mrg_uniform(mrg_uniform_base): ...@@ -564,6 +567,20 @@ class mrg_uniform(mrg_uniform_base):
class GPU_mrg_uniform(mrg_uniform_base, GpuOp): class GPU_mrg_uniform(mrg_uniform_base, GpuOp):
# GPU VERSION # GPU VERSION
def make_node(self, rstate, size):
# error checking slightly redundant here, since
# this op should not be called directly.
#
# call through MRG_RandomStreams instead.
broad = []
for i in range(self.output_type.ndim):
broad.append(tensor.extract_constant(size[i]) == 1)
output_type = self.output_type.clone(broadcastable=broad)()
rstate = as_cuda_ndarray_variable(rstate)
return Apply(self,
[rstate, size],
[rstate.type(), output_type])
@classmethod @classmethod
def new(cls, rstate, ndim, dtype, size): def new(cls, rstate, ndim, dtype, size):
v_size = as_tensor_variable(size) v_size = as_tensor_variable(size)
...@@ -809,6 +826,20 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base): ...@@ -809,6 +826,20 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
# GpuArray version # GpuArray version
_f16_ok = True _f16_ok = True
def make_node(self, rstate, size):
# error checking slightly redundant here, since
# this op should not be called directly.
#
# call through MRG_RandomStreams instead.
broad = []
for i in range(self.output_type.ndim):
broad.append(tensor.extract_constant(size[i]) == 1)
output_type = self.output_type.clone(broadcastable=broad)()
rstate = as_gpuarray_variable(rstate,infer_context_name(rstate))
return Apply(self,
[rstate, size],
[rstate.type(), output_type])
def get_params(self, node): def get_params(self, node):
return node.inputs[0].type.context return node.inputs[0].type.context
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论