提交 5c87d741 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix type check in local_pow_specialize

上级 7367e8d0
......@@ -2071,7 +2071,10 @@ def local_pow_specialize(fgraph, node):
rval = [reciprocal(sqr(xsym))]
if rval:
rval[0] = cast(rval[0], odtype)
assert rval[0].type == node.outputs[0].type, (rval, node.outputs)
assert rval[0].type.is_super(node.outputs[0].type), (
rval[0].type,
node.outputs[0].type,
)
return rval
else:
return False
......
......@@ -96,7 +96,7 @@ from pytensor.tensor.rewriting.math import (
perform_sigm_times_exp,
simplify_mul,
)
from pytensor.tensor.shape import Reshape, Shape_i
from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape
from pytensor.tensor.type import (
TensorType,
cmatrix,
......@@ -1671,6 +1671,18 @@ def test_local_pow_specialize():
assert isinstance(nodes[1].scalar_op, aes.basic.Reciprocal)
utt.assert_allclose(f(val_no0), val_no0 ** (-0.5))
twos = np.full(shape=(10,), fill_value=2.0).astype(config.floatX)
f = function([v], v**twos, mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
# Depending on the mode the SpecifyShape is lifted or not
if topo[0].op == sqr:
assert isinstance(topo[1].op, SpecifyShape)
else:
assert isinstance(topo[0].op, SpecifyShape)
assert topo[1].op == sqr
utt.assert_allclose(f(val), val**twos)
def test_local_pow_to_nested_squaring():
mode = config.mode
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论