提交 b152e5b9 authored 作者: Frederic's avatar Frederic

[TESTS] this fix a test while having concurrent PR that force pep8.

上级 c89e1bc2
...@@ -18,28 +18,29 @@ class CpuContiguous(theano.Op): ...@@ -18,28 +18,29 @@ class CpuContiguous(theano.Op):
""" """
__props__ = () __props__ = ()
view_map = {0: [0]} view_map = {0: [0]}
def make_node(self, x): def make_node(self, x):
x_ = theano.tensor.as_tensor_variable(x) x_ = theano.tensor.as_tensor_variable(x)
return theano.Apply(self, [x_], [x_.type()]) return theano.Apply(self, [x_], [x_.type()])
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
x, = inputs x, = inputs
y = output_storage[0] y = output_storage[0]
# if the ouput is contiguous do nothing, else copy # if the ouput is contiguous do nothing, else copy
# the input # the input
if not x.flags['C_CONTIGUOUS']: if not x.flags['C_CONTIGUOUS']:
x = x.copy() x = x.copy()
assert x.flags['C_CONTIGUOUS'] assert x.flags['C_CONTIGUOUS']
y[0] = x y[0] = x
def c_code(self, node, name, inames, onames, sub): def c_code(self, node, name, inames, onames, sub):
x, = inames x, = inames
y, = onames y, = onames
code = """ code = """
if (!PyArray_CHKFLAGS(%(x)s, NPY_ARRAY_C_CONTIGUOUS)){ if (!PyArray_CHKFLAGS(%(x)s, NPY_ARRAY_C_CONTIGUOUS)){
// check to see if output is contiguous first // check to see if output is contiguous first
if (%(y)s != NULL && PyArray_CHKFLAGS(%(y)s, NPY_ARRAY_C_CONTIGUOUS)){ if (%(y)s != NULL &&
PyArray_CHKFLAGS(%(y)s, NPY_ARRAY_C_CONTIGUOUS)){
PyArray_CopyInto(%(y)s, %(x)s); PyArray_CopyInto(%(y)s, %(x)s);
} }
else{ else{
...@@ -54,12 +55,13 @@ class CpuContiguous(theano.Op): ...@@ -54,12 +55,13 @@ class CpuContiguous(theano.Op):
} }
""" % locals() """ % locals()
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (0,) return (0,)
cpu_contiguous = CpuContiguous() cpu_contiguous = CpuContiguous()
class CumsumOp(theano.Op): class CumsumOp(theano.Op):
# See function cumsum for docstring # See function cumsum for docstring
def __init__(self, axis=None): def __init__(self, axis=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论