提交 326cb2e3 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fail graciously in local_pow_to_nested_squaring when static type shape is updated

上级 3169197c
......@@ -2081,63 +2081,65 @@ def local_pow_to_nested_squaring(fgraph, node):
Note: This sounds like the kind of thing any half-decent compiler can do by itself?
"""
if node.op == at_pow:
# the idea here is that we have pow(x, y)
odtype = node.outputs[0].dtype
xsym = node.inputs[0]
ysym = node.inputs[1]
y = get_constant(ysym)
# the next line is needed to fix a strange case that I don't
# know how to make a separate test.
# That happen in the `test_log_erfc` test.
# y is a ndarray with dtype int8 and value 2,4 or 6. This make
# the abs(y) <= 512 fail!
# taking the value outside ndarray solve the problem.
# it could be that in that case, numpy make the comparison
# into the wrong type(do in int8 that overflow.)
if isinstance(y, np.ndarray):
assert y.size == 1
try:
y = y[0]
except IndexError:
pass
if (y is not None) and not broadcasted_by(xsym, ysym):
rval = None
# 512 is too small for the cpu and too big for some gpu!
if abs(y) == int(abs(y)) and abs(y) <= 512:
pow2 = [xsym]
pow2_scal = [aes.get_scalar_type(xsym.dtype)()]
y_to_do = abs(y)
for i in range(int(np.log2(y_to_do))):
pow2.append(sqr(pow2[i]))
pow2_scal.append(aes.sqr(pow2_scal[i]))
rval1 = None
rval1_scal = None
while y_to_do > 0:
log_to_do = int(np.log2(y_to_do))
if rval1:
rval1 *= pow2[log_to_do]
rval1_scal *= pow2_scal[log_to_do]
else:
rval1 = pow2[log_to_do]
rval1_scal = pow2_scal[log_to_do]
y_to_do -= 2**log_to_do
if abs(y) > 2:
# We fuse all the pow together here to make
# compilation faster
rval1 = Elemwise(
aes.Composite([pow2_scal[0]], [rval1_scal])
).make_node(xsym)
if y < 0:
rval = [reciprocal(rval1)]
# the idea here is that we have pow(x, y)
odtype = node.outputs[0].dtype
xsym = node.inputs[0]
ysym = node.inputs[1]
y = get_constant(ysym)
# the next line is needed to fix a strange case that I don't
# know how to make a separate test.
# That happen in the `test_log_erfc` test.
# y is a ndarray with dtype int8 and value 2,4 or 6. This make
# the abs(y) <= 512 fail!
# taking the value outside ndarray solve the problem.
# it could be that in that case, numpy make the comparison
# into the wrong type(do in int8 that overflow.)
if isinstance(y, np.ndarray):
assert y.size == 1
try:
y = y[0]
except IndexError:
pass
if (y is not None) and not broadcasted_by(xsym, ysym):
rval = None
# 512 is too small for the cpu and too big for some gpu!
if abs(y) == int(abs(y)) and abs(y) <= 512:
pow2 = [xsym]
pow2_scal = [aes.get_scalar_type(xsym.dtype)()]
y_to_do = abs(y)
for i in range(int(np.log2(y_to_do))):
pow2.append(sqr(pow2[i]))
pow2_scal.append(aes.sqr(pow2_scal[i]))
rval1 = None
rval1_scal = None
while y_to_do > 0:
log_to_do = int(np.log2(y_to_do))
if rval1:
rval1 *= pow2[log_to_do]
rval1_scal *= pow2_scal[log_to_do]
else:
rval = [rval1]
if rval:
rval[0] = cast(rval[0], odtype)
assert rval[0].type == node.outputs[0].type, (rval, node.outputs)
return rval
rval1 = pow2[log_to_do]
rval1_scal = pow2_scal[log_to_do]
y_to_do -= 2**log_to_do
if abs(y) > 2:
# We fuse all the pow together here to make
# compilation faster
rval1 = Elemwise(aes.Composite([pow2_scal[0]], [rval1_scal])).make_node(
xsym
)
if y < 0:
rval = [reciprocal(rval1)]
else:
rval = [rval1]
if rval:
rval[0] = cast(rval[0], odtype)
# TODO: We can add a specify_broadcastable and/or unbroadcast to make the
# output types compatible. Or work on #408 and let TensorType.filter_variable do it.
if rval[0].type.broadcastable != node.outputs[0].type.broadcastable:
return None
return rval
@register_specialize
......
......@@ -29,8 +29,9 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph
from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import debugprint
from pytensor.scalar import Pow
from pytensor.tensor import inplace
from pytensor.tensor.basic import Alloc, join, second, switch
from pytensor.tensor.basic import Alloc, constant, join, second, switch
from pytensor.tensor.blas import Dot22, Gemv
from pytensor.tensor.blas_c import CGemv
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
......@@ -69,7 +70,7 @@ from pytensor.tensor.math import max as at_max
from pytensor.tensor.math import maximum
from pytensor.tensor.math import min as at_min
from pytensor.tensor.math import minimum, mul, neg, neq
from pytensor.tensor.math import pow as at_pow
from pytensor.tensor.math import pow as pt_pow
from pytensor.tensor.math import (
prod,
rad2deg,
......@@ -1746,6 +1747,29 @@ def test_local_pow_to_nested_squaring():
utt.assert_allclose(f(val_no0), val_no0 ** (-16))
def test_local_pow_to_nested_squaring_fails_gracefully():
# Reported in #456
x = vector("x", shape=(1,))
# Create an Apply that does not have precise output shape
node = Apply(
op=pt_pow,
inputs=[x, constant([2.0])],
outputs=[tensor(shape=(None,))],
)
y = node.default_output()
fn = function([x], y)
# Check rewrite is not applied (this could change in the future)
assert any(
(isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Pow))
for node in fn.maker.fgraph.apply_nodes
)
np.testing.assert_allclose(fn([2.0]), np.array([4.0]))
class TestFuncInverse:
def setup_method(self):
mode = get_default_mode()
......@@ -2449,7 +2473,7 @@ class TestLocalMergeSwitchSameCond:
le,
eq,
neq,
at_pow,
pt_pow,
):
g = rewrite(FunctionGraph(mats, [op(s1, s2)]))
assert debugprint(g, file="str").count("Switch") == 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论