提交 04ce1c6c authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Remove internal `get_constant helper`

Fixes bug in `local_add_neg_to_sub` reported in https://github.com/pymc-devs/pytensor/issues/584
上级 55f3cd0c
...@@ -126,24 +126,6 @@ def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False): ...@@ -126,24 +126,6 @@ def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
return consts, origconsts, nonconsts return consts, origconsts, nonconsts
def get_constant(v):
"""
Returns
-------
object
A numeric constant if v is a Constant or, well, a
numeric constant. If v is a plain Variable, returns None.
"""
if isinstance(v, TensorConstant):
return v.unique_value
elif isinstance(v, Variable):
return None
else:
return v
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@node_rewriter([Dot]) @node_rewriter([Dot])
...@@ -994,8 +976,8 @@ class AlgebraicCanonizer(NodeRewriter): ...@@ -994,8 +976,8 @@ class AlgebraicCanonizer(NodeRewriter):
""" """
Find all constants and put them together into a single constant. Find all constants and put them together into a single constant.
Finds all constants in orig_num and orig_denum (using Finds all constants in orig_num and orig_denum
get_constant) and puts them together into a single and puts them together into a single
constant. The constant is inserted as the first element of the constant. The constant is inserted as the first element of the
numerator. If the constant is the neutral element, it is numerator. If the constant is the neutral element, it is
removed from the numerator. removed from the numerator.
...@@ -1016,17 +998,15 @@ class AlgebraicCanonizer(NodeRewriter): ...@@ -1016,17 +998,15 @@ class AlgebraicCanonizer(NodeRewriter):
numct, denumct = [], [] numct, denumct = [], []
for v in orig_num: for v in orig_num:
ct = get_constant(v) if isinstance(v, TensorConstant) and v.unique_value is not None:
if ct is not None:
# We found a constant in the numerator! # We found a constant in the numerator!
# We add it to numct # We add it to numct
numct.append(ct) numct.append(v.unique_value)
else: else:
num.append(v) num.append(v)
for v in orig_denum: for v in orig_denum:
ct = get_constant(v) if isinstance(v, TensorConstant) and v.unique_value is not None:
if ct is not None: denumct.append(v.unique_value)
denumct.append(ct)
else: else:
denum.append(v) denum.append(v)
...@@ -1050,10 +1030,15 @@ class AlgebraicCanonizer(NodeRewriter): ...@@ -1050,10 +1030,15 @@ class AlgebraicCanonizer(NodeRewriter):
if orig_num and len(numct) == 1 and len(denumct) == 0 and ct: if orig_num and len(numct) == 1 and len(denumct) == 0 and ct:
# In that case we should only have one constant in `ct`. # In that case we should only have one constant in `ct`.
assert len(ct) == 1 [var_ct] = ct
first_num_ct = get_constant(orig_num[0]) first_num_var = orig_num[0]
if first_num_ct is not None and ct[0].type.values_eq( first_num_ct = (
ct[0].data, first_num_ct first_num_var.unique_value
if isinstance(first_num_var, TensorConstant)
else None
)
if first_num_ct is not None and var_ct.type.values_eq(
var_ct.data, first_num_ct
): ):
# This is an important trick :( if it so happens that: # This is an important trick :( if it so happens that:
# * there's exactly one constant on the numerator and none on # * there's exactly one constant on the numerator and none on
...@@ -1840,9 +1825,12 @@ def local_add_neg_to_sub(fgraph, node): ...@@ -1840,9 +1825,12 @@ def local_add_neg_to_sub(fgraph, node):
return [new_out] return [new_out]
# Check if it is a negative constant # Check if it is a negative constant
const = get_constant(second) if (
if const is not None and const < 0: isinstance(second, TensorConstant)
new_out = sub(first, np.abs(const)) and second.unique_value is not None
and second.unique_value < 0
):
new_out = sub(first, np.abs(second.data))
return [new_out] return [new_out]
...@@ -1871,7 +1859,12 @@ def local_mul_zero(fgraph, node): ...@@ -1871,7 +1859,12 @@ def local_mul_zero(fgraph, node):
@register_specialize @register_specialize
@node_rewriter([true_div]) @node_rewriter([true_div])
def local_div_to_reciprocal(fgraph, node): def local_div_to_reciprocal(fgraph, node):
if np.all(get_constant(node.inputs[0]) == 1.0): if (
get_underlying_scalar_constant_value(
node.inputs[0], only_process_constants=True, raise_not_constant=False
)
== 1.0
):
out = node.outputs[0] out = node.outputs[0]
new_out = reciprocal(local_mul_canonizer.merge_num_denum(node.inputs[1:], [])) new_out = reciprocal(local_mul_canonizer.merge_num_denum(node.inputs[1:], []))
# The ones could have forced upcasting # The ones could have forced upcasting
...@@ -1892,7 +1885,9 @@ def local_reciprocal_canon(fgraph, node): ...@@ -1892,7 +1885,9 @@ def local_reciprocal_canon(fgraph, node):
@register_canonicalize @register_canonicalize
@node_rewriter([pt_pow]) @node_rewriter([pt_pow])
def local_pow_canonicalize(fgraph, node): def local_pow_canonicalize(fgraph, node):
cst = get_constant(node.inputs[1]) cst = get_underlying_scalar_constant_value(
node.inputs[1], only_process_constants=True, raise_not_constant=False
)
if cst == 0: if cst == 0:
return [alloc_like(1, node.outputs[0], fgraph)] return [alloc_like(1, node.outputs[0], fgraph)]
if cst == 1: if cst == 1:
...@@ -1923,7 +1918,12 @@ def local_intdiv_by_one(fgraph, node): ...@@ -1923,7 +1918,12 @@ def local_intdiv_by_one(fgraph, node):
@node_rewriter([int_div, true_div]) @node_rewriter([int_div, true_div])
def local_zero_div(fgraph, node): def local_zero_div(fgraph, node):
"""0 / x -> 0""" """0 / x -> 0"""
if get_constant(node.inputs[0]) == 0: if (
get_underlying_scalar_constant_value(
node.inputs[0], only_process_constants=True, raise_not_constant=False
)
== 0
):
ret = alloc_like(0, node.outputs[0], fgraph) ret = alloc_like(0, node.outputs[0], fgraph)
ret.tag.values_eq_approx = values_eq_approx_remove_nan ret.tag.values_eq_approx = values_eq_approx_remove_nan
return [ret] return [ret]
...@@ -1936,8 +1936,12 @@ def local_pow_specialize(fgraph, node): ...@@ -1936,8 +1936,12 @@ def local_pow_specialize(fgraph, node):
odtype = node.outputs[0].dtype odtype = node.outputs[0].dtype
xsym = node.inputs[0] xsym = node.inputs[0]
ysym = node.inputs[1] ysym = node.inputs[1]
y = get_constant(ysym) try:
if (y is not None) and not broadcasted_by(xsym, ysym): y = get_underlying_scalar_constant_value(ysym, only_process_constants=True)
except NotScalarConstantError:
return
if not broadcasted_by(xsym, ysym):
rval = None rval = None
if np.all(y == 2): if np.all(y == 2):
...@@ -1971,10 +1975,14 @@ def local_pow_to_nested_squaring(fgraph, node): ...@@ -1971,10 +1975,14 @@ def local_pow_to_nested_squaring(fgraph, node):
""" """
# the idea here is that we have pow(x, y) # the idea here is that we have pow(x, y)
xsym, ysym = node.inputs
try:
y = get_underlying_scalar_constant_value(ysym, only_process_constants=True)
except NotScalarConstantError:
return
odtype = node.outputs[0].dtype odtype = node.outputs[0].dtype
xsym = node.inputs[0]
ysym = node.inputs[1]
y = get_constant(ysym)
# the next line is needed to fix a strange case that I don't # the next line is needed to fix a strange case that I don't
# know how to make a separate test. # know how to make a separate test.
...@@ -1990,7 +1998,7 @@ def local_pow_to_nested_squaring(fgraph, node): ...@@ -1990,7 +1998,7 @@ def local_pow_to_nested_squaring(fgraph, node):
y = y[0] y = y[0]
except IndexError: except IndexError:
pass pass
if (y is not None) and not broadcasted_by(xsym, ysym): if not broadcasted_by(xsym, ysym):
rval = None rval = None
# 512 is too small for the cpu and too big for some gpu! # 512 is too small for the cpu and too big for some gpu!
if abs(y) == int(abs(y)) and abs(y) <= 512: if abs(y) == int(abs(y)) and abs(y) <= 512:
...@@ -2057,7 +2065,9 @@ def local_mul_specialize(fgraph, node): ...@@ -2057,7 +2065,9 @@ def local_mul_specialize(fgraph, node):
nb_neg_node += 1 nb_neg_node += 1
# remove special case arguments of 1, -1 or 0 # remove special case arguments of 1, -1 or 0
y = get_constant(inp) y = get_underlying_scalar_constant_value(
inp, only_process_constants=True, raise_not_constant=False
)
if y == 1.0: if y == 1.0:
nb_cst += 1 nb_cst += 1
elif y == -1.0: elif y == -1.0:
......
...@@ -4440,16 +4440,18 @@ def test_local_add_neg_to_sub(first_negative): ...@@ -4440,16 +4440,18 @@ def test_local_add_neg_to_sub(first_negative):
assert np.allclose(f(x_test, y_test), exp) assert np.allclose(f(x_test, y_test), exp)
def test_local_add_neg_to_sub_const(): @pytest.mark.parametrize("const_left", (True, False))
def test_local_add_neg_to_sub_const(const_left):
x = vector("x") x = vector("x")
const = 5.0 const = np.full((3, 2), 5.0)
out = -const + x if const_left else x + (-const)
f = function([x], x + (-const), mode=Mode("py")) f = function([x], out, mode=Mode("py"))
nodes = [ nodes = [
node.op node.op
for node in f.maker.fgraph.toposort() for node in f.maker.fgraph.toposort()
if not isinstance(node.op, DimShuffle) if not isinstance(node.op, DimShuffle | Alloc)
] ]
assert nodes == [pt.sub] assert nodes == [pt.sub]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论