提交 b8e26cd4 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Make local_pow_to_nested_squaring more permissive

上级 e48ff560
......@@ -2120,10 +2120,6 @@ def local_pow_to_nested_squaring(fgraph, node):
rval = [rval1]
if rval:
rval[0] = cast(rval[0], odtype)
# TODO: We can add a specify_broadcastable and/or unbroadcast to make the
# output types compatible. Or work on #408 and let TensorType.filter_variable do it.
if rval[0].type.broadcastable != node.outputs[0].type.broadcastable:
return None
return rval
......
......@@ -29,7 +29,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph
from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import debugprint
from pytensor.scalar import PolyGamma, Pow, Psi, TriGamma
from pytensor.scalar import PolyGamma, Psi, TriGamma
from pytensor.tensor import inplace
from pytensor.tensor.basic import Alloc, constant, join, second, switch
from pytensor.tensor.blas import Dot22, Gemv
......@@ -1757,7 +1757,7 @@ def test_local_pow_to_nested_squaring():
utt.assert_allclose(f(val_no0), val_no0 ** (-16))
def test_local_pow_to_nested_squaring_fails_gracefully():
def test_local_pow_to_nested_squaring_works_with_static_type():
# Reported in #456
x = vector("x", shape=(1,))
......@@ -1771,12 +1771,6 @@ def test_local_pow_to_nested_squaring_fails_gracefully():
fn = function([x], y)
# Check rewrite is not applied (this could change in the future)
assert any(
(isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Pow))
for node in fn.maker.fgraph.apply_nodes
)
np.testing.assert_allclose(fn([2.0]), np.array([4.0]))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论