提交 5e6b356d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement rewrite to inline Composite constants

上级 d98f33c5
......@@ -33,9 +33,13 @@ from pytensor.tensor.basic import (
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import exp
from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize
from pytensor.tensor.rewriting.basic import (
broadcast_like,
register_canonicalize,
register_specialize,
)
from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.var import TensorConstant
from pytensor.tensor.var import TensorConstant, get_unique_constant_value
class InplaceElemwiseOptimizer(GraphRewriter):
......@@ -1203,6 +1207,49 @@ def local_careduce_fusion(fgraph, node):
return [new_car_op(*elm_inputs)]
@node_rewriter([Elemwise])
def local_inline_composite_constants(fgraph, node):
"""Inline scalar constants in Composite graphs."""
composite_op = node.op.scalar_op
if not isinstance(composite_op, aes.Composite):
return None
new_outer_inputs = []
new_inner_inputs = []
inner_replacements = {}
for outer_inp, inner_inp in zip(node.inputs, composite_op.fgraph.inputs):
# Complex variables don't have a `c_literal` that can be inlined
if "complex" not in outer_inp.type.dtype:
unique_value = get_unique_constant_value(outer_inp)
if unique_value is not None:
inner_replacements[inner_inp] = aes.constant(
unique_value, dtype=inner_inp.dtype
)
continue
new_outer_inputs.append(outer_inp)
new_inner_inputs.append(inner_inp)
if not inner_replacements:
return None
new_inner_outs = clone_replace(
composite_op.fgraph.outputs, replace=inner_replacements
)
new_composite_op = aes.Composite(new_inner_inputs, new_inner_outs)
new_outputs = Elemwise(new_composite_op).make_node(*new_outer_inputs).outputs
# Some of the inlined constants were broadcasting the output shape
if node.outputs[0].type.broadcastable != new_outputs[0].type.broadcastable:
new_outputs = [
broadcast_like(new_out, template=node.outputs[0], fgraph=fgraph)
for new_out in new_outputs
]
copy_stack_trace(node.outputs, new_outputs)
return new_outputs
# Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites)
fuse_seqopt = SequenceDB()
compile.optdb.register(
......@@ -1243,6 +1290,13 @@ fuse_seqopt.register(
"fusion",
position=10,
)
fuse_seqopt.register(
"local_inline_composite_constants",
in2out(local_inline_composite_constants),
"fast_run",
"fusion",
position=20,
)
def _rebuild_partial_2f1grad_loop(node, wrt):
......
......@@ -1461,6 +1461,32 @@ def test_local_useless_composite_outputs():
utt.assert_allclose(f([[np.nan]], [[1.0]], [[np.nan]]), [[0.0]])
@pytest.mark.parametrize("const_shape", [(), (1,), (5,), (1, 5), (2, 5)])
@pytest.mark.parametrize("op, np_op", [(at.pow, np.power), (at.add, np.add)])
def test_local_inline_composite_constants(op, np_op, const_shape):
const = np.full(shape=const_shape, fill_value=2.5).astype(config.floatX)
x = vector("x")
y = vector("y")
out = at.exp(op(x, const)) + y
fn = pytensor.function(
[x, y], out, mode=get_default_mode().including("specialize", "fusion")
)
# There should be a single Composite after optimization
[node] = [
node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Elemwise)
]
assert isinstance(node.op.scalar_op, Composite)
assert len(node.inputs) == 2 # x and y, but not const
x_test_value = np.arange(5).astype(config.floatX)
y_test_value = np.ones(5).astype(config.floatX)
np.testing.assert_allclose(
fn(x_test_value, y_test_value),
np.exp(np_op(x_test_value, const)) + y_test_value,
)
def test_local_useless_dimshuffle_makevector():
a = scalar()
x = MakeVector(config.floatX)(a)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论