提交 5e6b356d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement rewrite to inline Composite constants

上级 d98f33c5
...@@ -33,9 +33,13 @@ from pytensor.tensor.basic import ( ...@@ -33,9 +33,13 @@ from pytensor.tensor.basic import (
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import exp from pytensor.tensor.math import exp
from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize from pytensor.tensor.rewriting.basic import (
broadcast_like,
register_canonicalize,
register_specialize,
)
from pytensor.tensor.shape import shape_padleft from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.var import TensorConstant from pytensor.tensor.var import TensorConstant, get_unique_constant_value
class InplaceElemwiseOptimizer(GraphRewriter): class InplaceElemwiseOptimizer(GraphRewriter):
...@@ -1203,6 +1207,49 @@ def local_careduce_fusion(fgraph, node): ...@@ -1203,6 +1207,49 @@ def local_careduce_fusion(fgraph, node):
return [new_car_op(*elm_inputs)] return [new_car_op(*elm_inputs)]
@node_rewriter([Elemwise])
def local_inline_composite_constants(fgraph, node):
"""Inline scalar constants in Composite graphs."""
composite_op = node.op.scalar_op
if not isinstance(composite_op, aes.Composite):
return None
new_outer_inputs = []
new_inner_inputs = []
inner_replacements = {}
for outer_inp, inner_inp in zip(node.inputs, composite_op.fgraph.inputs):
# Complex variables don't have a `c_literal` that can be inlined
if "complex" not in outer_inp.type.dtype:
unique_value = get_unique_constant_value(outer_inp)
if unique_value is not None:
inner_replacements[inner_inp] = aes.constant(
unique_value, dtype=inner_inp.dtype
)
continue
new_outer_inputs.append(outer_inp)
new_inner_inputs.append(inner_inp)
if not inner_replacements:
return None
new_inner_outs = clone_replace(
composite_op.fgraph.outputs, replace=inner_replacements
)
new_composite_op = aes.Composite(new_inner_inputs, new_inner_outs)
new_outputs = Elemwise(new_composite_op).make_node(*new_outer_inputs).outputs
# Some of the inlined constants were broadcasting the output shape
if node.outputs[0].type.broadcastable != new_outputs[0].type.broadcastable:
new_outputs = [
broadcast_like(new_out, template=node.outputs[0], fgraph=fgraph)
for new_out in new_outputs
]
copy_stack_trace(node.outputs, new_outputs)
return new_outputs
# Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites) # Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites)
fuse_seqopt = SequenceDB() fuse_seqopt = SequenceDB()
compile.optdb.register( compile.optdb.register(
...@@ -1243,6 +1290,13 @@ fuse_seqopt.register( ...@@ -1243,6 +1290,13 @@ fuse_seqopt.register(
"fusion", "fusion",
position=10, position=10,
) )
fuse_seqopt.register(
"local_inline_composite_constants",
in2out(local_inline_composite_constants),
"fast_run",
"fusion",
position=20,
)
def _rebuild_partial_2f1grad_loop(node, wrt): def _rebuild_partial_2f1grad_loop(node, wrt):
......
...@@ -1461,6 +1461,32 @@ def test_local_useless_composite_outputs(): ...@@ -1461,6 +1461,32 @@ def test_local_useless_composite_outputs():
utt.assert_allclose(f([[np.nan]], [[1.0]], [[np.nan]]), [[0.0]]) utt.assert_allclose(f([[np.nan]], [[1.0]], [[np.nan]]), [[0.0]])
@pytest.mark.parametrize("const_shape", [(), (1,), (5,), (1, 5), (2, 5)])
@pytest.mark.parametrize("op, np_op", [(at.pow, np.power), (at.add, np.add)])
def test_local_inline_composite_constants(op, np_op, const_shape):
const = np.full(shape=const_shape, fill_value=2.5).astype(config.floatX)
x = vector("x")
y = vector("y")
out = at.exp(op(x, const)) + y
fn = pytensor.function(
[x, y], out, mode=get_default_mode().including("specialize", "fusion")
)
# There should be a single Composite after optimization
[node] = [
node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Elemwise)
]
assert isinstance(node.op.scalar_op, Composite)
assert len(node.inputs) == 2 # x and y, but not const
x_test_value = np.arange(5).astype(config.floatX)
y_test_value = np.ones(5).astype(config.floatX)
np.testing.assert_allclose(
fn(x_test_value, y_test_value),
np.exp(np_op(x_test_value, const)) + y_test_value,
)
def test_local_useless_dimshuffle_makevector(): def test_local_useless_dimshuffle_makevector():
a = scalar() a = scalar()
x = MakeVector(config.floatX)(a) x = MakeVector(config.floatX)(a)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论