Replace code snippet with minimal self-contained example for inplace_on_inputs

上级 7d0ee2d8
......@@ -219,11 +219,67 @@ those inputs where it is safe and beneficial to do so.
.. testcode::
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
"""Try to return a version of self that tries to inplace in as many as `allowed_inplace_inputs`."""
# Implementation would create a new Op with appropriate destroy_map
# Return self by default if no inplace version is available
return self
import numpy as np
import pytensor
import pytensor.tensor as pt
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.tensor.blockwise import Blockwise
class MyOpWithInplace(Op):
__props__ = ("destroy_a",)
def __init__(self, destroy_a):
self.destroy_a = destroy_a
if destroy_a:
self.destroy_map = {0: [0]}
def make_node(self, a):
return Apply(self, [a], [a.type()])
def perform(self, node, inputs, output_storage):
[a] = inputs
if not self.destroy_a:
a = a.copy()
a[0] += 1
output_storage[0][0] = a
def inplace_on_inputs(self, allowed_inplace_inputs):
if 0 in allowed_inplace_inputs:
return MyOpWithInplace(destroy_a=True)
return self
a = pt.vector("a")
# Only Blockwise trigger inplace automatically for now
# Since the Blockwise isn't needed in this case, it will be removed after the inplace optimization
op = Blockwise(MyOpWithInplace(destroy_a=False), signature="(a)->(a)")
out = op(a)
# Give PyTensor permission to inplace on user provided inputs
fn = pytensor.function([pytensor.In(a, mutable=True)], out)
# Confirm that we have the inplace version of the Op
fn.dprint(print_destroy_map=True)
.. testoutput::
Blockwise{MyOpWithInplace{destroy_a=True}, (a)->(a)} [id A] '' 5
└─ a [id B]
The output shows that the function now uses the inplace version (`destroy_a=True`).
.. testcode::
# Test that inplace modification works
test_a = np.zeros(5)
result = fn(test_a)
print("Function result:", result)
print("Original array after function call:", test_a)
.. testoutput::
Function result: [1. 0. 0. 0. 0.]
Original array after function call: [1. 0. 0. 0. 0.]
Currently, this method is primarily used with Blockwise operations through PyTensor's
rewriting system, but it will be extended to support core ops directly in future versions.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论