提交 9d5a8256 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add `set` and `add` TensorVariable methods for `set_subtensor` and `inc_subtensor` operations

上级 a6f3f2d3
...@@ -815,6 +815,32 @@ class _tensor_py_operators: ...@@ -815,6 +815,32 @@ class _tensor_py_operators:
"""Return selected slices only.""" """Return selected slices only."""
return at.extra_ops.compress(self, a, axis=axis) return at.extra_ops.compress(self, a, axis=axis)
def set(self, y, **kwargs):
"""Set values to y, where y is the output of an index operation.
Equivalent to set_subtensor(self, y). See docstrings for kwargs.
Examples
--------
>>> x = matrix()
>>> out = x[0].set(5)
"""
return at.subtensor.set_subtensor(self, y, **kwargs)
def add(self, y, **kwargs):
"""Add values to y, where y is the output of an index operation.
Equivalent to inc_subtensor(self, y). See docstrings for kwargs
Examples
--------
>>> x = matrix()
>>> out = x[0].add(5)
"""
return at.inc_subtensor(self, y, **kwargs)
class TensorVariable( class TensorVariable(
_tensor_py_operators, Variable[_TensorTypeType, OptionalApplyType] _tensor_py_operators, Variable[_TensorTypeType, OptionalApplyType]
......
...@@ -14,7 +14,12 @@ from pytensor.tensor.basic import constant ...@@ -14,7 +14,12 @@ from pytensor.tensor.basic import constant
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import dot, eq, matmul from pytensor.tensor.math import dot, eq, matmul
from pytensor.tensor.shape import Shape from pytensor.tensor.shape import Shape
from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor from pytensor.tensor.subtensor import (
AdvancedSubtensor,
Subtensor,
inc_subtensor,
set_subtensor,
)
from pytensor.tensor.type import ( from pytensor.tensor.type import (
TensorType, TensorType,
cscalar, cscalar,
...@@ -428,6 +433,20 @@ class TestTensorInstanceMethods: ...@@ -428,6 +433,20 @@ class TestTensorInstanceMethods:
# Test equivalent advanced indexing # Test equivalent advanced indexing
assert_array_equal(X[:, indices].eval({X: x}), x[:, indices]) assert_array_equal(X[:, indices].eval({X: x}), x[:, indices])
def test_set_add(self):
x = matrix("x")
idx = [0]
y = 5
assert equal_computations([x[idx].set(y)], [set_subtensor(x[idx], y)])
assert equal_computations([x[idx].add(y)], [inc_subtensor(x[idx], y)])
msg = "must be the result of a subtensor operation"
with pytest.raises(TypeError, match=msg):
x.set(y)
with pytest.raises(TypeError, match=msg):
x.add(y)
def test_deprecated_import(): def test_deprecated_import():
with pytest.warns( with pytest.warns(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论