提交 db734614 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Change behavior of helper set/inc to act on an indexed variable directly

上级 c40d6924
......@@ -824,25 +824,35 @@ class _tensor_py_operators:
"""Return selected slices only."""
return pt.extra_ops.compress(self, a, axis=axis)
def set(self, idx, y, **kwargs):
"""Return a copy of self with the indexed values set to y.
def set(self, y, **kwargs):
"""Return a copy of the variable indexed by self with the indexed values set to y.
Equivalent to set_subtensor(self[idx], y). See docstrings for kwargs.
Equivalent to set_subtensor(self, y). See docstrings for kwargs.
Raises
------
TypeError:
If self is not the result of a subtensor operation
Examples
--------
>>> import pytensor.tensor as pt
>>>
>>> x = pt.ones((3,))
>>> out = x.set(1, 2)
>>> out = x[1].set(2)
>>> out.eval() # array([1., 2., 1.])
"""
return pt.subtensor.set_subtensor(self[idx], y, **kwargs)
return pt.subtensor.set_subtensor(self, y, **kwargs)
def inc(self, y, **kwargs):
"""Return a copy of the variable indexed by self with the indexed values incremented by y.
def inc(self, idx, y, **kwargs):
"""Return a copy of self with the indexed values incremented by y.
Equivalent to inc_subtensor(self, y). See docstrings for kwargs.
Equivalent to inc_subtensor(self[idx], y). See docstrings for kwargs.
Raises
------
TypeError:
If self is not the result of a subtensor operation
Examples
--------
......@@ -850,10 +860,10 @@ class _tensor_py_operators:
>>> import pytensor.tensor as pt
>>>
>>> x = pt.ones((3,))
>>> out = x.inc(1, 2)
>>> out = x[1].inc(2)
>>> out.eval() # array([1., 3., 1.])
"""
return pt.inc_subtensor(self[idx], y, **kwargs)
return pt.inc_subtensor(self, y, **kwargs)
class TensorVariable(
......
......@@ -438,8 +438,8 @@ class TestTensorInstanceMethods:
idx = [0]
y = 5
assert equal_computations([x.set(idx, y)], [set_subtensor(x[idx], y)])
assert equal_computations([x.inc(idx, y)], [inc_subtensor(x[idx], y)])
assert equal_computations([x[:, idx].set(y)], [set_subtensor(x[:, idx], y)])
assert equal_computations([x[:, idx].inc(y)], [inc_subtensor(x[:, idx], y)])
def test_set_item_error(self):
x = matrix("x")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论