提交 ae499a49 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Do not recreate Scalar Ops with custom TransferType for Elemwise inplacing

This helper could arbitrarily override the default output_type from `ScalarOp.make_node` so that the output type matched one of the input types. This can be used to create artificial Op signatures that don't make sense or can't be cleanly implemented in other backends. For instance an Add with signature (int8,int64)->int8. This helper was historically used in: 1. Elemwise inplace rewrite. I assume as a preventive measure. However, regular use should never require changing the ScalarOp signature, as we only try to inplace on inputs that match the output dtype and recreating the same Op with the same input types should always return the same output type. ScalarOp don't have a concept of inplace, only the Elemwise wrapper does, and it shouldn't require recreating/mutating the original Op. 2. SecondOp. Here it makes sense, but a custom static_method works just as well 3. Inplace tests with the inplace variants of `@scalar_elemwise` decorator. These test Classes were removed. It still didn't make sense to test/force Ops to have an artifical signature for the sake of tests. They were removed anyway
上级 e0a2a865
......@@ -1101,30 +1101,6 @@ def same_out_float_only(type) -> tuple[ScalarType]:
return (type,)
class transfer_type(MetaObject):
__props__ = ("transfer",)
def __init__(self, *transfer):
assert all(isinstance(x, int | str) or x is None for x in transfer)
self.transfer = transfer
def __str__(self):
return f"transfer_type{self.transfer}"
def __call__(self, *types):
upcast = upcast_out(*types)
retval = []
for i in self.transfer:
if i is None:
retval += [upcast]
elif isinstance(i, str):
retval += [i]
else:
retval += [types[i]]
return retval
# return [upcast if i is None else types[i] for i in self.transfer]
class specific_out(MetaObject):
__props__ = ("spec",)
......@@ -2446,6 +2422,10 @@ clip = Clip(upcast_out_no_complex, name="clip")
class Second(BinaryScalarOp):
@staticmethod
def output_types_preference(_first_type, second_type):
return [second_type]
def impl(self, x, y):
return y
......@@ -2474,7 +2454,7 @@ class Second(BinaryScalarOp):
return DisconnectedType()(), y.zeros_like(dtype=config.floatX)
second = Second(transfer_type(1), name="second")
second = Second(name="second")
class Identity(UnaryScalarOp):
......@@ -2515,18 +2495,6 @@ class Cast(UnaryScalarOp):
return convert_to_float32
return self
def make_new_inplace(self, output_types_preference=None, name=None):
"""
This op.__init__ fct don't have the same parameter as other scalar op.
This breaks the insert_inplace_optimizer optimization.
This function is a fix to patch this, by ignoring the
output_types_preference passed by the optimization, and replacing it
by the current output type. This should only be triggered when
both input and output have the same dtype anyway.
"""
return self.__class__(self.o_type, name)
def impl(self, input):
return self.ctor(input)
......@@ -4322,22 +4290,6 @@ class Composite(ScalarInnerGraphOp):
return self._name
def make_new_inplace(self, output_types_preference=None, name=None):
"""
This op.__init__ fct don't have the same parameter as other scalar op.
This break the insert_inplace_optimizer optimization.
This fct allow fix patch this.
"""
d = {k: getattr(self, k) for k in self.init_param}
out = self.__class__(**d)
if name:
out.name = name
else:
name = out.name
super(Composite, out).__init__(output_types_preference, name)
return out
@property
def fgraph(self):
if hasattr(self, "_fgraph"):
......
......@@ -136,9 +136,6 @@ class ScalarLoop(ScalarInnerGraphOp):
def fn(self):
raise NotImplementedError
def make_new_inplace(self, output_types_preference=None, name=None):
return self.clone(output_types_preference=output_types_preference, name=name)
def make_node(self, n_steps, *inputs):
assert len(inputs) == self.nin - 1
......
......@@ -35,7 +35,6 @@ from pytensor.scalar import (
Mul,
ScalarOp,
get_scalar_type,
transfer_type,
upcast_out,
upgrade_to_float,
)
......@@ -287,22 +286,17 @@ class InplaceElemwiseOptimizer(InplaceGraphOptimizer):
op = node.op
scalar_op = op.scalar_op
inplace_pattern = {i: o for i, [o] in inplace_pattern.items()}
if hasattr(scalar_op, "make_new_inplace"):
new_scalar_op = scalar_op.make_new_inplace(
transfer_type(
*[
inplace_pattern.get(i, o.dtype)
for i, o in enumerate(node.outputs)
]
)
)
else:
new_scalar_op = type(scalar_op)(
transfer_type(
*[inplace_pattern.get(i, None) for i in range(len(node.outputs))]
)
try:
return type(op)(scalar_op, inplace_pattern).make_node(*node.inputs)
except TypeError:
# Elemwise raises TypeError if we try to inplace an output on an input of a different dtype
if config.optimizer_verbose:
print( # noqa: T201
f"InplaceElemwise failed because the output dtype of {node} changed when rebuilt. "
"Perhaps due to a change in config.floatX or config.cast_policy"
)
return type(op)(new_scalar_op, inplace_pattern).make_node(*node.inputs)
# InplaceGraphOptimizer will chug along fine if we return the original node
return node
optdb.register(
......
......@@ -2797,7 +2797,6 @@ class TestARange:
out = arange(start, stop, 1)
f = function([start, stop], out.shape, mode=mode)
assert len(f.maker.fgraph.toposort()) == 5
# 4 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{int64}}(Elemwise{sub,no_inplace}.0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)]
if config.cast_policy == "custom":
assert out.dtype == "int64"
elif config.cast_policy == "numpy+floatX":
......
......@@ -1200,3 +1200,28 @@ def test_XOR_inplace():
_ = gn(l, r)
# test the in-place stuff
assert np.all(l == np.asarray([0, 1, 1, 0])), l
def test_inplace_dtype_changed():
with pytensor.config.change_flags(cast_policy="numpy+floatX", floatX="float64"):
x = pt.vector("x", dtype="float32")
y = pt.vector("y", dtype="int32")
with pytensor.config.change_flags(floatX="float32"):
out = pt.add(x, y)
assert out.dtype == "float32"
with pytensor.config.change_flags(floatX="float32"):
fn32 = pytensor.function(
[In(x, mutable=True), In(y, mutable=True)],
out,
mode="fast_run",
)
assert fn32.maker.fgraph.outputs[0].owner.op.destroy_map == {0: [0]}
with pytensor.config.change_flags(floatX="float64"):
fn64 = pytensor.function(
[In(x, mutable=True), In(y, mutable=True)],
out,
mode="fast_run",
)
assert fn64.maker.fgraph.outputs[0].owner.op.destroy_map == {}
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论