提交 8cc489b1 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Use scalar variables on Numba Elemwise dispatch

上级 8267d0e4
...@@ -30,20 +30,19 @@ from pytensor.scalar.basic import ( ...@@ -30,20 +30,19 @@ from pytensor.scalar.basic import (
OR, OR,
XOR, XOR,
Add, Add,
Composite,
IntDiv, IntDiv,
Mul, Mul,
ScalarMaximum, ScalarMaximum,
ScalarMinimum, ScalarMinimum,
Sub, Sub,
TrueDiv, TrueDiv,
get_scalar_type,
scalar_maximum, scalar_maximum,
) )
from pytensor.scalar.basic import add as add_as from pytensor.scalar.basic import add as add_as
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import Argmax, MulWithoutZeros, Sum from pytensor.tensor.math import Argmax, MulWithoutZeros, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from pytensor.tensor.type import scalar
@singledispatch @singledispatch
...@@ -348,13 +347,8 @@ def create_axis_apply_fn(fn, axis, ndim, dtype): ...@@ -348,13 +347,8 @@ def create_axis_apply_fn(fn, axis, ndim, dtype):
@numba_funcify.register(Elemwise) @numba_funcify.register(Elemwise)
def numba_funcify_Elemwise(op, node, **kwargs): def numba_funcify_Elemwise(op, node, **kwargs):
# Creating a new scalar node is more involved and unnecessary scalar_inputs = [get_scalar_type(dtype=input.dtype)() for input in node.inputs]
# if the scalar_op is composite, as the fgraph already contains scalar_node = op.scalar_op.make_node(*scalar_inputs)
# all the necessary information.
scalar_node = None
if not isinstance(op.scalar_op, Composite):
scalar_inputs = [scalar(dtype=input.dtype) for input in node.inputs]
scalar_node = op.scalar_op.make_node(*scalar_inputs)
scalar_op_fn = numba_funcify( scalar_op_fn = numba_funcify(
op.scalar_op, op.scalar_op,
......
...@@ -267,11 +267,11 @@ def compare_numba_and_py( ...@@ -267,11 +267,11 @@ def compare_numba_and_py(
x, y x, y
) )
if isinstance(fgraph, tuple): if isinstance(fgraph, FunctionGraph):
fn_inputs, fn_outputs = fgraph
else:
fn_inputs = fgraph.inputs fn_inputs = fgraph.inputs
fn_outputs = fgraph.outputs fn_outputs = fgraph.outputs
else:
fn_inputs, fn_outputs = fgraph
fn_inputs = [i for i in fn_inputs if not isinstance(i, SharedVariable)] fn_inputs = [i for i in fn_inputs if not isinstance(i, SharedVariable)]
......
...@@ -15,7 +15,8 @@ from pytensor.compile.sharedvalue import SharedVariable ...@@ -15,7 +15,8 @@ from pytensor.compile.sharedvalue import SharedVariable
from pytensor.gradient import grad from pytensor.gradient import grad
from pytensor.graph.basic import Constant from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.elemwise import CAReduce, DimShuffle from pytensor.scalar import float64
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from tests.link.numba.test_basic import ( from tests.link.numba.test_basic import (
...@@ -691,3 +692,17 @@ def test_numba_careduce_benchmark(axis, c_contiguous, benchmark): ...@@ -691,3 +692,17 @@ def test_numba_careduce_benchmark(axis, c_contiguous, benchmark):
return careduce_benchmark_tester( return careduce_benchmark_tester(
axis, c_contiguous, mode="NUMBA", benchmark=benchmark axis, c_contiguous, mode="NUMBA", benchmark=benchmark
) )
def test_scalar_loop():
a = float64("a")
scalar_loop = pytensor.scalar.ScalarLoop([a], [a + a])
x = pt.tensor("x", shape=(3,))
elemwise_loop = Elemwise(scalar_loop)(3, x)
with pytest.warns(UserWarning, match="object mode"):
compare_numba_and_py(
([x], [elemwise_loop]),
(np.array([1, 2, 3], dtype="float64"),),
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论