提交 ac850ee6 authored 作者: Frederic's avatar Frederic

some pep8

上级 689e5680
...@@ -594,8 +594,7 @@ class Subtensor(Op): ...@@ -594,8 +594,7 @@ class Subtensor(Op):
@staticmethod @staticmethod
def helper_c_code(node, name, inputs, outputs, sub, idx_list, view_ndim, def helper_c_code(node, name, inputs, outputs, sub, idx_list, view_ndim,
c_prefix=None, c_prefix=None,
strides_mul=None, strides_mul=None):
):
""" """
The parameters c_prefix are there to allow reusing this The parameters c_prefix are there to allow reusing this
function on PyArray and CudaNdarray object. function on PyArray and CudaNdarray object.
...@@ -682,7 +681,8 @@ class Subtensor(Op): ...@@ -682,7 +681,8 @@ class Subtensor(Op):
subensor_spec = "npy_intp * subtensor_spec = NULL;" subensor_spec = "npy_intp * subtensor_spec = NULL;"
if is_slice: if is_slice:
is_slice_init = "int is_slice[] = {" + ",".join([str(s) for s in is_slice]) + "};" is_slice_init = "int is_slice[] = {" + ",".join([str(s) for s in
is_slice]) + "};"
else: else:
is_slice_init = "int* is_slice = NULL;" is_slice_init = "int* is_slice = NULL;"
subtensor_init = "\n".join(init_cmds) subtensor_init = "\n".join(init_cmds)
...@@ -897,7 +897,8 @@ class Subtensor(Op): ...@@ -897,7 +897,8 @@ class Subtensor(Op):
%(z)s = xview; %(z)s = xview;
""" % locals() """ % locals()
return decl + checkNDim + "{" + get_xview + build_view + finish_view + "}" return (decl + checkNDim +
"{" + get_xview + build_view + finish_view + "}")
def c_code_cache_version(self): def c_code_cache_version(self):
hv = self.helper_c_code_cache_version() hv = self.helper_c_code_cache_version()
...@@ -1022,9 +1023,9 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False, ...@@ -1022,9 +1023,9 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False,
destroyhandler_tolerate_aliased = [[0, 1]] destroyhandler_tolerate_aliased = [[0, 1]]
else: else:
destroyhandler_tolerate_aliased = [] destroyhandler_tolerate_aliased = []
the_op = IncSubtensor(x.owner.op.idx_list, inplace, set_instead_of_inc, the_op = IncSubtensor(
destroyhandler_tolerate_aliased=destroyhandler_tolerate_aliased x.owner.op.idx_list, inplace, set_instead_of_inc,
) destroyhandler_tolerate_aliased=destroyhandler_tolerate_aliased)
real_x = x.owner.inputs[0] real_x = x.owner.inputs[0]
real_idxargs = x.owner.inputs[1:] real_idxargs = x.owner.inputs[1:]
return the_op(real_x, y, *real_idxargs) return the_op(real_x, y, *real_idxargs)
...@@ -1105,10 +1106,10 @@ class IncSubtensor(Op): ...@@ -1105,10 +1106,10 @@ class IncSubtensor(Op):
self.set_instead_of_inc = set_instead_of_inc self.set_instead_of_inc = set_instead_of_inc
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) \ return (type(self) == type(other) and
and self.idx_list == other.idx_list \ self.idx_list == other.idx_list and
and self.inplace == other.inplace \ self.inplace == other.inplace and
and self.set_instead_of_inc == other.set_instead_of_inc self.set_instead_of_inc == other.set_instead_of_inc)
def __hash__(self): def __hash__(self):
msg = [] msg = []
...@@ -1120,12 +1121,12 @@ class IncSubtensor(Op): ...@@ -1120,12 +1121,12 @@ class IncSubtensor(Op):
idx_list = tuple(msg) idx_list = tuple(msg)
# backport # backport
#idx_list = tuple((entry.start, entry.stop, entry.step) # idx_list = tuple((entry.start, entry.stop, entry.step)
# if isinstance(entry, slice) # if isinstance(entry, slice)
# else entry # else entry
# for entry in self.idx_list) # for entry in self.idx_list)
return hashtype(self) ^ hash(idx_list) ^ hash(self.inplace) \ return (hashtype(self) ^ hash(idx_list) ^ hash(self.inplace) ^
^ hash(self.set_instead_of_inc) hash(self.set_instead_of_inc))
def __str__(self): def __str__(self):
indices = [] indices = []
...@@ -1225,7 +1226,7 @@ class IncSubtensor(Op): ...@@ -1225,7 +1226,7 @@ class IncSubtensor(Op):
if not self.set_instead_of_inc: if not self.set_instead_of_inc:
sub_x += y sub_x += y
else: else:
#sub_x += -sub_x + y # sub_x += -sub_x + y
x.__setitem__(cdata, y) x.__setitem__(cdata, y)
else: else:
# scalar case # scalar case
...@@ -1469,7 +1470,7 @@ class IncSubtensor(Op): ...@@ -1469,7 +1470,7 @@ class IncSubtensor(Op):
axis_to_sum.append(i) axis_to_sum.append(i)
elif (gy.broadcastable[i] is True and elif (gy.broadcastable[i] is True and
y_broad[i] is False): y_broad[i] is False):
# This mean that THeano where able to infer that # This mean that Theano where able to infer that
# gy.shape[i] is 1, so y.shape[i] is 1, but we # gy.shape[i] is 1, so y.shape[i] is 1, but we
# didn't know it. It is fine. # didn't know it. It is fine.
pass pass
......
...@@ -160,7 +160,7 @@ class Test_inc_subtensor(unittest.TestCase): ...@@ -160,7 +160,7 @@ class Test_inc_subtensor(unittest.TestCase):
(numpy.asarray([[0, 1], [2, 3], [4, 5.]]), (numpy.asarray([[0, 1], [2, 3], [4, 5.]]),
numpy.asarray([[9, 9.]]), )) numpy.asarray([[9, 9.]]), ))
#single element # single element
utt.verify_grad( utt.verify_grad(
f_slice(2, 1), f_slice(2, 1),
(numpy.asarray([[0, 1], [2, 3], [4, 5.]]), (numpy.asarray([[0, 1], [2, 3], [4, 5.]]),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论