提交 b25fb1be authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Prevent unnecessary shadowing of builtin input

上级 9078756f
...@@ -644,10 +644,10 @@ def local_dimshuffle_lift(fgraph, node): ...@@ -644,10 +644,10 @@ def local_dimshuffle_lift(fgraph, node):
if not isinstance(op, DimShuffle): if not isinstance(op, DimShuffle):
return False return False
input = node.inputs[0] inp = node.inputs[0]
inode = input.owner inode = inp.owner
new_order = op.new_order new_order = op.new_order
if inode and isinstance(inode.op, Elemwise) and (len(fgraph.clients[input]) == 1): if inode and isinstance(inode.op, Elemwise) and (len(fgraph.clients[inp]) == 1):
# Don't use make_node to have tag.test_value set. # Don't use make_node to have tag.test_value set.
new_inputs = [] new_inputs = []
for inp in inode.inputs: for inp in inode.inputs:
...@@ -658,12 +658,12 @@ def local_dimshuffle_lift(fgraph, node): ...@@ -658,12 +658,12 @@ def local_dimshuffle_lift(fgraph, node):
return ret return ret
if inode and isinstance(inode.op, DimShuffle): if inode and isinstance(inode.op, DimShuffle):
new_order = [x == "x" and "x" or inode.op.new_order[x] for x in new_order] new_order = [x == "x" and "x" or inode.op.new_order[x] for x in new_order]
input = inode.inputs[0] inp = inode.inputs[0]
if is_dimshuffle_useless(new_order, input): if is_dimshuffle_useless(new_order, inp):
return [input] return [inp]
elif inode and isinstance(inode.op, DimShuffle): elif inode and isinstance(inode.op, DimShuffle):
ret = op.__class__(input.type.broadcastable, new_order)(input) ret = op.__class__(inp.type.broadcastable, new_order)(inp)
ret = apply_local_dimshuffle_lift(fgraph, ret) ret = apply_local_dimshuffle_lift(fgraph, ret)
copy_stack_trace(node.outputs[0], ret) copy_stack_trace(node.outputs[0], ret)
return [ret] return [ret]
...@@ -691,7 +691,7 @@ def local_useless_dimshuffle_in_reshape(fgraph, node): ...@@ -691,7 +691,7 @@ def local_useless_dimshuffle_in_reshape(fgraph, node):
return False return False
new_order = node.inputs[0].owner.op.new_order new_order = node.inputs[0].owner.op.new_order
input = node.inputs[0].owner.inputs[0] inp = node.inputs[0].owner.inputs[0]
broadcastables = node.inputs[0].broadcastable broadcastables = node.inputs[0].broadcastable
new_order_of_nonbroadcast = [] new_order_of_nonbroadcast = []
for i, bd in zip(new_order, broadcastables): for i, bd in zip(new_order, broadcastables):
...@@ -703,7 +703,7 @@ def local_useless_dimshuffle_in_reshape(fgraph, node): ...@@ -703,7 +703,7 @@ def local_useless_dimshuffle_in_reshape(fgraph, node):
) )
if no_change_in_order: if no_change_in_order:
shape = node.inputs[1] shape = node.inputs[1]
ret = op.__class__(node.outputs[0].ndim)(input, shape) ret = op.__class__(node.outputs[0].ndim)(inp, shape)
copy_stack_trace(node.outputs[0], ret) copy_stack_trace(node.outputs[0], ret)
return [ret] return [ret]
...@@ -744,7 +744,7 @@ class MakeVectorPrinter: ...@@ -744,7 +744,7 @@ class MakeVectorPrinter:
old_precedence = getattr(pstate, "precedence", None) old_precedence = getattr(pstate, "precedence", None)
try: try:
pstate.precedence = 1000 pstate.precedence = 1000
s = [pstate.pprinter.process(input) for input in r.owner.inputs] s = [pstate.pprinter.process(inp) for inp in r.owner.inputs]
finally: finally:
pstate.precedence = old_precedence pstate.precedence = old_precedence
return f"[{', '.join(s)}]" return f"[{', '.join(s)}]"
...@@ -1636,12 +1636,12 @@ def local_fill_sink(fgraph, node): ...@@ -1636,12 +1636,12 @@ def local_fill_sink(fgraph, node):
return False return False
models = [] models = []
inputs = [] inputs = []
for input in node.inputs: for inp in node.inputs:
if input.owner and input.owner.op == fill: if inp.owner and inp.owner.op == fill:
models.append(input.owner.inputs[0]) models.append(inp.owner.inputs[0])
inputs.append(input.owner.inputs[1]) inputs.append(inp.owner.inputs[1])
else: else:
inputs.append(input) inputs.append(inp)
if not models: if not models:
return False return False
c = node.op(*inputs) c = node.op(*inputs)
...@@ -1765,16 +1765,16 @@ def local_useless_alloc(fgraph, node): ...@@ -1765,16 +1765,16 @@ def local_useless_alloc(fgraph, node):
if not isinstance(node.op, Alloc): if not isinstance(node.op, Alloc):
return False return False
input = node.inputs[0] inp = node.inputs[0]
output = node.outputs[0] output = node.outputs[0]
if input.type == output.type: if inp.type == output.type:
if input.ndim == 0: if inp.ndim == 0:
return [input] return [inp]
else: else:
return [ return [
Assert("Shapes must be equal")( Assert("Shapes must be equal")(
input, at_all(eq(input.shape, node.inputs[1:])) inp, at_all(eq(inp.shape, node.inputs[1:]))
) )
] ]
...@@ -1799,13 +1799,13 @@ def local_canonicalize_alloc(fgraph, node): ...@@ -1799,13 +1799,13 @@ def local_canonicalize_alloc(fgraph, node):
if not isinstance(op, Alloc): if not isinstance(op, Alloc):
return False return False
input = node.inputs[0] inp = node.inputs[0]
output = node.outputs[0] output = node.outputs[0]
# Check if dtype and broadcast remain the same. # Check if dtype and broadcast remain the same.
if input.type == output.type: if inp.type == output.type:
# We don't need to copy over any stack traces here # We don't need to copy over any stack traces here
return [input] return [inp]
# Allow local_merge_alloc to do its work first # Allow local_merge_alloc to do its work first
clients = fgraph.clients[output] clients = fgraph.clients[output]
...@@ -1817,20 +1817,20 @@ def local_canonicalize_alloc(fgraph, node): ...@@ -1817,20 +1817,20 @@ def local_canonicalize_alloc(fgraph, node):
output_shape = node.inputs[1:] output_shape = node.inputs[1:]
num_dims_with_size_1_added_to_left = 0 num_dims_with_size_1_added_to_left = 0
for i in range(len(output_shape) - input.ndim): for i in range(len(output_shape) - inp.ndim):
if extract_constant(output_shape[i], only_process_constants=True) == 1: if extract_constant(output_shape[i], only_process_constants=True) == 1:
num_dims_with_size_1_added_to_left += 1 num_dims_with_size_1_added_to_left += 1
else: else:
break break
new_output_shape = output_shape[num_dims_with_size_1_added_to_left:] new_output_shape = output_shape[num_dims_with_size_1_added_to_left:]
if num_dims_with_size_1_added_to_left > 0 and len(new_output_shape) >= input.ndim: if num_dims_with_size_1_added_to_left > 0 and len(new_output_shape) >= inp.ndim:
if ( if (
output.broadcastable[num_dims_with_size_1_added_to_left:] output.broadcastable[num_dims_with_size_1_added_to_left:]
== input.broadcastable == inp.broadcastable
): ):
inner = input inner = inp
else: else:
inner = op(*([input] + new_output_shape)) inner = op(*([inp] + new_output_shape))
dimshuffle_new_order = ["x"] * num_dims_with_size_1_added_to_left + list( dimshuffle_new_order = ["x"] * num_dims_with_size_1_added_to_left + list(
range(len(new_output_shape)) range(len(new_output_shape))
) )
...@@ -2292,14 +2292,14 @@ def local_rebroadcast_lift(fgraph, node): ...@@ -2292,14 +2292,14 @@ def local_rebroadcast_lift(fgraph, node):
if not isinstance(op, Rebroadcast): if not isinstance(op, Rebroadcast):
return False return False
input = node.inputs[0] inp = node.inputs[0]
inode = input.owner inode = inp.owner
if inode and isinstance(inode.op, Elemwise) and len(inode.inputs) == 1: if inode and isinstance(inode.op, Elemwise) and len(inode.inputs) == 1:
# It may happen that `input` has no client because this optimization # It may happen that `input` has no client because this optimization
# is called from `apply_rebroadcast_opt`, which in particular is used # is called from `apply_rebroadcast_opt`, which in particular is used
# by the `unbroadcast` function before we are in the actual function # by the `unbroadcast` function before we are in the actual function
# compilation phase. # compilation phase.
if len(fgraph.clients.get(input, ())) == 1: if len(fgraph.clients.get(inp, ())) == 1:
rebroadcasted = Rebroadcast(*list(op.axis.items()))(inode.inputs[0]) rebroadcasted = Rebroadcast(*list(op.axis.items()))(inode.inputs[0])
# Copy over stacktrace from previous output (after rebroadcasting) # Copy over stacktrace from previous output (after rebroadcasting)
# to new output, because an error in the new graph right after # to new output, because an error in the new graph right after
...@@ -2755,28 +2755,24 @@ def local_useless_reshape(fgraph, node): ...@@ -2755,28 +2755,24 @@ def local_useless_reshape(fgraph, node):
if not isinstance(op, Reshape): if not isinstance(op, Reshape):
return False return False
input = node.inputs[0] inp = node.inputs[0]
output = node.outputs[0] output = node.outputs[0]
output_shape = node.inputs[1] output_shape = node.inputs[1]
if input.ndim != output.ndim: if inp.ndim != output.ndim:
return False return False
# Simple case: both input and output have a single dimension. # Simple case: both input and output have a single dimension.
# This could hide errors if the user provides inconsistent shapes. # This could hide errors if the user provides inconsistent shapes.
if ( if inp.ndim == 1 and output.ndim == 1 and inp.broadcastable == output.broadcastable:
input.ndim == 1 return [inp]
and output.ndim == 1
and input.broadcastable == output.broadcastable
):
return [input]
# Second case: all the shapes match the input shape # Second case: all the shapes match the input shape
# Match Reshape(x, x.shape) # Match Reshape(x, x.shape)
if output_shape.owner and isinstance(output_shape.owner.op, Shape): if output_shape.owner and isinstance(output_shape.owner.op, Shape):
shape_input = output_shape.owner.inputs[0] shape_input = output_shape.owner.inputs[0]
if shape_input == input: if shape_input == inp:
return [input] return [inp]
# Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for # Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for
# broadcastable and constant dimensions # broadcastable and constant dimensions
...@@ -2786,15 +2782,15 @@ def local_useless_reshape(fgraph, node): ...@@ -2786,15 +2782,15 @@ def local_useless_reshape(fgraph, node):
shape_feature = getattr(fgraph, "shape_feature", None) shape_feature = getattr(fgraph, "shape_feature", None)
nb_m1 = 0 nb_m1 = 0
shape_match = [False] * input.ndim shape_match = [False] * inp.ndim
for dim in range(input.ndim): for dim in range(inp.ndim):
outshp_i = output_shape_is[dim] outshp_i = output_shape_is[dim]
# Match Shape_i{dim}(input) # Match Shape_i{dim}(input)
if ( if (
outshp_i.owner outshp_i.owner
and isinstance(outshp_i.owner.op, Shape_i) and isinstance(outshp_i.owner.op, Shape_i)
and outshp_i.owner.op.i == dim and outshp_i.owner.op.i == dim
and outshp_i.owner.inputs[0] == input and outshp_i.owner.inputs[0] == inp
): ):
shape_match[dim] = True shape_match[dim] = True
continue continue
...@@ -2809,13 +2805,13 @@ def local_useless_reshape(fgraph, node): ...@@ -2809,13 +2805,13 @@ def local_useless_reshape(fgraph, node):
subtensor_inp = outshp_i.owner.inputs[0] subtensor_inp = outshp_i.owner.inputs[0]
if subtensor_inp.owner and isinstance(subtensor_inp.owner.op, Shape): if subtensor_inp.owner and isinstance(subtensor_inp.owner.op, Shape):
shape_input_i = subtensor_inp.owner.inputs[0] shape_input_i = subtensor_inp.owner.inputs[0]
if shape_input_i == input: if shape_input_i == inp:
shape_match[dim] = True shape_match[dim] = True
continue continue
# Match 1 if input.broadcastable[dim] is True # Match 1 if input.broadcastable[dim] is True
cst_outshp_i = extract_constant(outshp_i, only_process_constants=1) cst_outshp_i = extract_constant(outshp_i, only_process_constants=1)
if input.broadcastable[dim] and cst_outshp_i == 1: if inp.broadcastable[dim] and cst_outshp_i == 1:
shape_match[dim] = True shape_match[dim] = True
continue continue
...@@ -2827,7 +2823,7 @@ def local_useless_reshape(fgraph, node): ...@@ -2827,7 +2823,7 @@ def local_useless_reshape(fgraph, node):
# Match shape_of[input][dim] or its constant equivalent # Match shape_of[input][dim] or its constant equivalent
if shape_feature: if shape_feature:
inpshp_i = shape_feature.get_shape(input, dim) inpshp_i = shape_feature.get_shape(inp, dim)
if inpshp_i == outshp_i or ( if inpshp_i == outshp_i or (
extract_constant(inpshp_i, only_process_constants=1) extract_constant(inpshp_i, only_process_constants=1)
== extract_constant(outshp_i, only_process_constants=1) == extract_constant(outshp_i, only_process_constants=1)
...@@ -2836,7 +2832,7 @@ def local_useless_reshape(fgraph, node): ...@@ -2836,7 +2832,7 @@ def local_useless_reshape(fgraph, node):
continue continue
if all(shape_match) and nb_m1 <= 1: if all(shape_match) and nb_m1 <= 1:
return [input] return [inp]
# TODO later: if all the shapes except one match, we may want to # TODO later: if all the shapes except one match, we may want to
# consider it useless as well, like we do in the 1-dim case. # consider it useless as well, like we do in the 1-dim case.
...@@ -2862,7 +2858,7 @@ def local_reshape_to_dimshuffle(fgraph, node): ...@@ -2862,7 +2858,7 @@ def local_reshape_to_dimshuffle(fgraph, node):
if not isinstance(op, Reshape): if not isinstance(op, Reshape):
return False return False
input = node.inputs[0] inp = node.inputs[0]
output = node.outputs[0] output = node.outputs[0]
output_shape = node.inputs[1] output_shape = node.inputs[1]
...@@ -2883,7 +2879,7 @@ def local_reshape_to_dimshuffle(fgraph, node): ...@@ -2883,7 +2879,7 @@ def local_reshape_to_dimshuffle(fgraph, node):
new_output_shape.append(dim) new_output_shape.append(dim)
index = index + 1 index = index + 1
if index != output.ndim: if index != output.ndim:
inner = op.__class__(len(new_output_shape))(input, new_output_shape) inner = op.__class__(len(new_output_shape))(inp, new_output_shape)
copy_stack_trace(output, inner) copy_stack_trace(output, inner)
new_node = [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)] new_node = [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)]
copy_stack_trace(output, new_node) copy_stack_trace(output, new_node)
...@@ -2937,8 +2933,8 @@ register_canonicalize(OpRemove(tensor_copy), name="remove_tensor_copy") ...@@ -2937,8 +2933,8 @@ register_canonicalize(OpRemove(tensor_copy), name="remove_tensor_copy")
@local_optimizer(None) @local_optimizer(None)
def constant_folding(fgraph, node): def constant_folding(fgraph, node):
for input in node.inputs: for inp in node.inputs:
if not isinstance(input, Constant): if not isinstance(inp, Constant):
return False return False
# condition: all inputs are constant # condition: all inputs are constant
if not node.op.do_constant_folding(fgraph, node): if not node.op.do_constant_folding(fgraph, node):
......
...@@ -627,7 +627,7 @@ class AlgebraicCanonizer(LocalOptimizer): ...@@ -627,7 +627,7 @@ class AlgebraicCanonizer(LocalOptimizer):
def tracks(self): def tracks(self):
return [self.main, self.inverse, self.reciprocal] return [self.main, self.inverse, self.reciprocal]
def get_num_denum(self, input): def get_num_denum(self, inp):
r""" r"""
This extract two lists, ``num`` and ``denum``, such that the input is: This extract two lists, ``num`` and ``denum``, such that the input is:
``self.inverse(self.main(\*num), self.main(\*denum))``. It returns ``self.inverse(self.main(\*num), self.main(\*denum))``. It returns
...@@ -656,12 +656,12 @@ class AlgebraicCanonizer(LocalOptimizer): ...@@ -656,12 +656,12 @@ class AlgebraicCanonizer(LocalOptimizer):
# argument. The leaf-Variables of the graph covered by the # argument. The leaf-Variables of the graph covered by the
# recursion may be of any Variable type. # recursion may be of any Variable type.
if input.owner is None or input.owner.op not in [ if inp.owner is None or inp.owner.op not in [
self.main, self.main,
self.inverse, self.inverse,
self.reciprocal, self.reciprocal,
]: ]:
if input.owner and isinstance(input.owner.op, DimShuffle): if inp.owner and isinstance(inp.owner.op, DimShuffle):
# If input is a DimShuffle of some input which does # If input is a DimShuffle of some input which does
# something like this: # something like this:
...@@ -671,7 +671,7 @@ class AlgebraicCanonizer(LocalOptimizer): ...@@ -671,7 +671,7 @@ class AlgebraicCanonizer(LocalOptimizer):
# with broadcastable 1s to the *left* # with broadcastable 1s to the *left*
# Then we will simply discard the DimShuffle and return # Then we will simply discard the DimShuffle and return
# the num/denum of its input # the num/denum of its input
dsn = input.owner # dimshuffle node dsn = inp.owner # dimshuffle node
dsop = dsn.op # dimshuffle op dsop = dsn.op # dimshuffle op
# the first input of the dimshuffle i.e. the ndarray to redim # the first input of the dimshuffle i.e. the ndarray to redim
...@@ -687,22 +687,22 @@ class AlgebraicCanonizer(LocalOptimizer): ...@@ -687,22 +687,22 @@ class AlgebraicCanonizer(LocalOptimizer):
# different numbers of dimensions (hence why we can # different numbers of dimensions (hence why we can
# discard its information - we know we can retrieve it # discard its information - we know we can retrieve it
# later on). # later on).
compatible_order = ("x",) * (input.type.ndim - dsi0.type.ndim) + tuple( compatible_order = ("x",) * (inp.type.ndim - dsi0.type.ndim) + tuple(
range(dsi0.type.ndim) range(dsi0.type.ndim)
) )
if dsop.new_order == compatible_order: if dsop.new_order == compatible_order:
# If the "new_order" is the one we recognize, # If the "new_order" is the one we recognize,
# we return the num_denum of the dimshuffled input. # we return the num_denum of the dimshuffled input.
return self.get_num_denum(input.owner.inputs[0]) return self.get_num_denum(inp.owner.inputs[0])
else: else:
# This is when the input isn't produced by main, # This is when the input isn't produced by main,
# inverse or reciprocal. # inverse or reciprocal.
return [input], [] return [inp], []
else: else:
return [input], [] return [inp], []
num = [] num = []
denum = [] denum = []
parent = input.owner parent = inp.owner
# We get the (num, denum) pairs for each input # We get the (num, denum) pairs for each input
# pairs = [self.get_num_denum(input2) if input2.type.dtype == # pairs = [self.get_num_denum(input2) if input2.type.dtype ==
...@@ -1699,22 +1699,22 @@ def local_opt_alloc(fgraph, node): ...@@ -1699,22 +1699,22 @@ def local_opt_alloc(fgraph, node):
if isinstance(node.op, Sum) or isinstance(node.op, Prod): if isinstance(node.op, Sum) or isinstance(node.op, Prod):
(node_inps,) = node.inputs (node_inps,) = node.inputs
if node_inps.owner and isinstance(node_inps.owner.op, Alloc): if node_inps.owner and isinstance(node_inps.owner.op, Alloc):
input = node_inps.owner.inputs[0] inp = node_inps.owner.inputs[0]
shapes = node_inps.owner.inputs[1:] shapes = node_inps.owner.inputs[1:]
try: try:
val = get_scalar_constant_value(input, only_process_constants=True) val = get_scalar_constant_value(inp, only_process_constants=True)
assert val.size == 1 assert val.size == 1
val = val.reshape(1)[0] val = val.reshape(1)[0]
# check which type of op # check which type of op
size = mul(*shapes) size = mul(*shapes)
if input.dtype in ["float16", "float32"]: if inp.dtype in ["float16", "float32"]:
# shapes are ints and normally int64. # shapes are ints and normally int64.
# We don't want to have a float64 upcast # We don't want to have a float64 upcast
# We don't want to downcast to float16 # We don't want to downcast to float16
# as we fear it could loose too much precision # as we fear it could loose too much precision
# that will be amplified by the mul/pow below. # that will be amplified by the mul/pow below.
size = size.astype("float32") size = size.astype("float32")
if node.op.axis is None or node.op.axis == tuple(range(input.ndim)): if node.op.axis is None or node.op.axis == tuple(range(inp.ndim)):
if isinstance(node.op, Sum): if isinstance(node.op, Sum):
val = val * size val = val * size
else: else:
...@@ -2010,15 +2010,15 @@ def local_mul_specialize(fgraph, node): ...@@ -2010,15 +2010,15 @@ def local_mul_specialize(fgraph, node):
new_inputs = [] new_inputs = []
nb_neg_node = 0 nb_neg_node = 0
nb_cst = 0 nb_cst = 0
for input in node.inputs: for inp in node.inputs:
# remove any neg arguments # remove any neg arguments
while input.owner and input.owner.op == neg: while inp.owner and inp.owner.op == neg:
has_neg ^= True has_neg ^= True
input = input.owner.inputs[0] inp = inp.owner.inputs[0]
nb_neg_node += 1 nb_neg_node += 1
# remove special case arguments of 1, -1 or 0 # remove special case arguments of 1, -1 or 0
y = local_mul_canonizer.get_constant(input) y = local_mul_canonizer.get_constant(inp)
if y == 1.0: if y == 1.0:
nb_cst += 1 nb_cst += 1
elif y == -1.0: elif y == -1.0:
...@@ -2028,7 +2028,7 @@ def local_mul_specialize(fgraph, node): ...@@ -2028,7 +2028,7 @@ def local_mul_specialize(fgraph, node):
# if we find any zero, we just return right away # if we find any zero, we just return right away
return [broadcast_like(0, node.outputs[0], fgraph)] return [broadcast_like(0, node.outputs[0], fgraph)]
else: else:
new_inputs.append(input) new_inputs.append(inp)
if new_inputs != node.inputs: if new_inputs != node.inputs:
if new_inputs: if new_inputs:
...@@ -2072,14 +2072,14 @@ def local_add_specialize(fgraph, node): ...@@ -2072,14 +2072,14 @@ def local_add_specialize(fgraph, node):
# to put in un-necessary fills. # to put in un-necessary fills.
if node.op == add: if node.op == add:
new_inputs = [] new_inputs = []
for input in node.inputs: for inp in node.inputs:
try: try:
y = get_scalar_constant_value(input) y = get_scalar_constant_value(inp)
except NotScalarConstantError: except NotScalarConstantError:
y = input y = inp
if np.all(y == 0.0): if np.all(y == 0.0):
continue continue
new_inputs.append(input) new_inputs.append(inp)
if len(new_inputs) < len(node.inputs): if len(new_inputs) < len(node.inputs):
dtype = node.outputs[0].type.dtype dtype = node.outputs[0].type.dtype
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论