提交 cb8b8ac8 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Change the semantics of the `set` and `add` helpers

上级 c1bceb93
...@@ -820,35 +820,36 @@ class _tensor_py_operators: ...@@ -820,35 +820,36 @@ class _tensor_py_operators:
"""Return selected slices only.""" """Return selected slices only."""
return at.extra_ops.compress(self, a, axis=axis) return at.extra_ops.compress(self, a, axis=axis)
def set(self, y, **kwargs): def set(self, idx, y, **kwargs):
"""Return a copy of a tensor with the indexed values set to y. """Return a copy of self with the indexed values set to y.
Self must be the output of an indexing operation. Equivalent to set_subtensor(self[idx], y). See docstrings for kwargs.
Equivalent to set_subtensor(self, y). See docstrings for kwargs.
Examples Examples
-------- --------
>>> import pytensor.tensor as pt
>>> x = matrix() >>>
>>> out = x[0].set(5) >>> x = pt.ones((3,))
>>> out = x.set(1, 2)
>>> out.eval() # array([1., 2., 1.])
""" """
return at.subtensor.set_subtensor(self, y, **kwargs) return at.subtensor.set_subtensor(self[idx], y, **kwargs)
def add(self, y, **kwargs):
"""Return a copy of a tensor with the indexed values incremented by y.
Self must be the output of an indexing operation. def add(self, idx, y, **kwargs):
"""Return a copy of self with the indexed values incremented by y.
Equivalent to inc_subtensor(self, y). See docstrings for kwargs. Equivalent to inc_subtensor(self[idx], y). See docstrings for kwargs.
Examples Examples
-------- --------
>>> x = matrix() >>> import pytensor.tensor as pt
>>> out = x[0].add(5) >>>
>>> x = pt.ones((3,))
>>> out = x.add(1, 2)
>>> out.eval() # array([1., 3., 1.])
""" """
return at.inc_subtensor(self, y, **kwargs) return at.inc_subtensor(self[idx], y, **kwargs)
class TensorVariable( class TensorVariable(
......
...@@ -438,14 +438,8 @@ class TestTensorInstanceMethods: ...@@ -438,14 +438,8 @@ class TestTensorInstanceMethods:
idx = [0] idx = [0]
y = 5 y = 5
assert equal_computations([x[idx].set(y)], [set_subtensor(x[idx], y)]) assert equal_computations([x.set(idx, y)], [set_subtensor(x[idx], y)])
assert equal_computations([x[idx].add(y)], [inc_subtensor(x[idx], y)]) assert equal_computations([x.add(idx, y)], [inc_subtensor(x[idx], y)])
msg = "must be the result of a subtensor operation"
with pytest.raises(TypeError, match=msg):
x.set(y)
with pytest.raises(TypeError, match=msg):
x.add(y)
def test_set_item_error(self): def test_set_item_error(self):
x = matrix("x") x = matrix("x")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论