提交 f9fbd00f authored 作者: Chinnadhurai Sankar's avatar Chinnadhurai Sankar

fix flake8 errors

上级 98e109fa
...@@ -25,7 +25,7 @@ from . import multinomial ...@@ -25,7 +25,7 @@ from . import multinomial
import theano.sandbox.cuda import theano.sandbox.cuda
from theano.sandbox.cuda import GpuOp from theano.sandbox.cuda import GpuOp
from theano.sandbox.cuda.basic_ops import as_cuda_ndarray_variable from theano.sandbox.cuda.basic_ops import as_cuda_ndarray_variable
from theano.gpuarray.basic_ops import GpuKernelBase, Kernel, infer_context_name,as_gpuarray_variable from theano.gpuarray.basic_ops import GpuKernelBase, Kernel, infer_context_name, as_gpuarray_variable
from theano.gpuarray.type import GpuArrayType from theano.gpuarray.type import GpuArrayType
from theano.gpuarray.fp16_help import write_w from theano.gpuarray.fp16_help import write_w
from theano.gpuarray.opt import (register_opt as register_gpua, from theano.gpuarray.opt import (register_opt as register_gpua,
...@@ -326,7 +326,6 @@ class mrg_uniform_base(Op): ...@@ -326,7 +326,6 @@ class mrg_uniform_base(Op):
class mrg_uniform(mrg_uniform_base): class mrg_uniform(mrg_uniform_base):
# CPU VERSION # CPU VERSION
def make_node(self, rstate, size): def make_node(self, rstate, size):
# error checking slightly redundant here, since # error checking slightly redundant here, since
# this op should not be called directly. # this op should not be called directly.
...@@ -835,7 +834,7 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base): ...@@ -835,7 +834,7 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
for i in range(self.output_type.ndim): for i in range(self.output_type.ndim):
broad.append(tensor.extract_constant(size[i]) == 1) broad.append(tensor.extract_constant(size[i]) == 1)
output_type = self.output_type.clone(broadcastable=broad)() output_type = self.output_type.clone(broadcastable=broad)()
rstate = as_gpuarray_variable(rstate,infer_context_name(rstate)) rstate = as_gpuarray_variable(rstate, infer_context_name(rstate))
return Apply(self, return Apply(self,
[rstate, size], [rstate, size],
[rstate.type(), output_type]) [rstate.type(), output_type])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论