提交 2c4a3e7b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Tag rewrites that make shape assumptions

上级 5db0d833
......@@ -682,25 +682,7 @@ def add_traceback_configvars():
def add_experimental_configvars():
config.add(
"experimental__local_alloc_elemwise",
"DEPRECATED: If True, enable the experimental"
" optimization local_alloc_elemwise."
" Generates error if not True. Use"
" optimizer_excluding=local_alloc_elemwise"
" to disable.",
BoolParam(True),
in_c_key=False,
)
# False could make the graph faster but not as safe.
config.add(
"experimental__local_alloc_elemwise_assert",
"When the local_alloc_elemwise is applied, add"
" an assert to highlight shape errors.",
BoolParam(True),
in_c_key=False,
)
return
def add_error_and_warning_configvars():
......
......@@ -256,7 +256,7 @@ def local_scalar_tensor_scalar(fgraph, node):
return [s]
@register_specialize("local_alloc_elemwise")
@register_specialize("shape_unsafe")
@node_rewriter([Elemwise])
def local_elemwise_alloc(fgraph, node):
r"""Remove unnecessary `Alloc`\s that occur as inputs of `Elemwise` `Op`\s.
......@@ -377,7 +377,7 @@ def local_elemwise_alloc(fgraph, node):
return ret
@register_canonicalize
@register_canonicalize("shape_unsafe")
@node_rewriter([Elemwise])
def local_fill_sink(fgraph, node):
"""
......@@ -428,8 +428,8 @@ def local_fill_sink(fgraph, node):
return replacements
@register_specialize
@register_stabilize
@register_specialize("shape_unsafe")
@register_stabilize("shape_unsafe")
@node_rewriter([fill])
def local_fill_to_alloc(fgraph, node):
r"""Remove `fill`\s or replace them with `Alloc`\s.
......@@ -479,8 +479,8 @@ compile.optdb.register(
)
@register_canonicalize("fast_compile")
@register_useless
@register_canonicalize("fast_compile", "shape_unsafe")
@register_useless("shape_unsafe")
@node_rewriter([fill])
def local_useless_fill(fgraph, node):
"""fill(s,v) -> v
......@@ -500,10 +500,10 @@ def local_useless_fill(fgraph, node):
return [v]
@register_specialize
@register_stabilize
@register_canonicalize
@register_useless
@register_specialize("shape_unsafe")
@register_stabilize("shape_unsafe")
@register_canonicalize("shape_unsafe")
@register_useless("shape_unsafe")
@node_rewriter([Alloc])
def local_useless_alloc(fgraph, node):
"""
......
......@@ -1176,7 +1176,7 @@ def mul_calculate(num, denum, aslist=False, out_type=None):
local_mul_canonizer = AlgebraicCanonizer(
mul, true_div, reciprocal, mul_calculate, False
)
register_canonicalize(local_mul_canonizer, name="local_mul_canonizer")
register_canonicalize(local_mul_canonizer, "shape_unsafe", name="local_mul_canonizer")
@register_canonicalize
......@@ -2493,7 +2493,7 @@ add_canonizer = in2out(
)
register_canonicalize(local_add_canonizer, name="local_add_canonizer")
register_canonicalize(local_add_canonizer, "shape_unsafe", name="local_add_canonizer")
def distribute_greedy(pos_pairs, neg_pairs, num, denum, out_type, minscore=0):
......
......@@ -1933,3 +1933,17 @@ class TestLocalElemwiseAlloc:
x_val = np.random.random((1, 5)).astype(self.dtype)
exp_res = np.broadcast_to(x_val, (5, 5))[..., None] + y_val
assert np.array_equal(func(y_val, x_val), exp_res)
def test_shape_unsafe_tag():
mode = get_mode("FAST_RUN")
x = vector("x")
y = vector("y")
out = x * y / y
fn = function([x, y], out, mode=mode)
np.testing.assert_equal(fn([0, 1], [2, 3, 4]), [0, 1])
fn = function([x, y], out, mode=mode.excluding("shape_unsafe"))
with pytest.raises(ValueError):
fn([0, 1], [2, 3, 4]), [0, 1]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论