提交 2c4a3e7b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Tag rewrites that make shape assumptions

上级 5db0d833
...@@ -682,25 +682,7 @@ def add_traceback_configvars(): ...@@ -682,25 +682,7 @@ def add_traceback_configvars():
def add_experimental_configvars(): def add_experimental_configvars():
config.add( return
"experimental__local_alloc_elemwise",
"DEPRECATED: If True, enable the experimental"
" optimization local_alloc_elemwise."
" Generates error if not True. Use"
" optimizer_excluding=local_alloc_elemwise"
" to disable.",
BoolParam(True),
in_c_key=False,
)
# False could make the graph faster but not as safe.
config.add(
"experimental__local_alloc_elemwise_assert",
"When the local_alloc_elemwise is applied, add"
" an assert to highlight shape errors.",
BoolParam(True),
in_c_key=False,
)
def add_error_and_warning_configvars(): def add_error_and_warning_configvars():
......
...@@ -256,7 +256,7 @@ def local_scalar_tensor_scalar(fgraph, node): ...@@ -256,7 +256,7 @@ def local_scalar_tensor_scalar(fgraph, node):
return [s] return [s]
@register_specialize("local_alloc_elemwise") @register_specialize("shape_unsafe")
@node_rewriter([Elemwise]) @node_rewriter([Elemwise])
def local_elemwise_alloc(fgraph, node): def local_elemwise_alloc(fgraph, node):
r"""Remove unnecessary `Alloc`\s that occur as inputs of `Elemwise` `Op`\s. r"""Remove unnecessary `Alloc`\s that occur as inputs of `Elemwise` `Op`\s.
...@@ -377,7 +377,7 @@ def local_elemwise_alloc(fgraph, node): ...@@ -377,7 +377,7 @@ def local_elemwise_alloc(fgraph, node):
return ret return ret
@register_canonicalize @register_canonicalize("shape_unsafe")
@node_rewriter([Elemwise]) @node_rewriter([Elemwise])
def local_fill_sink(fgraph, node): def local_fill_sink(fgraph, node):
""" """
...@@ -428,8 +428,8 @@ def local_fill_sink(fgraph, node): ...@@ -428,8 +428,8 @@ def local_fill_sink(fgraph, node):
return replacements return replacements
@register_specialize @register_specialize("shape_unsafe")
@register_stabilize @register_stabilize("shape_unsafe")
@node_rewriter([fill]) @node_rewriter([fill])
def local_fill_to_alloc(fgraph, node): def local_fill_to_alloc(fgraph, node):
r"""Remove `fill`\s or replace them with `Alloc`\s. r"""Remove `fill`\s or replace them with `Alloc`\s.
...@@ -479,8 +479,8 @@ compile.optdb.register( ...@@ -479,8 +479,8 @@ compile.optdb.register(
) )
@register_canonicalize("fast_compile") @register_canonicalize("fast_compile", "shape_unsafe")
@register_useless @register_useless("shape_unsafe")
@node_rewriter([fill]) @node_rewriter([fill])
def local_useless_fill(fgraph, node): def local_useless_fill(fgraph, node):
"""fill(s,v) -> v """fill(s,v) -> v
...@@ -500,10 +500,10 @@ def local_useless_fill(fgraph, node): ...@@ -500,10 +500,10 @@ def local_useless_fill(fgraph, node):
return [v] return [v]
@register_specialize @register_specialize("shape_unsafe")
@register_stabilize @register_stabilize("shape_unsafe")
@register_canonicalize @register_canonicalize("shape_unsafe")
@register_useless @register_useless("shape_unsafe")
@node_rewriter([Alloc]) @node_rewriter([Alloc])
def local_useless_alloc(fgraph, node): def local_useless_alloc(fgraph, node):
""" """
......
...@@ -1176,7 +1176,7 @@ def mul_calculate(num, denum, aslist=False, out_type=None): ...@@ -1176,7 +1176,7 @@ def mul_calculate(num, denum, aslist=False, out_type=None):
local_mul_canonizer = AlgebraicCanonizer( local_mul_canonizer = AlgebraicCanonizer(
mul, true_div, reciprocal, mul_calculate, False mul, true_div, reciprocal, mul_calculate, False
) )
register_canonicalize(local_mul_canonizer, name="local_mul_canonizer") register_canonicalize(local_mul_canonizer, "shape_unsafe", name="local_mul_canonizer")
@register_canonicalize @register_canonicalize
...@@ -2493,7 +2493,7 @@ add_canonizer = in2out( ...@@ -2493,7 +2493,7 @@ add_canonizer = in2out(
) )
register_canonicalize(local_add_canonizer, name="local_add_canonizer") register_canonicalize(local_add_canonizer, "shape_unsafe", name="local_add_canonizer")
def distribute_greedy(pos_pairs, neg_pairs, num, denum, out_type, minscore=0): def distribute_greedy(pos_pairs, neg_pairs, num, denum, out_type, minscore=0):
......
...@@ -1933,3 +1933,17 @@ class TestLocalElemwiseAlloc: ...@@ -1933,3 +1933,17 @@ class TestLocalElemwiseAlloc:
x_val = np.random.random((1, 5)).astype(self.dtype) x_val = np.random.random((1, 5)).astype(self.dtype)
exp_res = np.broadcast_to(x_val, (5, 5))[..., None] + y_val exp_res = np.broadcast_to(x_val, (5, 5))[..., None] + y_val
assert np.array_equal(func(y_val, x_val), exp_res) assert np.array_equal(func(y_val, x_val), exp_res)
def test_shape_unsafe_tag():
mode = get_mode("FAST_RUN")
x = vector("x")
y = vector("y")
out = x * y / y
fn = function([x, y], out, mode=mode)
np.testing.assert_equal(fn([0, 1], [2, 3, 4]), [0, 1])
fn = function([x, y], out, mode=mode.excluding("shape_unsafe"))
with pytest.raises(ValueError):
fn([0, 1], [2, 3, 4]), [0, 1]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论