提交 264ea310 authored 作者: Caglar's avatar Caglar

fixed the flake8 errors.

上级 73c28a05
...@@ -1220,7 +1220,7 @@ class ShapeFeature(object): ...@@ -1220,7 +1220,7 @@ class ShapeFeature(object):
# But we never timed this speed optimization! # But we never timed this speed optimization!
self.lscalar_one.equals(new_shape[idx]) or self.lscalar_one.equals(new_shape[idx]) or
self.lscalar_one.equals(T.extract_constant(new_shape[idx], self.lscalar_one.equals(T.extract_constant(new_shape[idx],
only_process_constants=True)) only_process_constants=True))
for idx in xrange(r.ndim)]) for idx in xrange(r.ndim)])
self.shape_of[r] = tuple(new_shape) self.shape_of[r] = tuple(new_shape)
for sv in self.shape_of[r]: for sv in self.shape_of[r]:
...@@ -2393,7 +2393,7 @@ def local_useless_inc_subtensor(node): ...@@ -2393,7 +2393,7 @@ def local_useless_inc_subtensor(node):
idx_cst = get_idx_list(node.inputs[1:], node.op.idx_list) idx_cst = get_idx_list(node.inputs[1:], node.op.idx_list)
if all(isinstance(e, slice) and e.start is None and if all(isinstance(e, slice) and e.start is None and
e.stop is None and (e.step is None or T.extract_constant(e.step, e.stop is None and (e.step is None or T.extract_constant(e.step,
only_process_constants=True) == -1) only_process_constants=True) == -1)
for e in idx_cst): for e in idx_cst):
# IncSubtensor broadcast node.inputs[1] on node.inputs[0] # IncSubtensor broadcast node.inputs[1] on node.inputs[0]
# based on run time shapes, so we must check they are the same. # based on run time shapes, so we must check they are the same.
...@@ -2464,7 +2464,7 @@ def local_useless_slice(node): ...@@ -2464,7 +2464,7 @@ def local_useless_slice(node):
# check if slice and then check slice indices # check if slice and then check slice indices
if (isinstance(s, slice) and s.start is None and s.stop is None and if (isinstance(s, slice) and s.start is None and s.stop is None and
(s.step is None or T.extract_constant(s.step, (s.step is None or T.extract_constant(s.step,
only_process_constants=True) == 1)): only_process_constants=True) == 1)):
last_slice -= 1 last_slice -= 1
else: else:
break break
...@@ -2580,7 +2580,7 @@ def local_useless_subtensor(node): ...@@ -2580,7 +2580,7 @@ def local_useless_subtensor(node):
elif idx.owner is not None and isinstance(idx.owner.op, T.ARange): elif idx.owner is not None and isinstance(idx.owner.op, T.ARange):
try: try:
start, stop, step = map(lambda x: get_scalar_constant_value(x, start, stop, step = map(lambda x: get_scalar_constant_value(x,
only_process_constants=True), only_process_constants=True),
idx.owner.inputs) idx.owner.inputs)
except NotScalarConstantError: except NotScalarConstantError:
return False return False
...@@ -3577,7 +3577,7 @@ def local_join_empty(node): ...@@ -3577,7 +3577,7 @@ def local_join_empty(node):
new_inputs = [] new_inputs = []
try: try:
join_idx = get_scalar_constant_value(node.inputs[0], join_idx = get_scalar_constant_value(node.inputs[0],
only_process_constants=True) only_process_constants=True)
except NotScalarConstantError: except NotScalarConstantError:
return return
for idx in xrange(1, len(node.inputs)): for idx in xrange(1, len(node.inputs)):
...@@ -3737,7 +3737,7 @@ def local_useless_switch(node): ...@@ -3737,7 +3737,7 @@ def local_useless_switch(node):
if (isinstance(node.op, T.Elemwise) and if (isinstance(node.op, T.Elemwise) and
isinstance(node.op.scalar_op, scalar.basic.Switch)): isinstance(node.op.scalar_op, scalar.basic.Switch)):
cond = T.extract_constant(node.inputs[0], elemwise=False, cond = T.extract_constant(node.inputs[0], elemwise=False,
only_process_constants=True) only_process_constants=True)
if type(cond) is numpy.ndarray and cond.ndim == 0: if type(cond) is numpy.ndarray and cond.ndim == 0:
if cond == 0: if cond == 0:
correct_out = node.inputs[2] correct_out = node.inputs[2]
...@@ -3900,7 +3900,7 @@ def local_div_switch_sink(node): ...@@ -3900,7 +3900,7 @@ def local_div_switch_sink(node):
switch = node.inputs[0].owner switch = node.inputs[0].owner
try: try:
if get_scalar_constant_value(switch.inputs[1], if get_scalar_constant_value(switch.inputs[1],
only_process_constants=True) == 0.: only_process_constants=True) == 0.:
fdiv = op(switch.inputs[2], node.inputs[1]) fdiv = op(switch.inputs[2], node.inputs[1])
# Copy over stacktrace for elementwise division op # Copy over stacktrace for elementwise division op
# from previous elementwise multiplication op. # from previous elementwise multiplication op.
...@@ -3923,7 +3923,7 @@ def local_div_switch_sink(node): ...@@ -3923,7 +3923,7 @@ def local_div_switch_sink(node):
pass pass
try: try:
if get_scalar_constant_value(switch.inputs[2], if get_scalar_constant_value(switch.inputs[2],
only_process_constants=True) == 0.: only_process_constants=True) == 0.:
fdiv = op(switch.inputs[1], node.inputs[1]) fdiv = op(switch.inputs[1], node.inputs[1])
# Copy over stacktrace for elementwise division op # Copy over stacktrace for elementwise division op
# from previous elementwise multiplication op. # from previous elementwise multiplication op.
...@@ -3989,7 +3989,7 @@ def local_useless_tile(node): ...@@ -3989,7 +3989,7 @@ def local_useless_tile(node):
if isinstance(node.op, T.Tile): if isinstance(node.op, T.Tile):
try: try:
a = T.get_scalar_constant_value(node.inputs[1], a = T.get_scalar_constant_value(node.inputs[1],
only_process_constants=True) only_process_constants=True)
if a == 1: if a == 1:
try: try:
l = T.get_vector_length(node.inputs[1]) l = T.get_vector_length(node.inputs[1])
...@@ -4173,7 +4173,7 @@ if 0: ...@@ -4173,7 +4173,7 @@ if 0:
def tmp(thing): def tmp(thing):
try: try:
return T.get_scalar_constant_value(thing, return T.get_scalar_constant_value(thing,
only_process_constants=True) only_process_constants=True)
except (TypeError, ValueError) as e: except (TypeError, ValueError) as e:
print(e, thing.owner.inputs[0]) print(e, thing.owner.inputs[0])
return None return None
...@@ -5221,7 +5221,7 @@ def local_reduce_join(node): ...@@ -5221,7 +5221,7 @@ def local_reduce_join(node):
# We add the new check late to don't add extra warning. # We add the new check late to don't add extra warning.
try: try:
join_axis = get_scalar_constant_value(join.inputs[0], join_axis = get_scalar_constant_value(join.inputs[0],
only_process_constants=True) only_process_constants=True)
if join_axis != reduce_axis[0]: if join_axis != reduce_axis[0]:
return return
...@@ -5305,7 +5305,7 @@ def local_opt_alloc(node): ...@@ -5305,7 +5305,7 @@ def local_opt_alloc(node):
node.op.axis == tuple(range(input.ndim))): node.op.axis == tuple(range(input.ndim))):
try: try:
val = get_scalar_constant_value(input, val = get_scalar_constant_value(input,
only_process_constants=True) only_process_constants=True)
assert val.size == 1 assert val.size == 1
# check which type of op # check which type of op
casted = T.mul(*shapes).astype(str(input.dtype)) casted = T.mul(*shapes).astype(str(input.dtype))
...@@ -5320,7 +5320,7 @@ def local_opt_alloc(node): ...@@ -5320,7 +5320,7 @@ def local_opt_alloc(node):
else: else:
try: try:
val = get_scalar_constant_value(input, val = get_scalar_constant_value(input,
only_process_constants=True) only_process_constants=True)
assert val.size == 1 assert val.size == 1
val = val.reshape(1)[0] val = val.reshape(1)[0]
to_prod = [shapes[i] for i in xrange(len(shapes)) to_prod = [shapes[i] for i in xrange(len(shapes))
...@@ -5765,7 +5765,7 @@ def local_abs_merge(node): ...@@ -5765,7 +5765,7 @@ def local_abs_merge(node):
elif isinstance(i, Constant): elif isinstance(i, Constant):
try: try:
const = get_scalar_constant_value(i, const = get_scalar_constant_value(i,
only_process_constants=True) only_process_constants=True)
except NotScalarConstantError: except NotScalarConstantError:
return False return False
if not (const >= 0).all(): if not (const >= 0).all():
...@@ -6348,7 +6348,7 @@ def local_grad_log_erfc_neg(node): ...@@ -6348,7 +6348,7 @@ def local_grad_log_erfc_neg(node):
try: try:
cst2 = get_scalar_constant_value(mul_neg.owner.inputs[0], cst2 = get_scalar_constant_value(mul_neg.owner.inputs[0],
only_process_constants=True) only_process_constants=True)
except NotScalarConstantError: except NotScalarConstantError:
return False return False
...@@ -6376,7 +6376,7 @@ def local_grad_log_erfc_neg(node): ...@@ -6376,7 +6376,7 @@ def local_grad_log_erfc_neg(node):
x = erfc_x x = erfc_x
try: try:
cst = get_scalar_constant_value(erfc_x.owner.inputs[0], cst = get_scalar_constant_value(erfc_x.owner.inputs[0],
only_process_constants=True) only_process_constants=True)
except NotScalarConstantError: except NotScalarConstantError:
return False return False
if cst2 != -cst * 2: if cst2 != -cst * 2:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论