提交 d5a054d1 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Generalize lift of Subtensor over Elemwise

Split off Subtensor of Unbroadcast into its own rewrite
上级 f1db1bd6
...@@ -108,73 +108,79 @@ def local_subtensor_of_dot(fgraph, node): ...@@ -108,73 +108,79 @@ def local_subtensor_of_dot(fgraph, node):
return [r] return [r]
# fast_compile to allow opt subtensor(cast{float32}(make_vector)) @register_canonicalize("shape_unsafe")
@register_canonicalize("fast_compile") @register_specialize("shape_unsafe")
@node_rewriter([Subtensor]) @node_rewriter([Subtensor])
def local_subtensor_lift(fgraph, node): def local_subtensor_of_elemwise(fgraph, node):
"""Lift a Subtensor through an Elemwise and its implicit broadcasting behavior.
exp(x)[:, 0] -> exp(x[:, 0])
add(x, y)[0] -> add(x[0], y[0])
add(x[None], y)[2] -> add(x, y[2])
""" """
unary(x)[idx] -> unary(x[idx])#any broadcast pattern. elem, *idx = node.inputs
Handles the following unary ops: if not (elem.owner and isinstance(elem.owner.op, Elemwise)):
elemwise(x,...)[idx] -> elemwise(x[idx],...) return None
when x,... are broadcasted scalar or not broadcasted at all
""" if len(fgraph.clients[elem]) > 1:
if isinstance(node.op, Subtensor): # Elemwise output is used beyond the Subtensor.
u = node.inputs[0] # Get out to avoid repeated computations
if u.owner is None or len(fgraph.clients[u]) > 1: return None
return False
if isinstance(u.owner.op, Elemwise) and len(u.owner.inputs) == 1: idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
idx = node.inputs[1:]
x_idx = node.op(u.owner.inputs[0], *idx) elem_inputs = elem.owner.inputs
# Copy over previous output stacktrace elem_bcast = elem.type.broadcastable
copy_stack_trace(node.outputs, x_idx) if all(inp.type.broadcastable == elem_bcast for inp in elem_inputs):
ret = u.owner.op(x_idx) # No need to worry about implicit broadcasting.
# Copy over previous output stacktrace indexed_inputs = [inp[idx_tuple] for inp in elem_inputs]
# and stacktrace from previous unary operation
copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
return [ret]
if isinstance(u.owner.op, Elemwise):
new_inputs = []
if all(sum(i.type.broadcastable) == 0 for i in u.owner.inputs):
# There is no broadcastable in the inputs
idx = node.inputs[1:]
new_inputs = [node.op(i, *idx) for i in u.owner.inputs]
# Copy over previous output stacktrace
copy_stack_trace(node.outputs[0], new_inputs)
ret = u.owner.op(*new_inputs)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
return [ret]
elif all(sum(i.type.broadcastable) in [i.ndim, 0] for i in u.owner.inputs):
# There is no broadcastable in the inputs or it is scalar
idx = node.inputs[1:]
new_inputs = []
for i in u.owner.inputs:
if sum(i.type.broadcastable) == 0:
new_inputs.append(node.op(i, *idx))
else:
# If the subtensor remove some dims, we must
# lower the number of dimensions of this scalar.
if node.outputs[0].ndim == i.ndim:
new_inputs.append(i)
else: else:
new_inputs.append( # The original indices may not make sense on some of the broadcasted dimensions
i.dimshuffle(["x"] * node.outputs[0].ndim) new_idxs = [list(idx_tuple) for _ in elem_inputs]
for dim, (dim_idx, dim_bcast_out, *dim_bcast_inputs) in enumerate(
zip(
idx_tuple,
elem_bcast,
*(inp.type.broadcastable for inp in elem_inputs),
# Indices can be shorter than input ndims
strict=False,
) )
):
if is_full_slice(dim_idx):
# Full slice can be safely applied to all inputs
continue
# Copy over previous output stacktrace if all(dim_bcast_inp == elem_bcast for dim_bcast_inp in dim_bcast_inputs):
copy_stack_trace(node.outputs[0], new_inputs) # This dim is not broadcasted for any of the inputs, original index can be applied to all inputs
continue
ret = u.owner.op(*new_inputs) # Some dims are broadcasted, so we need to adapt their indices
# Copy over previous output stacktrace # Slice indexing keeps the dimension, so we use a full slice for broadcasted inputs
# and stacktrace from previous unary operation # Integer indexing drops the dimension, so we index by zero for the broadcsated inputs
copy_stack_trace([node.outputs[0], node.inputs[0]], ret) safe_bcast_dim_idx = slice(None) if isinstance(dim_idx, slice) else 0
return [ret] for inp_idx, dim_bcast_inp in zip(new_idxs, dim_bcast_inputs, strict=True):
if dim_bcast_inp:
inp_idx[dim] = safe_bcast_dim_idx
indexed_inputs = [
inp[tuple(new_idx)]
for inp, new_idx in zip(elem_inputs, new_idxs, strict=True)
]
[old_out] = node.outputs
# Copy stack trace to new inputs
[copy_stack_trace(old_out, new_inp) for new_inp in indexed_inputs]
# Define elemwise operation on indexed inputs
new_out = elem.owner.op(*indexed_inputs)
# Copy stack trace to new output
copy_stack_trace([old_out, *node.inputs], new_out)
return [new_out]
@register_canonicalize("shape_unsafe") @register_canonicalize("shape_unsafe")
......
import numpy as np import numpy as np
import pytest import pytest
import unittest_tools as utt
from pytensor import ( from pytensor import (
Mode, Mode,
...@@ -25,13 +24,11 @@ from pytensor.printing import debugprint ...@@ -25,13 +24,11 @@ from pytensor.printing import debugprint
from pytensor.tensor import ( from pytensor.tensor import (
add, add,
exp, exp,
inplace,
iscalar, iscalar,
iscalars, iscalars,
lscalar, lscalar,
lscalars, lscalars,
matrix, matrix,
scalar,
shape, shape,
slicetype, slicetype,
specify_shape, specify_shape,
...@@ -43,6 +40,7 @@ from pytensor.tensor.basic import MakeVector, expand_dims, make_vector ...@@ -43,6 +40,7 @@ from pytensor.tensor.basic import MakeVector, expand_dims, make_vector
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.rewriting.subtensor_lift import ( from pytensor.tensor.rewriting.subtensor_lift import (
local_subtensor_make_vector, local_subtensor_make_vector,
local_subtensor_of_elemwise,
local_subtensor_shape_constant, local_subtensor_shape_constant,
) )
from pytensor.tensor.shape import SpecifyShape, _shape from pytensor.tensor.shape import SpecifyShape, _shape
...@@ -58,22 +56,8 @@ mode_opt = get_mode(mode_opt) ...@@ -58,22 +56,8 @@ mode_opt = get_mode(mode_opt)
NO_OPTIMIZATION_MODE = Mode(linker="py", optimizer=None) NO_OPTIMIZATION_MODE = Mode(linker="py", optimizer=None)
class TestLocalSubtensorLift: class TestLocalSubtensorOfElemwise:
def test_basic(self): def test_unary_multiple_clients(self):
# basic test that the Op works
x = matrix("x")
f = function([x], exp(x)[0], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f, ops_to_check="all")
prog = f.maker.fgraph.toposort()
assert isinstance(prog[0].op, Subtensor) # first subtensor
assert prog[1].op == exp
assert len(prog) == 2
f([[0, 1], [2, 3]]) # let debugmode test something
def test_basic_1(self):
# as test0, but we reuse the output of the elemwise # as test0, but we reuse the output of the elemwise
# So we should not lift the subtensor # So we should not lift the subtensor
x = matrix("x") x = matrix("x")
...@@ -87,85 +71,16 @@ class TestLocalSubtensorLift: ...@@ -87,85 +71,16 @@ class TestLocalSubtensorLift:
assert isinstance(prog[1].op, Subtensor) # first subtensor assert isinstance(prog[1].op, Subtensor) # first subtensor
assert isinstance(prog[2].op, DeepCopyOp) assert isinstance(prog[2].op, DeepCopyOp)
assert len(prog) == 3 assert len(prog) == 3
f([[0, 1], [2, 3]]) # let debugmode test something
def test_basic_2(self):
# basic test that the optimization work with scalar broadcasted
x = matrix("x")
y = scalar("y")
z = matrix("z")
f = function([x, y, z], exp(x + y + z)[0], mode=mode_opt)
prog = f.maker.fgraph.toposort()
assert isinstance(prog[0].op, Subtensor)
assert isinstance(prog[1].op, DimShuffle)
assert isinstance(prog[2].op, Subtensor)
assert isinstance(prog[3].op.scalar_op, ps.Composite) # Composite{add,add}
assert len(prog) == 4
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f, ops_to_check=[Subtensor])
# let debugmode test something
f([[0, 1], [2, 3]], 4, [[4, 5], [6, 7]])
def test_basic_3(self):
# as 1, but take a slice
x = matrix("x")
y = scalar("y")
z = matrix("z")
f = function([x, y, z], exp(x + y + z)[0:2], mode=mode_opt)
prog = f.maker.fgraph.toposort()
assert isinstance(prog[0].op, Subtensor)
assert isinstance(prog[1].op, DimShuffle)
assert isinstance(prog[2].op, Subtensor)
assert isinstance(prog[3].op.scalar_op, ps.Composite) # Composite{add,add}
assert len(prog) == 4
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f, ops_to_check=[Subtensor])
# let debugmode test something
f([[0, 1], [2, 3]], 4, [[4, 5], [6, 7]])
def test_basic_4(self):
# basic test that the optimization does work with broadcasting
# for unary elemwise.
y = vector("y")
f = function([y], exp(y.dimshuffle(0, "x"))[0], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f, ops_to_check="all")
prog = f.maker.fgraph.toposort()
assert isinstance(prog[0].op, Subtensor)
assert isinstance(prog[1].op, DimShuffle)
assert prog[2].op == exp
assert len(prog) == 3
f([4, 5]) # let debugmode test something
@utt.assertFailure_fast
def test_basic_5(self):
# basic test that the optimization doesn't work with broadcasting
# ... It *could* be extended to,
# ... but right now it doesn't, so it shouldn't try.
x = matrix("x")
y = vector("y")
f = function([x, y], exp(x + y)[0], mode=mode_opt)
# Opt doesn't apply, so no need for check_stack_trace x_test = [[0, 1], [2, 3]]
# assert check_stack_trace(f, ops_to_check='all') res1, res2 = f(x_test)
np.testing.assert_allclose(
prog = f.maker.fgraph.toposort() res1,
assert isinstance(prog[0].op, DimShuffle) np.exp(x_test)[0],
assert prog[1].op == add )
assert isinstance(prog[2].op, Subtensor) # first subtensor np.testing.assert_allclose(res2, np.exp(x_test))
assert prog[3].op == inplace.exp_inplace
assert len(prog) == 4
f([[0, 1], [2, 3]], [4, 5]) # let debugmode test something
def test_basic_6(self): def test_multinary_multiple_clients(self):
# test that we don't lift when we reuse the output of the # test that we don't lift when we reuse the output of the
# elemwise for other computation. # elemwise for other computation.
x = matrix("x") x = matrix("x")
...@@ -181,26 +96,84 @@ class TestLocalSubtensorLift: ...@@ -181,26 +96,84 @@ class TestLocalSubtensorLift:
# first subtensor # first subtensor
assert isinstance(prog[2].op, Subtensor) assert isinstance(prog[2].op, Subtensor)
assert len(prog) == 3 assert len(prog) == 3
f([[0, 1], [2, 3]], [4, 5]) # let debugmode test something
def test_basic_7(self): x_test = np.array([[0, 1], [2, 3]]).astype(x.dtype)
# basic test that the optimization works with a scalar as input, y_test = np.array([4, 5]).astype(y.dtype)
# and a scalar as output (no broadcasting of the scalar needed). res1, res2 = f(x_test, y_test)
# The optimization used to fail and display an ERROR message. np.testing.assert_allclose(
res1,
np.exp(x_test + y_test)[0],
)
np.testing.assert_allclose(
res2,
np.exp(x_test + y_test) + x_test,
)
@pytest.mark.parametrize(
"original_fn, expected_fn",
[
# Unary integer indexing
(lambda x, y: exp(x)[0], lambda x, y: exp(x[0])),
# Unary integer with expand_dims
(lambda x, y: exp(x[:, None])[0], lambda x, y: exp(x[0][None])),
# Integer indexing on non-broadcastable dimension
(lambda x, y: add(x, y)[0], lambda x, y: add(x[0], y[0])),
# Slice indexing on non-broadcastable dimension
(lambda x, y: add(x, y)[1:], lambda x, y: add(x[1:], y[1:])),
# Integer indexing on broacastable dimension
(lambda x, y: add(x[None], y[None])[0], lambda x, y: add(x, y)),
(lambda x, y: add(x[None], y[None])[0, 1], lambda x, y: add(x[1], y[1])),
(
lambda x, y: add(x[None, :], y[:, None])[2],
lambda x, y: add(x, y[2][None]),
),
(
lambda x, y: add(x[:, None], y[None, :])[:, 2],
lambda x, y: add(x, y[2][None]),
),
# Slice indexing on broadcastable dimension
(
lambda x, y: add(x[None], y[None])[1:],
lambda x, y: add(x[None][1:], y[None][1:]),
),
(
lambda x, y: add(x[None, :], y[:, None])[1:],
lambda x, y: add(x[None, :], y[1:][:, None]),
),
],
)
def test_local_subtensor_of_elemwise(self, original_fn, expected_fn):
rng = np.random.default_rng(257)
x = pt.matrix("x", shape=(5, 3))
y = pt.matrix("y", shape=(5, 3))
x_test = rng.normal(size=x.type.shape).astype(x.dtype)
y_test = rng.normal(size=y.type.shape).astype(y.dtype)
out = original_fn(x, y)
expected_opt_out = expected_fn(x, y)
opt_out = rewrite_graph(out)
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
[expected_opt_out, opt_out], print_type=True
)
eval_kwargs = dict(mode=NO_OPTIMIZATION_MODE, on_unused_input="ignore")
np.testing.assert_allclose(
opt_out.eval({x: x_test, y: y_test}, **eval_kwargs),
out.eval({x: x_test, y: y_test}, **eval_kwargs),
)
x = vector("x") def test_local_subtensor_of_elemwise_multiple_clients(self):
y = scalar("y") x = pt.matrix("x", shape=(5, 3))
f = function([x, y], exp(x + y)[0], mode=mode_opt) y = pt.matrix("y", shape=(5, 3))
out1 = add(x, y)
out2 = out1[0]
# Check stacktrace was copied over correctly after opt was applied # Rewrite should fail when another node uses out1 directly (in this case it's an extra output)
assert check_stack_trace(f, ops_to_check=Subtensor) fgraph = FunctionGraph([x, y], [out1, out2], clone=False)
assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is None
prog = f.maker.fgraph.toposort() # Otherwise it should work
assert isinstance(prog[0].op, Subtensor) fgraph.remove_output(0)
# Composite{add,exp} assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is not None
assert isinstance(prog[1].op.scalar_op, ps.Composite)
assert len(prog) == 2
f([1, 2, 3], 4) # let debugmode test something
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论