提交 8d4c5f4d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove superfluous test value warning in Elemwise fusion rewrite

上级 584496dc
...@@ -16,6 +16,7 @@ from aesara import compile ...@@ -16,6 +16,7 @@ from aesara import compile
from aesara.compile.ops import ViewOp from aesara.compile.ops import ViewOp
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import ( from aesara.graph.basic import (
Apply,
Constant, Constant,
Variable, Variable,
ancestors, ancestors,
...@@ -24,7 +25,7 @@ from aesara.graph.basic import ( ...@@ -24,7 +25,7 @@ from aesara.graph.basic import (
) )
from aesara.graph.features import AlreadyThere, Feature, ReplaceValidate from aesara.graph.features import AlreadyThere, Feature, ReplaceValidate
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value from aesara.graph.op import compute_test_value, get_test_value
from aesara.graph.opt import ( from aesara.graph.opt import (
GlobalOptimizer, GlobalOptimizer,
OpRemove, OpRemove,
...@@ -3003,7 +3004,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None ...@@ -3003,7 +3004,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
fused = False fused = False
for i in node.inputs: for i in node.inputs:
do_fusion = False scalar_node: Optional[Apply] = None
# Will store inputs of the fused node that are not currently inputs # Will store inputs of the fused node that are not currently inputs
# of the node we want to create (to avoid duplicating inputs). # of the node we want to create (to avoid duplicating inputs).
tmp_input = [] tmp_input = []
...@@ -3034,36 +3035,45 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None ...@@ -3034,36 +3035,45 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
tmp_s_input.append(tmp_scalar[tmp_input.index(ii)]) tmp_s_input.append(tmp_scalar[tmp_input.index(ii)])
else: else:
tmp = aes.get_scalar_type(ii.type.dtype).make_variable() tmp = aes.get_scalar_type(ii.type.dtype).make_variable()
try: try:
tv = get_test_value(ii) tv = get_test_value(ii)
if tv.size > 0: # Sometimes the original inputs have
tmp.tag.test_value = tv.flatten()[0] # zero-valued shapes in some dimensions, which
else: # implies that this whole scalar thing doesn't
_logger.warning( # make sense (i.e. we're asking for the scalar
"Cannot construct a scalar test value" # value of an entry in a zero-dimensional
" from a test value with no size: {}".format(ii) # array).
) # This will eventually lead to an error in the
except TestValueError: # `compute_test_value` call below when/if
# `config.compute_test_value_opt` is enabled
# (for debugging, more or less)
tmp.tag.test_value = tv.item()
except (TestValueError, ValueError):
pass pass
tmp_s_input.append(tmp) tmp_s_input.append(tmp)
tmp_input.append(ii) tmp_input.append(ii)
tmp_scalar.append(tmp_s_input[-1]) tmp_scalar.append(tmp_s_input[-1])
s_op = i.owner.op.scalar_op(*tmp_s_input, return_list=True) # Use the `Op.make_node` interface in case `Op.__call__`
# has been customized
scalar_node = i.owner.op.scalar_op.make_node(*tmp_s_input)
if config.compute_test_value_opt != "off":
# This is required because `Op.make_node` won't do it
compute_test_value(scalar_node)
# If the scalar_op doesn't have a C implementation, we skip # If the scalar_op doesn't have a C implementation, we skip
# its fusion to allow fusion of the other ops # its fusion to allow fusion of the other ops
i.owner.op.scalar_op.c_code( i.owner.op.scalar_op.c_code(
s_op[0].owner, scalar_node,
"test_presence_of_c_code", "test_presence_of_c_code",
["x" for x in i.owner.inputs], ["x" for x in i.owner.inputs],
["z" for z in i.owner.outputs], ["z" for z in i.owner.outputs],
{"fail": "%(fail)s"}, {"fail": "%(fail)s"},
) )
do_fusion = True
except (NotImplementedError, MethodNotDefined): except (NotImplementedError, MethodNotDefined):
_logger.warning( _logger.warning(
( (
...@@ -3073,7 +3083,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None ...@@ -3073,7 +3083,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
"loop fusion." "loop fusion."
) )
) )
do_fusion = False scalar_node = None
# Compute the number of inputs in case we fuse this input. # Compute the number of inputs in case we fuse this input.
# We subtract 1 because we replace the existing input with the new # We subtract 1 because we replace the existing input with the new
...@@ -3089,12 +3099,12 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None ...@@ -3089,12 +3099,12 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
if x in node.inputs: if x in node.inputs:
new_nb_input_ -= 1 new_nb_input_ -= 1
if do_fusion and (new_nb_input_ <= max_nb_input): if scalar_node and (new_nb_input_ <= max_nb_input):
fused = True fused = True
new_nb_input = new_nb_input_ new_nb_input = new_nb_input_
inputs.extend(tmp_input) inputs.extend(tmp_input)
s_inputs.extend(tmp_scalar) s_inputs.extend(tmp_scalar)
s_g.extend(s_op) s_g.extend(scalar_node.outputs)
else: else:
# We must support the case where the same variable appears many # We must support the case where the same variable appears many
# times within the inputs # times within the inputs
...@@ -3102,12 +3112,13 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None ...@@ -3102,12 +3112,13 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
s = s_inputs[inputs.index(i)] s = s_inputs[inputs.index(i)]
else: else:
s = aes.get_scalar_type(i.type.dtype).make_variable() s = aes.get_scalar_type(i.type.dtype).make_variable()
if config.compute_test_value_opt != "off":
try: try:
if config.compute_test_value != "off":
v = get_test_value(i) v = get_test_value(i)
if v.size > 0: # See the zero-dimensional test value situation
s.tag.test_value = v.flatten()[0] # described above.
except TestValueError: s.tag.test_value = v.item()
except (TestValueError, ValueError):
pass pass
inputs.append(i) inputs.append(i)
...@@ -3157,7 +3168,8 @@ your code will run correctly, but may be slower.""" ...@@ -3157,7 +3168,8 @@ your code will run correctly, but may be slower."""
if len(new_node.inputs) > max_nb_input: if len(new_node.inputs) > max_nb_input:
_logger.warning( _logger.warning(
"loop fusion failed because Op would exceed" " kernel argument limit." "Loop fusion failed because the resulting node "
"would exceed the kernel argument limit."
) )
return False return False
......
import contextlib
import copy import copy
import numpy as np import numpy as np
...@@ -22,6 +23,7 @@ from aesara.graph.type import Type ...@@ -22,6 +23,7 @@ from aesara.graph.type import Type
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.printing import pprint from aesara.printing import pprint
from aesara.raise_op import Assert, CheckAndRaise from aesara.raise_op import Assert, CheckAndRaise
from aesara.scalar.basic import Composite
from aesara.tensor.basic import ( from aesara.tensor.basic import (
Alloc, Alloc,
Join, Join,
...@@ -1152,6 +1154,54 @@ class TestFusion: ...@@ -1152,6 +1154,54 @@ class TestFusion:
for n in f.maker.fgraph.toposort() for n in f.maker.fgraph.toposort()
) )
@pytest.mark.parametrize("test_value", [np.c_[[1.0]], np.c_[[]]])
def test_test_values(self, test_value):
"""Make sure that `local_elemwise_fusion_op` uses test values correctly when they have zero dimensions.
The test values we're talking about are the ones used when C implementations
are checked.
"""
opts = OptimizationQuery(
include=[
"local_elemwise_fusion",
"composite_elemwise_fusion",
"canonicalize",
],
exclude=["cxx_only", "BlasOpt"],
)
mode = Mode(self.mode.linker, opts)
x, y, z = dmatrices("xyz")
x.tag.test_value = test_value
y.tag.test_value = test_value
z.tag.test_value = test_value
if test_value.size == 0:
cm = pytest.raises(ValueError)
else:
cm = contextlib.suppress()
with config.change_flags(
compute_test_value="raise", compute_test_value_opt="raise"
):
out = x * y + z
with cm:
f = function([x, y, z], out, mode=mode)
if test_value.size != 0:
# Confirm that the fusion happened
assert isinstance(f.maker.fgraph.outputs[0].owner.op.scalar_op, Composite)
assert len(f.maker.fgraph.toposort()) == 1
x_c, y_c, z_c = f.maker.fgraph.outputs[0].owner.inputs
assert np.array_equal(
f.maker.fgraph.outputs[0].tag.test_value, np.c_[[2.0]]
)
class TimesN(aes.basic.UnaryScalarOp): class TimesN(aes.basic.UnaryScalarOp):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论