提交 b90976ee authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix flake8 format of gpuarray/subtensor.py

上级 c0342e58
......@@ -4,7 +4,7 @@ import numpy
import os
import theano
from theano import tensor, gof, Op, config
from theano import tensor, gof, config
from six.moves import StringIO
from theano.tensor.subtensor import IncSubtensor, Subtensor, get_idx_list
import theano.tensor.inplace
......@@ -243,16 +243,16 @@ class GpuIncSubtensor(GpuKernelBase, IncSubtensor):
if sub_x.shape:
# we've sliced out an N-D tensor with N > 0
if not self.set_instead_of_inc:
#sub_x += y
pygpu.elemwise.ielemwise2(sub_x, '+', y, broadcast=False)
# sub_x += y
pygpu.elemwise.ielemwise2(sub_x, '+', y, broadcast=False)
else:
#sub_x += -sub_x + y
# sub_x += -sub_x + y
x.__setitem__(cdata, y)
else:
# scalar case
if not self.set_instead_of_inc:
#x.__setitem__(cdata, sub_x + y)
tmp = pygpu.elemwise.elemwise2(sub_x, '+', y, sub_x,
# x.__setitem__(cdata, sub_x + y)
tmp = pygpu.elemwise.elemwise2(sub_x, '+', y, sub_x,
broadcast=False)
x.__setitem__(cdata, tmp)
else:
......@@ -261,9 +261,6 @@ class GpuIncSubtensor(GpuKernelBase, IncSubtensor):
def __setstate__(self, d):
self.__dict__.update(d)
owner = getattr(self.__dict__, "owner", None)
if owner:
op.create_iadd_node(owner)
def __getstate__(self):
d = copy.copy(self.__dict__)
......@@ -558,7 +555,7 @@ class GpuAdvancedIncSubtensor1(HideC, tensor.AdvancedIncSubtensor1):
reshaped_y = y.reshape(y.shape[1:])
else:
nb_dims_to_add = (x.ndim - 1) - y.ndim
reshaped_y = y.reshape((1,)*nb_dims_to_add + y.shape)
reshaped_y = y.reshape((1,) * nb_dims_to_add + y.shape)
if self.set_instead_of_inc:
for i in idx:
......@@ -633,9 +630,9 @@ class GpuAdvancedIncSubtensor1_dev20(GpuKernelBase, GpuAdvancedIncSubtensor1):
device_properties = theano.sandbox.cuda.device_properties
compute_capability = device_properties(active_device_no)['major']
if ((self.set_instead_of_inc) or
(node.inputs[0].ndim != node.inputs[1].ndim) or
(node.inputs[0].ndim != 2) or
(compute_capability < 2)):
(node.inputs[0].ndim != node.inputs[1].ndim) or
(node.inputs[0].ndim != 2) or
(compute_capability < 2)):
raise NotImplementedError("This case does not have C code yet.")
x = inputs[0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论