提交 8d4c5f4d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove superfluous test value warning in Elemwise fusion rewrite

上级 584496dc
......@@ -16,6 +16,7 @@ from aesara import compile
from aesara.compile.ops import ViewOp
from aesara.configdefaults import config
from aesara.graph.basic import (
Apply,
Constant,
Variable,
ancestors,
......@@ -24,7 +25,7 @@ from aesara.graph.basic import (
)
from aesara.graph.features import AlreadyThere, Feature, ReplaceValidate
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value
from aesara.graph.op import compute_test_value, get_test_value
from aesara.graph.opt import (
GlobalOptimizer,
OpRemove,
......@@ -3003,7 +3004,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
fused = False
for i in node.inputs:
do_fusion = False
scalar_node: Optional[Apply] = None
# Will store inputs of the fused node that are not currently inputs
# of the node we want to create (to avoid duplicating inputs).
tmp_input = []
......@@ -3034,36 +3035,45 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
tmp_s_input.append(tmp_scalar[tmp_input.index(ii)])
else:
tmp = aes.get_scalar_type(ii.type.dtype).make_variable()
try:
tv = get_test_value(ii)
if tv.size > 0:
tmp.tag.test_value = tv.flatten()[0]
else:
_logger.warning(
"Cannot construct a scalar test value"
" from a test value with no size: {}".format(ii)
)
except TestValueError:
# Sometimes the original inputs have
# zero-valued shapes in some dimensions, which
# implies that this whole scalar thing doesn't
# make sense (i.e. we're asking for the scalar
# value of an entry in a zero-dimensional
# array).
# This will eventually lead to an error in the
# `compute_test_value` call below when/if
# `config.compute_test_value_opt` is enabled
# (for debugging, more or less)
tmp.tag.test_value = tv.item()
except (TestValueError, ValueError):
pass
tmp_s_input.append(tmp)
tmp_input.append(ii)
tmp_scalar.append(tmp_s_input[-1])
s_op = i.owner.op.scalar_op(*tmp_s_input, return_list=True)
# Use the `Op.make_node` interface in case `Op.__call__`
# has been customized
scalar_node = i.owner.op.scalar_op.make_node(*tmp_s_input)
if config.compute_test_value_opt != "off":
# This is required because `Op.make_node` won't do it
compute_test_value(scalar_node)
# If the scalar_op doesn't have a C implementation, we skip
# its fusion to allow fusion of the other ops
i.owner.op.scalar_op.c_code(
s_op[0].owner,
scalar_node,
"test_presence_of_c_code",
["x" for x in i.owner.inputs],
["z" for z in i.owner.outputs],
{"fail": "%(fail)s"},
)
do_fusion = True
except (NotImplementedError, MethodNotDefined):
_logger.warning(
(
......@@ -3073,7 +3083,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
"loop fusion."
)
)
do_fusion = False
scalar_node = None
# Compute the number of inputs in case we fuse this input.
# We subtract 1 because we replace the existing input with the new
......@@ -3089,12 +3099,12 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
if x in node.inputs:
new_nb_input_ -= 1
if do_fusion and (new_nb_input_ <= max_nb_input):
if scalar_node and (new_nb_input_ <= max_nb_input):
fused = True
new_nb_input = new_nb_input_
inputs.extend(tmp_input)
s_inputs.extend(tmp_scalar)
s_g.extend(s_op)
s_g.extend(scalar_node.outputs)
else:
# We must support the case where the same variable appears many
# times within the inputs
......@@ -3102,13 +3112,14 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
s = s_inputs[inputs.index(i)]
else:
s = aes.get_scalar_type(i.type.dtype).make_variable()
try:
if config.compute_test_value != "off":
if config.compute_test_value_opt != "off":
try:
v = get_test_value(i)
if v.size > 0:
s.tag.test_value = v.flatten()[0]
except TestValueError:
pass
# See the zero-dimensional test value situation
# described above.
s.tag.test_value = v.item()
except (TestValueError, ValueError):
pass
inputs.append(i)
s_inputs.append(s)
......@@ -3157,7 +3168,8 @@ your code will run correctly, but may be slower."""
if len(new_node.inputs) > max_nb_input:
_logger.warning(
"loop fusion failed because Op would exceed" " kernel argument limit."
"Loop fusion failed because the resulting node "
"would exceed the kernel argument limit."
)
return False
......
import contextlib
import copy
import numpy as np
......@@ -22,6 +23,7 @@ from aesara.graph.type import Type
from aesara.misc.safe_asarray import _asarray
from aesara.printing import pprint
from aesara.raise_op import Assert, CheckAndRaise
from aesara.scalar.basic import Composite
from aesara.tensor.basic import (
Alloc,
Join,
......@@ -1152,6 +1154,54 @@ class TestFusion:
for n in f.maker.fgraph.toposort()
)
@pytest.mark.parametrize("test_value", [np.c_[[1.0]], np.c_[[]]])
def test_test_values(self, test_value):
"""Make sure that `local_elemwise_fusion_op` uses test values correctly when they have zero dimensions.
The test values we're talking about are the ones used when C implementations
are checked.
"""
opts = OptimizationQuery(
include=[
"local_elemwise_fusion",
"composite_elemwise_fusion",
"canonicalize",
],
exclude=["cxx_only", "BlasOpt"],
)
mode = Mode(self.mode.linker, opts)
x, y, z = dmatrices("xyz")
x.tag.test_value = test_value
y.tag.test_value = test_value
z.tag.test_value = test_value
if test_value.size == 0:
cm = pytest.raises(ValueError)
else:
cm = contextlib.suppress()
with config.change_flags(
compute_test_value="raise", compute_test_value_opt="raise"
):
out = x * y + z
with cm:
f = function([x, y, z], out, mode=mode)
if test_value.size != 0:
# Confirm that the fusion happened
assert isinstance(f.maker.fgraph.outputs[0].owner.op.scalar_op, Composite)
assert len(f.maker.fgraph.toposort()) == 1
x_c, y_c, z_c = f.maker.fgraph.outputs[0].owner.inputs
assert np.array_equal(
f.maker.fgraph.outputs[0].tag.test_value, np.c_[[2.0]]
)
class TimesN(aes.basic.UnaryScalarOp):
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论