提交 326cb2e3 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fail graciously in local_pow_to_nested_squaring when static type shape is updated

上级 3169197c
...@@ -2081,63 +2081,65 @@ def local_pow_to_nested_squaring(fgraph, node): ...@@ -2081,63 +2081,65 @@ def local_pow_to_nested_squaring(fgraph, node):
Note: This sounds like the kind of thing any half-decent compiler can do by itself? Note: This sounds like the kind of thing any half-decent compiler can do by itself?
""" """
if node.op == at_pow: # the idea here is that we have pow(x, y)
# the idea here is that we have pow(x, y) odtype = node.outputs[0].dtype
odtype = node.outputs[0].dtype xsym = node.inputs[0]
xsym = node.inputs[0] ysym = node.inputs[1]
ysym = node.inputs[1] y = get_constant(ysym)
y = get_constant(ysym)
# the next line is needed to fix a strange case that I don't
# the next line is needed to fix a strange case that I don't # know how to make a separate test.
# know how to make a separate test. # That happen in the `test_log_erfc` test.
# That happen in the `test_log_erfc` test. # y is a ndarray with dtype int8 and value 2,4 or 6. This make
# y is a ndarray with dtype int8 and value 2,4 or 6. This make # the abs(y) <= 512 fail!
# the abs(y) <= 512 fail! # taking the value outside ndarray solve the problem.
# taking the value outside ndarray solve the problem. # it could be that in that case, numpy make the comparison
# it could be that in that case, numpy make the comparison # into the wrong type(do in int8 that overflow.)
# into the wrong type(do in int8 that overflow.) if isinstance(y, np.ndarray):
if isinstance(y, np.ndarray): assert y.size == 1
assert y.size == 1 try:
try: y = y[0]
y = y[0] except IndexError:
except IndexError: pass
pass if (y is not None) and not broadcasted_by(xsym, ysym):
if (y is not None) and not broadcasted_by(xsym, ysym): rval = None
rval = None # 512 is too small for the cpu and too big for some gpu!
# 512 is too small for the cpu and too big for some gpu! if abs(y) == int(abs(y)) and abs(y) <= 512:
if abs(y) == int(abs(y)) and abs(y) <= 512: pow2 = [xsym]
pow2 = [xsym] pow2_scal = [aes.get_scalar_type(xsym.dtype)()]
pow2_scal = [aes.get_scalar_type(xsym.dtype)()] y_to_do = abs(y)
y_to_do = abs(y) for i in range(int(np.log2(y_to_do))):
for i in range(int(np.log2(y_to_do))): pow2.append(sqr(pow2[i]))
pow2.append(sqr(pow2[i])) pow2_scal.append(aes.sqr(pow2_scal[i]))
pow2_scal.append(aes.sqr(pow2_scal[i])) rval1 = None
rval1 = None rval1_scal = None
rval1_scal = None while y_to_do > 0:
while y_to_do > 0: log_to_do = int(np.log2(y_to_do))
log_to_do = int(np.log2(y_to_do)) if rval1:
if rval1: rval1 *= pow2[log_to_do]
rval1 *= pow2[log_to_do] rval1_scal *= pow2_scal[log_to_do]
rval1_scal *= pow2_scal[log_to_do]
else:
rval1 = pow2[log_to_do]
rval1_scal = pow2_scal[log_to_do]
y_to_do -= 2**log_to_do
if abs(y) > 2:
# We fuse all the pow together here to make
# compilation faster
rval1 = Elemwise(
aes.Composite([pow2_scal[0]], [rval1_scal])
).make_node(xsym)
if y < 0:
rval = [reciprocal(rval1)]
else: else:
rval = [rval1] rval1 = pow2[log_to_do]
if rval: rval1_scal = pow2_scal[log_to_do]
rval[0] = cast(rval[0], odtype) y_to_do -= 2**log_to_do
assert rval[0].type == node.outputs[0].type, (rval, node.outputs)
return rval if abs(y) > 2:
# We fuse all the pow together here to make
# compilation faster
rval1 = Elemwise(aes.Composite([pow2_scal[0]], [rval1_scal])).make_node(
xsym
)
if y < 0:
rval = [reciprocal(rval1)]
else:
rval = [rval1]
if rval:
rval[0] = cast(rval[0], odtype)
# TODO: We can add a specify_broadcastable and/or unbroadcast to make the
# output types compatible. Or work on #408 and let TensorType.filter_variable do it.
if rval[0].type.broadcastable != node.outputs[0].type.broadcastable:
return None
return rval
@register_specialize @register_specialize
......
...@@ -29,8 +29,9 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery ...@@ -29,8 +29,9 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph
from pytensor.misc.safe_asarray import _asarray from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import debugprint from pytensor.printing import debugprint
from pytensor.scalar import Pow
from pytensor.tensor import inplace from pytensor.tensor import inplace
from pytensor.tensor.basic import Alloc, join, second, switch from pytensor.tensor.basic import Alloc, constant, join, second, switch
from pytensor.tensor.blas import Dot22, Gemv from pytensor.tensor.blas import Dot22, Gemv
from pytensor.tensor.blas_c import CGemv from pytensor.tensor.blas_c import CGemv
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
...@@ -69,7 +70,7 @@ from pytensor.tensor.math import max as at_max ...@@ -69,7 +70,7 @@ from pytensor.tensor.math import max as at_max
from pytensor.tensor.math import maximum from pytensor.tensor.math import maximum
from pytensor.tensor.math import min as at_min from pytensor.tensor.math import min as at_min
from pytensor.tensor.math import minimum, mul, neg, neq from pytensor.tensor.math import minimum, mul, neg, neq
from pytensor.tensor.math import pow as at_pow from pytensor.tensor.math import pow as pt_pow
from pytensor.tensor.math import ( from pytensor.tensor.math import (
prod, prod,
rad2deg, rad2deg,
...@@ -1746,6 +1747,29 @@ def test_local_pow_to_nested_squaring(): ...@@ -1746,6 +1747,29 @@ def test_local_pow_to_nested_squaring():
utt.assert_allclose(f(val_no0), val_no0 ** (-16)) utt.assert_allclose(f(val_no0), val_no0 ** (-16))
def test_local_pow_to_nested_squaring_fails_gracefully():
# Reported in #456
x = vector("x", shape=(1,))
# Create an Apply that does not have precise output shape
node = Apply(
op=pt_pow,
inputs=[x, constant([2.0])],
outputs=[tensor(shape=(None,))],
)
y = node.default_output()
fn = function([x], y)
# Check rewrite is not applied (this could change in the future)
assert any(
(isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Pow))
for node in fn.maker.fgraph.apply_nodes
)
np.testing.assert_allclose(fn([2.0]), np.array([4.0]))
class TestFuncInverse: class TestFuncInverse:
def setup_method(self): def setup_method(self):
mode = get_default_mode() mode = get_default_mode()
...@@ -2449,7 +2473,7 @@ class TestLocalMergeSwitchSameCond: ...@@ -2449,7 +2473,7 @@ class TestLocalMergeSwitchSameCond:
le, le,
eq, eq,
neq, neq,
at_pow, pt_pow,
): ):
g = rewrite(FunctionGraph(mats, [op(s1, s2)])) g = rewrite(FunctionGraph(mats, [op(s1, s2)]))
assert debugprint(g, file="str").count("Switch") == 1 assert debugprint(g, file="str").count("Switch") == 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论