Unverified 提交 b79d2329 authored 作者: Dhruvanshu-Joshi's avatar Dhruvanshu-Joshi 提交者: GitHub

Implement `ufunc_outer` like `add.outer` for binary `Elemwise` operations

上级 35ae5db4
...@@ -1170,6 +1170,16 @@ class Elemwise(OpenMPOp): ...@@ -1170,6 +1170,16 @@ class Elemwise(OpenMPOp):
else: else:
return () return ()
def outer(self, x, y):
from pytensor.tensor.basic import expand_dims
if self.scalar_op.nin not in (-1, 2):
raise NotImplementedError("outer is only available for binary operators")
x_ = expand_dims(x, tuple(range(-y.ndim, 0)))
y_ = expand_dims(y, tuple(range(x.ndim)))
return self(x_, y_)
class CAReduce(COp): class CAReduce(COp):
"""Reduces a scalar operation along specified axes. """Reduces a scalar operation along specified axes.
......
...@@ -8,7 +8,9 @@ import pytest ...@@ -8,7 +8,9 @@ import pytest
import pytensor import pytensor
import pytensor.scalar as ps import pytensor.scalar as ps
import pytensor.tensor as pt
import tests.unittest_tools as utt import tests.unittest_tools as utt
from pytensor.compile.function import function
from pytensor.compile.mode import Mode from pytensor.compile.mode import Mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Variable from pytensor.graph.basic import Apply, Variable
...@@ -893,6 +895,25 @@ class TestElemwise(unittest_tools.InferShapeTester): ...@@ -893,6 +895,25 @@ class TestElemwise(unittest_tools.InferShapeTester):
): ):
x + y x + y
@pytest.mark.parametrize(
"shape_x, shape_y, op, np_op",
[
((3, 5), (7, 1, 3), pt.add, np.add),
((2, 3), (1, 4), pt.mul, np.multiply),
],
)
def test_outer(self, shape_x, shape_y, op, np_op):
x = tensor(dtype=np.float64, shape=shape_x)
y = tensor(dtype=np.float64, shape=shape_y)
z = op.outer(x, y)
f = function([x, y], z)
x1 = np.ones(shape_x)
y1 = np.ones(shape_y)
np.testing.assert_array_equal(f(x1, y1), np_op.outer(x1, y1))
def test_not_implemented_elemwise_grad(): def test_not_implemented_elemwise_grad():
# Regression test for unimplemented gradient in an Elemwise Op. # Regression test for unimplemented gradient in an Elemwise Op.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论