提交 b25fb1be authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Prevent unnecessary shadowing of builtin input

上级 9078756f
差异被折叠。
...@@ -627,7 +627,7 @@ class AlgebraicCanonizer(LocalOptimizer): ...@@ -627,7 +627,7 @@ class AlgebraicCanonizer(LocalOptimizer):
def tracks(self): def tracks(self):
return [self.main, self.inverse, self.reciprocal] return [self.main, self.inverse, self.reciprocal]
def get_num_denum(self, input): def get_num_denum(self, inp):
r""" r"""
This extract two lists, ``num`` and ``denum``, such that the input is: This extract two lists, ``num`` and ``denum``, such that the input is:
``self.inverse(self.main(\*num), self.main(\*denum))``. It returns ``self.inverse(self.main(\*num), self.main(\*denum))``. It returns
...@@ -656,12 +656,12 @@ class AlgebraicCanonizer(LocalOptimizer): ...@@ -656,12 +656,12 @@ class AlgebraicCanonizer(LocalOptimizer):
# argument. The leaf-Variables of the graph covered by the # argument. The leaf-Variables of the graph covered by the
# recursion may be of any Variable type. # recursion may be of any Variable type.
if input.owner is None or input.owner.op not in [ if inp.owner is None or inp.owner.op not in [
self.main, self.main,
self.inverse, self.inverse,
self.reciprocal, self.reciprocal,
]: ]:
if input.owner and isinstance(input.owner.op, DimShuffle): if inp.owner and isinstance(inp.owner.op, DimShuffle):
# If input is a DimShuffle of some input which does # If input is a DimShuffle of some input which does
# something like this: # something like this:
...@@ -671,7 +671,7 @@ class AlgebraicCanonizer(LocalOptimizer): ...@@ -671,7 +671,7 @@ class AlgebraicCanonizer(LocalOptimizer):
# with broadcastable 1s to the *left* # with broadcastable 1s to the *left*
# Then we will simply discard the DimShuffle and return # Then we will simply discard the DimShuffle and return
# the num/denum of its input # the num/denum of its input
dsn = input.owner # dimshuffle node dsn = inp.owner # dimshuffle node
dsop = dsn.op # dimshuffle op dsop = dsn.op # dimshuffle op
# the first input of the dimshuffle i.e. the ndarray to redim # the first input of the dimshuffle i.e. the ndarray to redim
...@@ -687,22 +687,22 @@ class AlgebraicCanonizer(LocalOptimizer): ...@@ -687,22 +687,22 @@ class AlgebraicCanonizer(LocalOptimizer):
# different numbers of dimensions (hence why we can # different numbers of dimensions (hence why we can
# discard its information - we know we can retrieve it # discard its information - we know we can retrieve it
# later on). # later on).
compatible_order = ("x",) * (input.type.ndim - dsi0.type.ndim) + tuple( compatible_order = ("x",) * (inp.type.ndim - dsi0.type.ndim) + tuple(
range(dsi0.type.ndim) range(dsi0.type.ndim)
) )
if dsop.new_order == compatible_order: if dsop.new_order == compatible_order:
# If the "new_order" is the one we recognize, # If the "new_order" is the one we recognize,
# we return the num_denum of the dimshuffled input. # we return the num_denum of the dimshuffled input.
return self.get_num_denum(input.owner.inputs[0]) return self.get_num_denum(inp.owner.inputs[0])
else: else:
# This is when the input isn't produced by main, # This is when the input isn't produced by main,
# inverse or reciprocal. # inverse or reciprocal.
return [input], [] return [inp], []
else: else:
return [input], [] return [inp], []
num = [] num = []
denum = [] denum = []
parent = input.owner parent = inp.owner
# We get the (num, denum) pairs for each input # We get the (num, denum) pairs for each input
# pairs = [self.get_num_denum(input2) if input2.type.dtype == # pairs = [self.get_num_denum(input2) if input2.type.dtype ==
...@@ -1699,22 +1699,22 @@ def local_opt_alloc(fgraph, node): ...@@ -1699,22 +1699,22 @@ def local_opt_alloc(fgraph, node):
if isinstance(node.op, Sum) or isinstance(node.op, Prod): if isinstance(node.op, Sum) or isinstance(node.op, Prod):
(node_inps,) = node.inputs (node_inps,) = node.inputs
if node_inps.owner and isinstance(node_inps.owner.op, Alloc): if node_inps.owner and isinstance(node_inps.owner.op, Alloc):
input = node_inps.owner.inputs[0] inp = node_inps.owner.inputs[0]
shapes = node_inps.owner.inputs[1:] shapes = node_inps.owner.inputs[1:]
try: try:
val = get_scalar_constant_value(input, only_process_constants=True) val = get_scalar_constant_value(inp, only_process_constants=True)
assert val.size == 1 assert val.size == 1
val = val.reshape(1)[0] val = val.reshape(1)[0]
# check which type of op # check which type of op
size = mul(*shapes) size = mul(*shapes)
if input.dtype in ["float16", "float32"]: if inp.dtype in ["float16", "float32"]:
# shapes are ints and normally int64. # shapes are ints and normally int64.
# We don't want to have a float64 upcast # We don't want to have a float64 upcast
# We don't want to downcast to float16 # We don't want to downcast to float16
# as we fear it could loose too much precision # as we fear it could loose too much precision
# that will be amplified by the mul/pow below. # that will be amplified by the mul/pow below.
size = size.astype("float32") size = size.astype("float32")
if node.op.axis is None or node.op.axis == tuple(range(input.ndim)): if node.op.axis is None or node.op.axis == tuple(range(inp.ndim)):
if isinstance(node.op, Sum): if isinstance(node.op, Sum):
val = val * size val = val * size
else: else:
...@@ -2010,15 +2010,15 @@ def local_mul_specialize(fgraph, node): ...@@ -2010,15 +2010,15 @@ def local_mul_specialize(fgraph, node):
new_inputs = [] new_inputs = []
nb_neg_node = 0 nb_neg_node = 0
nb_cst = 0 nb_cst = 0
for input in node.inputs: for inp in node.inputs:
# remove any neg arguments # remove any neg arguments
while input.owner and input.owner.op == neg: while inp.owner and inp.owner.op == neg:
has_neg ^= True has_neg ^= True
input = input.owner.inputs[0] inp = inp.owner.inputs[0]
nb_neg_node += 1 nb_neg_node += 1
# remove special case arguments of 1, -1 or 0 # remove special case arguments of 1, -1 or 0
y = local_mul_canonizer.get_constant(input) y = local_mul_canonizer.get_constant(inp)
if y == 1.0: if y == 1.0:
nb_cst += 1 nb_cst += 1
elif y == -1.0: elif y == -1.0:
...@@ -2028,7 +2028,7 @@ def local_mul_specialize(fgraph, node): ...@@ -2028,7 +2028,7 @@ def local_mul_specialize(fgraph, node):
# if we find any zero, we just return right away # if we find any zero, we just return right away
return [broadcast_like(0, node.outputs[0], fgraph)] return [broadcast_like(0, node.outputs[0], fgraph)]
else: else:
new_inputs.append(input) new_inputs.append(inp)
if new_inputs != node.inputs: if new_inputs != node.inputs:
if new_inputs: if new_inputs:
...@@ -2072,14 +2072,14 @@ def local_add_specialize(fgraph, node): ...@@ -2072,14 +2072,14 @@ def local_add_specialize(fgraph, node):
# to put in un-necessary fills. # to put in un-necessary fills.
if node.op == add: if node.op == add:
new_inputs = [] new_inputs = []
for input in node.inputs: for inp in node.inputs:
try: try:
y = get_scalar_constant_value(input) y = get_scalar_constant_value(inp)
except NotScalarConstantError: except NotScalarConstantError:
y = input y = inp
if np.all(y == 0.0): if np.all(y == 0.0):
continue continue
new_inputs.append(input) new_inputs.append(inp)
if len(new_inputs) < len(node.inputs): if len(new_inputs) < len(node.inputs):
dtype = node.outputs[0].type.dtype dtype = node.outputs[0].type.dtype
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论