提交 cb8b8ac8 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Change the semantics of the `set` and `add` helpers

上级 c1bceb93
......@@ -820,35 +820,36 @@ class _tensor_py_operators:
"""Return selected slices only."""
return at.extra_ops.compress(self, a, axis=axis)
def set(self, y, **kwargs):
"""Return a copy of a tensor with the indexed values set to y.
def set(self, idx, y, **kwargs):
"""Return a copy of self with the indexed values set to y.
Self must be the output of an indexing operation.
Equivalent to set_subtensor(self, y). See docstrings for kwargs.
Equivalent to set_subtensor(self[idx], y). See docstrings for kwargs.
Examples
--------
>>> x = matrix()
>>> out = x[0].set(5)
>>> import pytensor.tensor as pt
>>>
>>> x = pt.ones((3,))
>>> out = x.set(1, 2)
>>> out.eval() # array([1., 2., 1.])
"""
return at.subtensor.set_subtensor(self, y, **kwargs)
def add(self, y, **kwargs):
"""Return a copy of a tensor with the indexed values incremented by y.
return at.subtensor.set_subtensor(self[idx], y, **kwargs)
Self must be the output of an indexing operation.
def add(self, idx, y, **kwargs):
"""Return a copy of self with the indexed values incremented by y.
Equivalent to inc_subtensor(self, y). See docstrings for kwargs.
Equivalent to inc_subtensor(self[idx], y). See docstrings for kwargs.
Examples
--------
>>> x = matrix()
>>> out = x[0].add(5)
>>> import pytensor.tensor as pt
>>>
>>> x = pt.ones((3,))
>>> out = x.add(1, 2)
>>> out.eval() # array([1., 3., 1.])
"""
return at.inc_subtensor(self, y, **kwargs)
return at.inc_subtensor(self[idx], y, **kwargs)
class TensorVariable(
......
......@@ -438,14 +438,8 @@ class TestTensorInstanceMethods:
idx = [0]
y = 5
assert equal_computations([x[idx].set(y)], [set_subtensor(x[idx], y)])
assert equal_computations([x[idx].add(y)], [inc_subtensor(x[idx], y)])
msg = "must be the result of a subtensor operation"
with pytest.raises(TypeError, match=msg):
x.set(y)
with pytest.raises(TypeError, match=msg):
x.add(y)
assert equal_computations([x.set(idx, y)], [set_subtensor(x[idx], y)])
assert equal_computations([x.add(idx, y)], [inc_subtensor(x[idx], y)])
def test_set_item_error(self):
x = matrix("x")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论