提交 5c87d741 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix type check in local_pow_specialize

上级 7367e8d0
...@@ -2071,7 +2071,10 @@ def local_pow_specialize(fgraph, node): ...@@ -2071,7 +2071,10 @@ def local_pow_specialize(fgraph, node):
rval = [reciprocal(sqr(xsym))] rval = [reciprocal(sqr(xsym))]
if rval: if rval:
rval[0] = cast(rval[0], odtype) rval[0] = cast(rval[0], odtype)
assert rval[0].type == node.outputs[0].type, (rval, node.outputs) assert rval[0].type.is_super(node.outputs[0].type), (
rval[0].type,
node.outputs[0].type,
)
return rval return rval
else: else:
return False return False
......
...@@ -96,7 +96,7 @@ from pytensor.tensor.rewriting.math import ( ...@@ -96,7 +96,7 @@ from pytensor.tensor.rewriting.math import (
perform_sigm_times_exp, perform_sigm_times_exp,
simplify_mul, simplify_mul,
) )
from pytensor.tensor.shape import Reshape, Shape_i from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape
from pytensor.tensor.type import ( from pytensor.tensor.type import (
TensorType, TensorType,
cmatrix, cmatrix,
...@@ -1671,6 +1671,18 @@ def test_local_pow_specialize(): ...@@ -1671,6 +1671,18 @@ def test_local_pow_specialize():
assert isinstance(nodes[1].scalar_op, aes.basic.Reciprocal) assert isinstance(nodes[1].scalar_op, aes.basic.Reciprocal)
utt.assert_allclose(f(val_no0), val_no0 ** (-0.5)) utt.assert_allclose(f(val_no0), val_no0 ** (-0.5))
twos = np.full(shape=(10,), fill_value=2.0).astype(config.floatX)
f = function([v], v**twos, mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
# Depending on the mode the SpecifyShape is lifted or not
if topo[0].op == sqr:
assert isinstance(topo[1].op, SpecifyShape)
else:
assert isinstance(topo[0].op, SpecifyShape)
assert topo[1].op == sqr
utt.assert_allclose(f(val), val**twos)
def test_local_pow_to_nested_squaring(): def test_local_pow_to_nested_squaring():
mode = config.mode mode = config.mode
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论