提交 671cb44b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix local_careduce_fusion rewrite

上级 6b189ee3
...@@ -1150,11 +1150,20 @@ def local_careduce_fusion(fgraph, node): ...@@ -1150,11 +1150,20 @@ def local_careduce_fusion(fgraph, node):
"""Fuse a `CAReduce` applied to an `Elemwise`.""" """Fuse a `CAReduce` applied to an `Elemwise`."""
(car_input,) = node.inputs (car_input,) = node.inputs
car_scalar_op = node.op.scalar_op
# FIXME: This check is needed because of the faulty logic in the FIXME below!
# Right now, rewrite only works for `Sum`/`Prod`
if not isinstance(car_scalar_op, (aes.Add, aes.Mul)):
return None
elm_node = car_input.owner elm_node = car_input.owner
if elm_node is None or not isinstance(elm_node.op, Elemwise): if elm_node is None or not isinstance(elm_node.op, Elemwise):
return False return False
elm_scalar_op = elm_node.op.scalar_op
elm_inputs = elm_node.inputs elm_inputs = elm_node.inputs
elm_outputs = elm_node.outputs elm_outputs = elm_node.outputs
...@@ -1166,21 +1175,15 @@ def local_careduce_fusion(fgraph, node): ...@@ -1166,21 +1175,15 @@ def local_careduce_fusion(fgraph, node):
return False return False
# Don't form the fusion when the target language is Python # Don't form the fusion when the target language is Python
elm_scalar_op = elm_node.op.scalar_op
car_scalar_op = node.op.scalar_op
if get_target_language() == ("py",): if get_target_language() == ("py",):
return False return False
try: if not elm_scalar_op.supports_c_code(elm_inputs, elm_outputs):
elm_scalar_op.c_code( return None
elm_node,
"test_presence_of_c_code",
["x" for x in elm_inputs],
["z" for z in elm_outputs],
{"fail": "%(fail)s"},
)
# FIXME: This fails with Ops like `Max` whose `c_code` always expects two inputs!
# Should implement a `CAReduce.supports_c_code`?
try:
car_scalar_op.c_code( car_scalar_op.c_code(
node, node,
"test_presence_of_c_code", "test_presence_of_c_code",
...@@ -1191,18 +1194,24 @@ def local_careduce_fusion(fgraph, node): ...@@ -1191,18 +1194,24 @@ def local_careduce_fusion(fgraph, node):
except (NotImplementedError, MethodNotDefined): except (NotImplementedError, MethodNotDefined):
return False return False
car_axis = node.op.axis car_op = node.op
car_acc_dtype = node.op.acc_dtype
scalar_elm_inputs = [ scalar_elm_inputs = [
aes.get_scalar_type(inp.type.dtype).make_variable() for inp in elm_inputs aes.get_scalar_type(inp.type.dtype).make_variable() for inp in elm_inputs
] ]
elm_output = elm_scalar_op(*scalar_elm_inputs) elm_output = elm_scalar_op(*scalar_elm_inputs)
# This input represents the previous value in the `CAReduce` binary reduction # This input represents the previous value in the `CAReduce` binary reduction
carried_car_input = elm_output.type() carried_car_input = aes.get_scalar_type(car_acc_dtype).make_variable()
scalar_fused_outputs = [car_scalar_op(carried_car_input, elm_output)]
scalar_fused_output = car_scalar_op(carried_car_input, elm_output)
if scalar_fused_output.type.dtype != car_acc_dtype:
scalar_fused_output = aes.cast(scalar_fused_output, car_acc_dtype)
fused_scalar_op = aes.Composite( fused_scalar_op = aes.Composite(
inputs=[carried_car_input] + scalar_elm_inputs, outputs=scalar_fused_outputs inputs=[carried_car_input] + scalar_elm_inputs, outputs=[scalar_fused_output]
) )
# The fused `Op` needs to look and behave like a `BinaryScalarOp` # The fused `Op` needs to look and behave like a `BinaryScalarOp`
...@@ -1211,7 +1220,13 @@ def local_careduce_fusion(fgraph, node): ...@@ -1211,7 +1220,13 @@ def local_careduce_fusion(fgraph, node):
fused_scalar_op.nin = 2 fused_scalar_op.nin = 2
fused_scalar_op.nout = 1 fused_scalar_op.nout = 1
new_car_op = CAReduce(fused_scalar_op, car_axis) new_car_op = CAReduce(
scalar_op=fused_scalar_op,
axis=car_op.axis,
acc_dtype=car_acc_dtype,
dtype=car_op.dtype,
upcast_discrete_output=car_op.upcast_discrete_output,
)
return [new_car_op(*elm_inputs)] return [new_car_op(*elm_inputs)]
......
...@@ -1177,8 +1177,24 @@ class TestFusion: ...@@ -1177,8 +1177,24 @@ class TestFusion:
) )
@pytest.mark.parametrize("linker", ["cvm", "py"]) @pytest.mark.parametrize("linker", ["cvm", "py"])
@pytest.mark.parametrize("inp_dtype", ("floatX", "int32"))
@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1), (0, 1, 2)]) @pytest.mark.parametrize("axis", [None, 0, 1, (0, 1), (0, 1, 2)])
def test_CAReduce_single_input(self, linker, axis): @pytest.mark.parametrize(
"careduce_op, numpy_op",
[
(at_sum, np.sum),
pytest.param(
at_all,
np.all,
marks=pytest.mark.xfail(
reason="Rewrite logic does not support all CAReduce"
),
),
],
)
def test_CAReduce_single_input(
self, linker, inp_dtype, axis, careduce_op, numpy_op
):
"""Make sure that `CAReduce` and `Elemwise` fusions work with a single input.""" """Make sure that `CAReduce` and `Elemwise` fusions work with a single input."""
mode = Mode(linker=linker) mode = Mode(linker=linker)
...@@ -1188,8 +1204,8 @@ class TestFusion: ...@@ -1188,8 +1204,8 @@ class TestFusion:
"inplace", "inplace",
) )
x = tensor(dtype="floatX", shape=(None, None, None), name="x") x = tensor(dtype=inp_dtype, shape=(None, None, None), name="x")
out = exp(x).sum(axis=axis) out = careduce_op(exp(x), axis=axis)
out_fn = function([x], out, mode=mode) out_fn = function([x], out, mode=mode)
...@@ -1198,9 +1214,9 @@ class TestFusion: ...@@ -1198,9 +1214,9 @@ class TestFusion:
assert isinstance(getattr(out_node.op, "scalar_op"), aes.basic.Composite) assert isinstance(getattr(out_node.op, "scalar_op"), aes.basic.Composite)
rng = np.random.default_rng(2320) rng = np.random.default_rng(2320)
x_val = rng.random((4, 3, 2), dtype=config.floatX) x_val = rng.random((4, 3, 2)).astype(x.type.dtype)
exp_res = np.exp(x_val).sum(axis=axis) exp_res = numpy_op(np.exp(x_val), axis=axis)
out_val = out_fn(x_val) out_val = out_fn(x_val)
assert out_val.shape == exp_res.shape assert out_val.shape == exp_res.shape
...@@ -1216,7 +1232,7 @@ class TestFusion: ...@@ -1216,7 +1232,7 @@ class TestFusion:
# `Elemwise`s with more than one client shouldn't be rewritten # `Elemwise`s with more than one client shouldn't be rewritten
x = tensor(dtype="floatX", shape=(None, None, None), name="x") x = tensor(dtype="floatX", shape=(None, None, None), name="x")
exp_x = exp(x) exp_x = exp(x)
out = exp_x.sum(axis=axis) + exp(x) out = careduce_op(exp_x, axis=axis) + exp(x)
out_fn = function([x], out, mode=mode) out_fn = function([x], out, mode=mode)
out_nodes = out_fn.maker.fgraph.toposort() out_nodes = out_fn.maker.fgraph.toposort()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论