提交 ae499a49 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Do not recreate Scalar Ops with custom TransferType for Elemwise inplacing

This helper could arbitrarily override the default output_type from `ScalarOp.make_node` so that the output type matched one of the input types. This can be used to create artificial Op signatures that don't make sense or can't be cleanly implemented in other backends. For instance an Add with signature (int8,int64)->int8. This helper was historically used in: 1. Elemwise inplace rewrite. I assume as a preventive measure. However, regular use should never require changing the ScalarOp signature, as we only try to inplace on inputs that match the output dtype and recreating the same Op with the same input types should always return the same output type. ScalarOp don't have a concept of inplace, only the Elemwise wrapper does, and it shouldn't require recreating/mutating the original Op. 2. SecondOp. Here it makes sense, but a custom static_method works just as well 3. Inplace tests with the inplace variants of `@scalar_elemwise` decorator. These test Classes were removed. It still didn't make sense to test/force Ops to have an artifical signature for the sake of tests. They were removed anyway
上级 e0a2a865
...@@ -1101,30 +1101,6 @@ def same_out_float_only(type) -> tuple[ScalarType]: ...@@ -1101,30 +1101,6 @@ def same_out_float_only(type) -> tuple[ScalarType]:
return (type,) return (type,)
class transfer_type(MetaObject):
__props__ = ("transfer",)
def __init__(self, *transfer):
assert all(isinstance(x, int | str) or x is None for x in transfer)
self.transfer = transfer
def __str__(self):
return f"transfer_type{self.transfer}"
def __call__(self, *types):
upcast = upcast_out(*types)
retval = []
for i in self.transfer:
if i is None:
retval += [upcast]
elif isinstance(i, str):
retval += [i]
else:
retval += [types[i]]
return retval
# return [upcast if i is None else types[i] for i in self.transfer]
class specific_out(MetaObject): class specific_out(MetaObject):
__props__ = ("spec",) __props__ = ("spec",)
...@@ -2446,6 +2422,10 @@ clip = Clip(upcast_out_no_complex, name="clip") ...@@ -2446,6 +2422,10 @@ clip = Clip(upcast_out_no_complex, name="clip")
class Second(BinaryScalarOp): class Second(BinaryScalarOp):
@staticmethod
def output_types_preference(_first_type, second_type):
return [second_type]
def impl(self, x, y): def impl(self, x, y):
return y return y
...@@ -2474,7 +2454,7 @@ class Second(BinaryScalarOp): ...@@ -2474,7 +2454,7 @@ class Second(BinaryScalarOp):
return DisconnectedType()(), y.zeros_like(dtype=config.floatX) return DisconnectedType()(), y.zeros_like(dtype=config.floatX)
second = Second(transfer_type(1), name="second") second = Second(name="second")
class Identity(UnaryScalarOp): class Identity(UnaryScalarOp):
...@@ -2515,18 +2495,6 @@ class Cast(UnaryScalarOp): ...@@ -2515,18 +2495,6 @@ class Cast(UnaryScalarOp):
return convert_to_float32 return convert_to_float32
return self return self
def make_new_inplace(self, output_types_preference=None, name=None):
"""
This op.__init__ fct don't have the same parameter as other scalar op.
This breaks the insert_inplace_optimizer optimization.
This function is a fix to patch this, by ignoring the
output_types_preference passed by the optimization, and replacing it
by the current output type. This should only be triggered when
both input and output have the same dtype anyway.
"""
return self.__class__(self.o_type, name)
def impl(self, input): def impl(self, input):
return self.ctor(input) return self.ctor(input)
...@@ -4322,22 +4290,6 @@ class Composite(ScalarInnerGraphOp): ...@@ -4322,22 +4290,6 @@ class Composite(ScalarInnerGraphOp):
return self._name return self._name
def make_new_inplace(self, output_types_preference=None, name=None):
"""
This op.__init__ fct don't have the same parameter as other scalar op.
This break the insert_inplace_optimizer optimization.
This fct allow fix patch this.
"""
d = {k: getattr(self, k) for k in self.init_param}
out = self.__class__(**d)
if name:
out.name = name
else:
name = out.name
super(Composite, out).__init__(output_types_preference, name)
return out
@property @property
def fgraph(self): def fgraph(self):
if hasattr(self, "_fgraph"): if hasattr(self, "_fgraph"):
......
...@@ -136,9 +136,6 @@ class ScalarLoop(ScalarInnerGraphOp): ...@@ -136,9 +136,6 @@ class ScalarLoop(ScalarInnerGraphOp):
def fn(self): def fn(self):
raise NotImplementedError raise NotImplementedError
def make_new_inplace(self, output_types_preference=None, name=None):
return self.clone(output_types_preference=output_types_preference, name=name)
def make_node(self, n_steps, *inputs): def make_node(self, n_steps, *inputs):
assert len(inputs) == self.nin - 1 assert len(inputs) == self.nin - 1
......
...@@ -35,7 +35,6 @@ from pytensor.scalar import ( ...@@ -35,7 +35,6 @@ from pytensor.scalar import (
Mul, Mul,
ScalarOp, ScalarOp,
get_scalar_type, get_scalar_type,
transfer_type,
upcast_out, upcast_out,
upgrade_to_float, upgrade_to_float,
) )
...@@ -287,22 +286,17 @@ class InplaceElemwiseOptimizer(InplaceGraphOptimizer): ...@@ -287,22 +286,17 @@ class InplaceElemwiseOptimizer(InplaceGraphOptimizer):
op = node.op op = node.op
scalar_op = op.scalar_op scalar_op = op.scalar_op
inplace_pattern = {i: o for i, [o] in inplace_pattern.items()} inplace_pattern = {i: o for i, [o] in inplace_pattern.items()}
if hasattr(scalar_op, "make_new_inplace"): try:
new_scalar_op = scalar_op.make_new_inplace( return type(op)(scalar_op, inplace_pattern).make_node(*node.inputs)
transfer_type( except TypeError:
*[ # Elemwise raises TypeError if we try to inplace an output on an input of a different dtype
inplace_pattern.get(i, o.dtype) if config.optimizer_verbose:
for i, o in enumerate(node.outputs) print( # noqa: T201
] f"InplaceElemwise failed because the output dtype of {node} changed when rebuilt. "
) "Perhaps due to a change in config.floatX or config.cast_policy"
)
else:
new_scalar_op = type(scalar_op)(
transfer_type(
*[inplace_pattern.get(i, None) for i in range(len(node.outputs))]
)
) )
return type(op)(new_scalar_op, inplace_pattern).make_node(*node.inputs) # InplaceGraphOptimizer will chug along fine if we return the original node
return node
optdb.register( optdb.register(
......
...@@ -2797,7 +2797,6 @@ class TestARange: ...@@ -2797,7 +2797,6 @@ class TestARange:
out = arange(start, stop, 1) out = arange(start, stop, 1)
f = function([start, stop], out.shape, mode=mode) f = function([start, stop], out.shape, mode=mode)
assert len(f.maker.fgraph.toposort()) == 5 assert len(f.maker.fgraph.toposort()) == 5
# 4 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{int64}}(Elemwise{sub,no_inplace}.0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)]
if config.cast_policy == "custom": if config.cast_policy == "custom":
assert out.dtype == "int64" assert out.dtype == "int64"
elif config.cast_policy == "numpy+floatX": elif config.cast_policy == "numpy+floatX":
......
...@@ -1200,3 +1200,28 @@ def test_XOR_inplace(): ...@@ -1200,3 +1200,28 @@ def test_XOR_inplace():
_ = gn(l, r) _ = gn(l, r)
# test the in-place stuff # test the in-place stuff
assert np.all(l == np.asarray([0, 1, 1, 0])), l assert np.all(l == np.asarray([0, 1, 1, 0])), l
def test_inplace_dtype_changed():
with pytensor.config.change_flags(cast_policy="numpy+floatX", floatX="float64"):
x = pt.vector("x", dtype="float32")
y = pt.vector("y", dtype="int32")
with pytensor.config.change_flags(floatX="float32"):
out = pt.add(x, y)
assert out.dtype == "float32"
with pytensor.config.change_flags(floatX="float32"):
fn32 = pytensor.function(
[In(x, mutable=True), In(y, mutable=True)],
out,
mode="fast_run",
)
assert fn32.maker.fgraph.outputs[0].owner.op.destroy_map == {0: [0]}
with pytensor.config.change_flags(floatX="float64"):
fn64 = pytensor.function(
[In(x, mutable=True), In(y, mutable=True)],
out,
mode="fast_run",
)
assert fn64.maker.fgraph.outputs[0].owner.op.destroy_map == {}
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论