提交 326cb2e3 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fail graciously in local_pow_to_nested_squaring when static type shape is updated

上级 3169197c
...@@ -2081,7 +2081,6 @@ def local_pow_to_nested_squaring(fgraph, node): ...@@ -2081,7 +2081,6 @@ def local_pow_to_nested_squaring(fgraph, node):
Note: This sounds like the kind of thing any half-decent compiler can do by itself? Note: This sounds like the kind of thing any half-decent compiler can do by itself?
""" """
if node.op == at_pow:
# the idea here is that we have pow(x, y) # the idea here is that we have pow(x, y)
odtype = node.outputs[0].dtype odtype = node.outputs[0].dtype
xsym = node.inputs[0] xsym = node.inputs[0]
...@@ -2127,16 +2126,19 @@ def local_pow_to_nested_squaring(fgraph, node): ...@@ -2127,16 +2126,19 @@ def local_pow_to_nested_squaring(fgraph, node):
if abs(y) > 2: if abs(y) > 2:
# We fuse all the pow together here to make # We fuse all the pow together here to make
# compilation faster # compilation faster
rval1 = Elemwise( rval1 = Elemwise(aes.Composite([pow2_scal[0]], [rval1_scal])).make_node(
aes.Composite([pow2_scal[0]], [rval1_scal]) xsym
).make_node(xsym) )
if y < 0: if y < 0:
rval = [reciprocal(rval1)] rval = [reciprocal(rval1)]
else: else:
rval = [rval1] rval = [rval1]
if rval: if rval:
rval[0] = cast(rval[0], odtype) rval[0] = cast(rval[0], odtype)
assert rval[0].type == node.outputs[0].type, (rval, node.outputs) # TODO: We can add a specify_broadcastable and/or unbroadcast to make the
# output types compatible. Or work on #408 and let TensorType.filter_variable do it.
if rval[0].type.broadcastable != node.outputs[0].type.broadcastable:
return None
return rval return rval
......
...@@ -29,8 +29,9 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery ...@@ -29,8 +29,9 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph
from pytensor.misc.safe_asarray import _asarray from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import debugprint from pytensor.printing import debugprint
from pytensor.scalar import Pow
from pytensor.tensor import inplace from pytensor.tensor import inplace
from pytensor.tensor.basic import Alloc, join, second, switch from pytensor.tensor.basic import Alloc, constant, join, second, switch
from pytensor.tensor.blas import Dot22, Gemv from pytensor.tensor.blas import Dot22, Gemv
from pytensor.tensor.blas_c import CGemv from pytensor.tensor.blas_c import CGemv
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
...@@ -69,7 +70,7 @@ from pytensor.tensor.math import max as at_max ...@@ -69,7 +70,7 @@ from pytensor.tensor.math import max as at_max
from pytensor.tensor.math import maximum from pytensor.tensor.math import maximum
from pytensor.tensor.math import min as at_min from pytensor.tensor.math import min as at_min
from pytensor.tensor.math import minimum, mul, neg, neq from pytensor.tensor.math import minimum, mul, neg, neq
from pytensor.tensor.math import pow as at_pow from pytensor.tensor.math import pow as pt_pow
from pytensor.tensor.math import ( from pytensor.tensor.math import (
prod, prod,
rad2deg, rad2deg,
...@@ -1746,6 +1747,29 @@ def test_local_pow_to_nested_squaring(): ...@@ -1746,6 +1747,29 @@ def test_local_pow_to_nested_squaring():
utt.assert_allclose(f(val_no0), val_no0 ** (-16)) utt.assert_allclose(f(val_no0), val_no0 ** (-16))
def test_local_pow_to_nested_squaring_fails_gracefully():
# Reported in #456
x = vector("x", shape=(1,))
# Create an Apply that does not have precise output shape
node = Apply(
op=pt_pow,
inputs=[x, constant([2.0])],
outputs=[tensor(shape=(None,))],
)
y = node.default_output()
fn = function([x], y)
# Check rewrite is not applied (this could change in the future)
assert any(
(isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Pow))
for node in fn.maker.fgraph.apply_nodes
)
np.testing.assert_allclose(fn([2.0]), np.array([4.0]))
class TestFuncInverse: class TestFuncInverse:
def setup_method(self): def setup_method(self):
mode = get_default_mode() mode = get_default_mode()
...@@ -2449,7 +2473,7 @@ class TestLocalMergeSwitchSameCond: ...@@ -2449,7 +2473,7 @@ class TestLocalMergeSwitchSameCond:
le, le,
eq, eq,
neq, neq,
at_pow, pt_pow,
): ):
g = rewrite(FunctionGraph(mats, [op(s1, s2)])) g = rewrite(FunctionGraph(mats, [op(s1, s2)]))
assert debugprint(g, file="str").count("Switch") == 1 assert debugprint(g, file="str").count("Switch") == 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论