提交 8cc489b1 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Use scalar variables on Numba Elemwise dispatch

上级 8267d0e4
......@@ -30,20 +30,19 @@ from pytensor.scalar.basic import (
OR,
XOR,
Add,
Composite,
IntDiv,
Mul,
ScalarMaximum,
ScalarMinimum,
Sub,
TrueDiv,
get_scalar_type,
scalar_maximum,
)
from pytensor.scalar.basic import add as add_as
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import Argmax, MulWithoutZeros, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from pytensor.tensor.type import scalar
@singledispatch
......@@ -348,13 +347,8 @@ def create_axis_apply_fn(fn, axis, ndim, dtype):
@numba_funcify.register(Elemwise)
def numba_funcify_Elemwise(op, node, **kwargs):
# Creating a new scalar node is more involved and unnecessary
# if the scalar_op is composite, as the fgraph already contains
# all the necessary information.
scalar_node = None
if not isinstance(op.scalar_op, Composite):
scalar_inputs = [scalar(dtype=input.dtype) for input in node.inputs]
scalar_node = op.scalar_op.make_node(*scalar_inputs)
scalar_inputs = [get_scalar_type(dtype=input.dtype)() for input in node.inputs]
scalar_node = op.scalar_op.make_node(*scalar_inputs)
scalar_op_fn = numba_funcify(
op.scalar_op,
......
......@@ -267,11 +267,11 @@ def compare_numba_and_py(
x, y
)
if isinstance(fgraph, tuple):
fn_inputs, fn_outputs = fgraph
else:
if isinstance(fgraph, FunctionGraph):
fn_inputs = fgraph.inputs
fn_outputs = fgraph.outputs
else:
fn_inputs, fn_outputs = fgraph
fn_inputs = [i for i in fn_inputs if not isinstance(i, SharedVariable)]
......
......@@ -15,7 +15,8 @@ from pytensor.compile.sharedvalue import SharedVariable
from pytensor.gradient import grad
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.elemwise import CAReduce, DimShuffle
from pytensor.scalar import float64
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from tests.link.numba.test_basic import (
......@@ -691,3 +692,17 @@ def test_numba_careduce_benchmark(axis, c_contiguous, benchmark):
return careduce_benchmark_tester(
axis, c_contiguous, mode="NUMBA", benchmark=benchmark
)
def test_scalar_loop():
a = float64("a")
scalar_loop = pytensor.scalar.ScalarLoop([a], [a + a])
x = pt.tensor("x", shape=(3,))
elemwise_loop = Elemwise(scalar_loop)(3, x)
with pytest.warns(UserWarning, match="object mode"):
compare_numba_and_py(
([x], [elemwise_loop]),
(np.array([1, 2, 3], dtype="float64"),),
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论