提交 9d5a8256 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add `set` and `add` TensorVariable methods for `set_subtensor` and `inc_subtensor` operations

上级 a6f3f2d3
......@@ -815,6 +815,32 @@ class _tensor_py_operators:
"""Return selected slices only."""
return at.extra_ops.compress(self, a, axis=axis)
def set(self, y, **kwargs):
"""Set values to y, where y is the output of an index operation.
Equivalent to set_subtensor(self, y). See docstrings for kwargs.
Examples
--------
>>> x = matrix()
>>> out = x[0].set(5)
"""
return at.subtensor.set_subtensor(self, y, **kwargs)
def add(self, y, **kwargs):
"""Add values to y, where y is the output of an index operation.
Equivalent to inc_subtensor(self, y). See docstrings for kwargs
Examples
--------
>>> x = matrix()
>>> out = x[0].add(5)
"""
return at.inc_subtensor(self, y, **kwargs)
class TensorVariable(
_tensor_py_operators, Variable[_TensorTypeType, OptionalApplyType]
......
......@@ -14,7 +14,12 @@ from pytensor.tensor.basic import constant
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import dot, eq, matmul
from pytensor.tensor.shape import Shape
from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor
from pytensor.tensor.subtensor import (
AdvancedSubtensor,
Subtensor,
inc_subtensor,
set_subtensor,
)
from pytensor.tensor.type import (
TensorType,
cscalar,
......@@ -428,6 +433,20 @@ class TestTensorInstanceMethods:
# Test equivalent advanced indexing
assert_array_equal(X[:, indices].eval({X: x}), x[:, indices])
def test_set_add(self):
x = matrix("x")
idx = [0]
y = 5
assert equal_computations([x[idx].set(y)], [set_subtensor(x[idx], y)])
assert equal_computations([x[idx].add(y)], [inc_subtensor(x[idx], y)])
msg = "must be the result of a subtensor operation"
with pytest.raises(TypeError, match=msg):
x.set(y)
with pytest.raises(TypeError, match=msg):
x.add(y)
def test_deprecated_import():
with pytest.warns(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论