提交 b25fb1be authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Prevent unnecessary shadowing of builtin input

上级 9078756f
差异被折叠。
......@@ -627,7 +627,7 @@ class AlgebraicCanonizer(LocalOptimizer):
def tracks(self):
return [self.main, self.inverse, self.reciprocal]
def get_num_denum(self, input):
def get_num_denum(self, inp):
r"""
This extract two lists, ``num`` and ``denum``, such that the input is:
``self.inverse(self.main(\*num), self.main(\*denum))``. It returns
......@@ -656,12 +656,12 @@ class AlgebraicCanonizer(LocalOptimizer):
# argument. The leaf-Variables of the graph covered by the
# recursion may be of any Variable type.
if input.owner is None or input.owner.op not in [
if inp.owner is None or inp.owner.op not in [
self.main,
self.inverse,
self.reciprocal,
]:
if input.owner and isinstance(input.owner.op, DimShuffle):
if inp.owner and isinstance(inp.owner.op, DimShuffle):
# If input is a DimShuffle of some input which does
# something like this:
......@@ -671,7 +671,7 @@ class AlgebraicCanonizer(LocalOptimizer):
# with broadcastable 1s to the *left*
# Then we will simply discard the DimShuffle and return
# the num/denum of its input
dsn = input.owner # dimshuffle node
dsn = inp.owner # dimshuffle node
dsop = dsn.op # dimshuffle op
# the first input of the dimshuffle i.e. the ndarray to redim
......@@ -687,22 +687,22 @@ class AlgebraicCanonizer(LocalOptimizer):
# different numbers of dimensions (hence why we can
# discard its information - we know we can retrieve it
# later on).
compatible_order = ("x",) * (input.type.ndim - dsi0.type.ndim) + tuple(
compatible_order = ("x",) * (inp.type.ndim - dsi0.type.ndim) + tuple(
range(dsi0.type.ndim)
)
if dsop.new_order == compatible_order:
# If the "new_order" is the one we recognize,
# we return the num_denum of the dimshuffled input.
return self.get_num_denum(input.owner.inputs[0])
return self.get_num_denum(inp.owner.inputs[0])
else:
# This is when the input isn't produced by main,
# inverse or reciprocal.
return [input], []
return [inp], []
else:
return [input], []
return [inp], []
num = []
denum = []
parent = input.owner
parent = inp.owner
# We get the (num, denum) pairs for each input
# pairs = [self.get_num_denum(input2) if input2.type.dtype ==
......@@ -1699,22 +1699,22 @@ def local_opt_alloc(fgraph, node):
if isinstance(node.op, Sum) or isinstance(node.op, Prod):
(node_inps,) = node.inputs
if node_inps.owner and isinstance(node_inps.owner.op, Alloc):
input = node_inps.owner.inputs[0]
inp = node_inps.owner.inputs[0]
shapes = node_inps.owner.inputs[1:]
try:
val = get_scalar_constant_value(input, only_process_constants=True)
val = get_scalar_constant_value(inp, only_process_constants=True)
assert val.size == 1
val = val.reshape(1)[0]
# check which type of op
size = mul(*shapes)
if input.dtype in ["float16", "float32"]:
if inp.dtype in ["float16", "float32"]:
# shapes are ints and normally int64.
# We don't want to have a float64 upcast
# We don't want to downcast to float16
# as we fear it could loose too much precision
# that will be amplified by the mul/pow below.
size = size.astype("float32")
if node.op.axis is None or node.op.axis == tuple(range(input.ndim)):
if node.op.axis is None or node.op.axis == tuple(range(inp.ndim)):
if isinstance(node.op, Sum):
val = val * size
else:
......@@ -2010,15 +2010,15 @@ def local_mul_specialize(fgraph, node):
new_inputs = []
nb_neg_node = 0
nb_cst = 0
for input in node.inputs:
for inp in node.inputs:
# remove any neg arguments
while input.owner and input.owner.op == neg:
while inp.owner and inp.owner.op == neg:
has_neg ^= True
input = input.owner.inputs[0]
inp = inp.owner.inputs[0]
nb_neg_node += 1
# remove special case arguments of 1, -1 or 0
y = local_mul_canonizer.get_constant(input)
y = local_mul_canonizer.get_constant(inp)
if y == 1.0:
nb_cst += 1
elif y == -1.0:
......@@ -2028,7 +2028,7 @@ def local_mul_specialize(fgraph, node):
# if we find any zero, we just return right away
return [broadcast_like(0, node.outputs[0], fgraph)]
else:
new_inputs.append(input)
new_inputs.append(inp)
if new_inputs != node.inputs:
if new_inputs:
......@@ -2072,14 +2072,14 @@ def local_add_specialize(fgraph, node):
# to put in un-necessary fills.
if node.op == add:
new_inputs = []
for input in node.inputs:
for inp in node.inputs:
try:
y = get_scalar_constant_value(input)
y = get_scalar_constant_value(inp)
except NotScalarConstantError:
y = input
y = inp
if np.all(y == 0.0):
continue
new_inputs.append(input)
new_inputs.append(inp)
if len(new_inputs) < len(node.inputs):
dtype = node.outputs[0].type.dtype
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论