提交 c513419c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba Blockwise: Fix OpFromGraph as core_op

上级 01ac0c7c
......@@ -4,6 +4,7 @@ import warnings
from collections.abc import Callable, Sequence
from copy import copy
from functools import partial
from itertools import chain
from typing import Union, cast
from pytensor.compile.function import function
......@@ -47,11 +48,15 @@ def infer_shape(outs, inputs, input_shapes):
assert len(inp_shp) == inp.type.ndim
shape_feature = ShapeFeature()
shape_feature.on_attach(FunctionGraph([], []))
fgraph = FunctionGraph([], [], features=[shape_feature])
for v in chain.from_iterable(s for s in input_shapes if s is not None):
# Import input_shape nodes, as for some graphs ShapeFeature assumes these were seen before
if (node := v.owner) is not None:
fgraph.import_node(node, import_missing=True)
# Initialize shape_of with the input shapes
for inp, inp_shp in zip(inputs, input_shapes, strict=True):
shape_feature.set_shape(inp, inp_shp)
shape_feature.set_shape(inp, inp_shp, override=True)
def local_traverse(out):
"""
......
......@@ -36,7 +36,6 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
core_op_fn, core_op_key = numba_funcify_and_cache_key(
core_op,
node=core_node,
parent_node=node,
**kwargs,
)
core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout)
......
......@@ -274,7 +274,6 @@ def numba_funcify_Elemwise(op, node, **kwargs):
scalar_op_fn, scalar_cache_key = numba_funcify_and_cache_key(
op.scalar_op,
node=scalar_node,
parent_node=node,
**kwargs,
)
......
......@@ -4,9 +4,10 @@ import pytest
from pytensor import OpFromGraph, config, function, ifelse
from pytensor import tensor as pt
from pytensor.compile import ViewOp
from pytensor.graph import vectorize_graph
from pytensor.raise_op import assert_op
from pytensor.scalar import Add
from pytensor.tensor import matrix
from pytensor.tensor import dmatrix, dtensor3, matrix
from pytensor.tensor.elemwise import Elemwise
from tests.link.numba.test_basic import compare_numba_and_py
......@@ -171,6 +172,24 @@ def test_ofg_aliased_outputs():
np.testing.assert_allclose(res, np.ones((2, 2)))
def test_ofg_elemwise_regression():
# Regression bug for https://github.com/pymc-devs/pytensor/issues/1507
x = dmatrix("x", shape=(None, None))
z = OpFromGraph(
inputs=[x],
outputs=[x + 1],
)(x)
x_batched = dtensor3("X_batched", shape=(None, None, None))
z_batched = vectorize_graph(z, {x: x_batched})
compare_numba_and_py(
[x_batched],
[z_batched],
[np.random.normal(size=(3, 2, 4))],
eval_obj_mode=False,
)
def test_check_and_raise():
x = pt.vector()
x_test_value = np.array([1.0, 2.0], dtype=config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论