提交 35f52cff authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Extract get_constant from AlgebraicCanonizer

上级 b25fb1be
......@@ -101,6 +101,31 @@ _logger = logging.getLogger("aesara.tensor.math_opt")
_logger.addFilter(NoDuplicateOptWarningFilter())
def get_constant(v):
"""
Returns
-------
object
A numeric constant if v is a Constant or, well, a
numeric constant. If v is a plain Variable, returns None.
"""
if isinstance(v, Constant):
if getattr(v.tag, "unique_value", None) is not None:
data = v.tag.unique_value
else:
data = v.data
if data.ndim == 0:
return data
else:
return None
elif isinstance(v, Variable):
return None
else:
return v
def fill_chain(new_out, orig_inputs):
for i in orig_inputs:
new_out = fill(i, new_out)
......@@ -777,31 +802,6 @@ class AlgebraicCanonizer(LocalOptimizer):
self.merge_num_denum(num, []), self.merge_num_denum(denum, [])
)
@staticmethod
def get_constant(v):
"""
Returns
-------
object
A numeric constant if v is a Constant or, well, a
numeric constant. If v is a plain Variable, returns None.
"""
if isinstance(v, Constant):
if getattr(v.tag, "unique_value", None) is not None:
data = v.tag.unique_value
else:
data = v.data
if data.ndim == 0:
return data
else:
return None
elif isinstance(v, Variable):
return None
else:
return v
def simplify(self, num, denum, out_type):
"""
Shorthand for:
......@@ -879,7 +879,7 @@ class AlgebraicCanonizer(LocalOptimizer):
numct, denumct = [], []
for v in orig_num:
ct = self.get_constant(v)
ct = get_constant(v)
if ct is not None:
# We found a constant in the numerator!
# We add it to numct
......@@ -887,7 +887,7 @@ class AlgebraicCanonizer(LocalOptimizer):
else:
num.append(v)
for v in orig_denum:
ct = self.get_constant(v)
ct = get_constant(v)
if ct is not None:
denumct.append(ct)
else:
......@@ -914,7 +914,7 @@ class AlgebraicCanonizer(LocalOptimizer):
if orig_num and len(numct) == 1 and len(denumct) == 0 and ct:
# In that case we should only have one constant in `ct`.
assert len(ct) == 1
first_num_ct = self.get_constant(orig_num[0])
first_num_ct = get_constant(orig_num[0])
if first_num_ct is not None and ct[0].type.values_eq(
ct[0].data, first_num_ct
):
......@@ -1801,9 +1801,7 @@ def local_mul_zero(fgraph, node):
@register_specialize
@local_optimizer([true_div])
def local_div_to_reciprocal(fgraph, node):
if node.op == true_div and np.all(
local_mul_canonizer.get_constant(node.inputs[0]) == 1.0
):
if node.op == true_div and np.all(get_constant(node.inputs[0]) == 1.0):
out = node.outputs[0]
new_out = reciprocal(local_mul_canonizer.merge_num_denum(node.inputs[1:], []))
# The ones could have forced upcasting
......@@ -1830,7 +1828,7 @@ def local_reciprocal_canon(fgraph, node):
@local_optimizer([aet_pow])
def local_pow_canonicalize(fgraph, node):
if node.op == aet_pow:
cst = local_mul_canonizer.get_constant(node.inputs[1])
cst = get_constant(node.inputs[1])
if cst == 0:
return [broadcast_like(1, node.outputs[0], fgraph)]
if cst == 1:
......@@ -1874,7 +1872,7 @@ def local_zero_div(fgraph, node):
if isinstance(node.op, Elemwise) and isinstance(
node.op.scalar_op, (aes.IntDiv, aes.TrueDiv)
):
if local_mul_canonizer.get_constant(node.inputs[0]) == 0:
if get_constant(node.inputs[0]) == 0:
ret = broadcast_like(0, node.outputs[0], fgraph)
ret.tag.values_eq_approx = values_eq_approx_remove_nan
return [ret]
......@@ -1890,7 +1888,7 @@ def local_pow_specialize(fgraph, node):
odtype = node.outputs[0].dtype
xsym = node.inputs[0]
ysym = node.inputs[1]
y = local_mul_canonizer.get_constant(ysym)
y = get_constant(ysym)
if (y is not None) and encompasses_broadcastable(
xsym.type.broadcastable, ysym.type.broadcastable
):
......@@ -1929,7 +1927,7 @@ def local_pow_specialize_device(fgraph, node):
odtype = node.outputs[0].dtype
xsym = node.inputs[0]
ysym = node.inputs[1]
y = local_mul_canonizer.get_constant(ysym)
y = get_constant(ysym)
# the next line is needed to fix a strange case that I don't
# know how to make a separate test.
......@@ -2018,7 +2016,7 @@ def local_mul_specialize(fgraph, node):
nb_neg_node += 1
# remove special case arguments of 1, -1 or 0
y = local_mul_canonizer.get_constant(inp)
y = get_constant(inp)
if y == 1.0:
nb_cst += 1
elif y == -1.0:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论