提交 c513419c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba Blockwise: Fix OpFromGraph as core_op

上级 01ac0c7c
...@@ -4,6 +4,7 @@ import warnings ...@@ -4,6 +4,7 @@ import warnings
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from copy import copy from copy import copy
from functools import partial from functools import partial
from itertools import chain
from typing import Union, cast from typing import Union, cast
from pytensor.compile.function import function from pytensor.compile.function import function
...@@ -47,11 +48,15 @@ def infer_shape(outs, inputs, input_shapes): ...@@ -47,11 +48,15 @@ def infer_shape(outs, inputs, input_shapes):
assert len(inp_shp) == inp.type.ndim assert len(inp_shp) == inp.type.ndim
shape_feature = ShapeFeature() shape_feature = ShapeFeature()
shape_feature.on_attach(FunctionGraph([], [])) fgraph = FunctionGraph([], [], features=[shape_feature])
for v in chain.from_iterable(s for s in input_shapes if s is not None):
# Import input_shape nodes, as for some graphs ShapeFeature assumes these were seen before
if (node := v.owner) is not None:
fgraph.import_node(node, import_missing=True)
# Initialize shape_of with the input shapes # Initialize shape_of with the input shapes
for inp, inp_shp in zip(inputs, input_shapes, strict=True): for inp, inp_shp in zip(inputs, input_shapes, strict=True):
shape_feature.set_shape(inp, inp_shp) shape_feature.set_shape(inp, inp_shp, override=True)
def local_traverse(out): def local_traverse(out):
""" """
......
...@@ -36,7 +36,6 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs): ...@@ -36,7 +36,6 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
core_op_fn, core_op_key = numba_funcify_and_cache_key( core_op_fn, core_op_key = numba_funcify_and_cache_key(
core_op, core_op,
node=core_node, node=core_node,
parent_node=node,
**kwargs, **kwargs,
) )
core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout) core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout)
......
...@@ -274,7 +274,6 @@ def numba_funcify_Elemwise(op, node, **kwargs): ...@@ -274,7 +274,6 @@ def numba_funcify_Elemwise(op, node, **kwargs):
scalar_op_fn, scalar_cache_key = numba_funcify_and_cache_key( scalar_op_fn, scalar_cache_key = numba_funcify_and_cache_key(
op.scalar_op, op.scalar_op,
node=scalar_node, node=scalar_node,
parent_node=node,
**kwargs, **kwargs,
) )
......
...@@ -4,9 +4,10 @@ import pytest ...@@ -4,9 +4,10 @@ import pytest
from pytensor import OpFromGraph, config, function, ifelse from pytensor import OpFromGraph, config, function, ifelse
from pytensor import tensor as pt from pytensor import tensor as pt
from pytensor.compile import ViewOp from pytensor.compile import ViewOp
from pytensor.graph import vectorize_graph
from pytensor.raise_op import assert_op from pytensor.raise_op import assert_op
from pytensor.scalar import Add from pytensor.scalar import Add
from pytensor.tensor import matrix from pytensor.tensor import dmatrix, dtensor3, matrix
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
from tests.link.numba.test_basic import compare_numba_and_py from tests.link.numba.test_basic import compare_numba_and_py
...@@ -171,6 +172,24 @@ def test_ofg_aliased_outputs(): ...@@ -171,6 +172,24 @@ def test_ofg_aliased_outputs():
np.testing.assert_allclose(res, np.ones((2, 2))) np.testing.assert_allclose(res, np.ones((2, 2)))
def test_ofg_elemwise_regression():
# Regression bug for https://github.com/pymc-devs/pytensor/issues/1507
x = dmatrix("x", shape=(None, None))
z = OpFromGraph(
inputs=[x],
outputs=[x + 1],
)(x)
x_batched = dtensor3("X_batched", shape=(None, None, None))
z_batched = vectorize_graph(z, {x: x_batched})
compare_numba_and_py(
[x_batched],
[z_batched],
[np.random.normal(size=(3, 2, 4))],
eval_obj_mode=False,
)
def test_check_and_raise(): def test_check_and_raise():
x = pt.vector() x = pt.vector()
x_test_value = np.array([1.0, 2.0], dtype=config.floatX) x_test_value = np.array([1.0, 2.0], dtype=config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论