提交 c8aa0615 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #3974 from nouiz/faster_opt

Faster optimisation
......@@ -3344,7 +3344,7 @@ pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right'))
##########################
def extract_constant(x, elemwise=True):
def extract_constant(x, elemwise=True, only_process_constants=False):
"""
This function is basically a call to tensor.get_scalar_constant_value.
......@@ -3356,7 +3356,9 @@ def extract_constant(x, elemwise=True):
"""
try:
x = get_scalar_constant_value(x, elemwise=elemwise)
x = get_scalar_constant_value(x,
elemwise,
only_process_constants)
except NotScalarConstantError:
pass
if ((isinstance(x, scal.ScalarVariable) or
......
......@@ -4665,23 +4665,23 @@ def local_useless_elemwise_comparison(node):
if isinstance(node.op.scalar_op, scalar.LT) and \
node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1]) == 0:
T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X)
if isinstance(node.op.scalar_op, scalar.GE) and \
node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1]) == 0:
T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.ones_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[maximum](X.shape[i], 0) -> X.shape[i]
if isinstance(node.op.scalar_op, scalar.Maximum) and \
node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1]) == 0:
T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [node.inputs[0]]
# Elemwise[maximum](0, X.shape[i]) -> X.shape[i]
if isinstance(node.op.scalar_op, scalar.Maximum) and \
T.extract_constant(node.inputs[0]) == 0 and \
T.extract_constant(node.inputs[0], only_process_constants=True) == 0 and \
node.inputs[1].owner and \
isinstance(node.inputs[1].owner.op, Shape_i):
return [node.inputs[1]]
......@@ -4689,11 +4689,11 @@ def local_useless_elemwise_comparison(node):
if isinstance(node.op.scalar_op, scalar.Minimum) and \
node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1]) == 0:
T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[minimum](0, X.shape[i]) -> 0
if isinstance(node.op.scalar_op, scalar.Minimum) and \
T.extract_constant(node.inputs[0]) == 0 and \
T.extract_constant(node.inputs[0], only_process_constants=True) == 0 and \
node.inputs[1].owner and \
isinstance(node.inputs[1].owner.op, Shape_i):
return [T.zeros_like(node.inputs[1], dtype=node.outputs[0].dtype)]
......@@ -4705,7 +4705,7 @@ def local_useless_elemwise_comparison(node):
isinstance(node.inputs[0].owner.op.scalar_op, scalar.Add) and \
all([isinstance(var.owner and var.owner.op, Shape_i)
for var in node.inputs[0].owner.inputs]) and \
T.extract_constant(node.inputs[1]) == 0:
T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X)
......@@ -4715,7 +4715,7 @@ def local_useless_elemwise_comparison(node):
isinstance(node.inputs[0].owner.op.scalar_op, scalar.Add) and \
all([isinstance(var.owner and var.owner.op, Shape_i)
for var in node.inputs[0].owner.inputs]) and \
T.extract_constant(node.inputs[1]) == 0:
T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.ones_like(node.inputs[0], dtype=node.outputs[0].dtype)]
return
......@@ -5826,38 +5826,6 @@ register_stabilize(topo_constant_folding, 'fast_compile', final_opt=True)
register_specialize(topo_constant_folding, 'fast_compile', final_opt=True)
def _is_1(expr):
"""
Returns
-------
bool
True iff expr is a constant close to 1.
"""
try:
v = get_scalar_constant_value(expr)
return numpy.allclose(v, 1)
except NotScalarConstantError:
return False
def _is_minus1(expr):
"""
Returns
-------
bool
True iff expr is a constant close to -1.
"""
try:
v = get_scalar_constant_value(expr)
return numpy.allclose(v, -1)
except NotScalarConstantError:
return False
def get_clients(node):
"""
Used by erf/erfc opt to track less frequent op.
......@@ -5881,7 +5849,7 @@ def get_clients2(node):
# 1+erf(x)=>erfc(-x)
local_one_plus_erf = gof.PatternSub((T.add,
dict(pattern='y', constraint=_is_1),
1,
(T.erf, 'x')),
(T.erfc, (T.neg, 'x')),
allow_multiple_clients=True,
......@@ -5894,7 +5862,7 @@ register_specialize(local_one_plus_erf)
# 1-erf(x)=>erfc(x)
local_one_minus_erf = gof.PatternSub((T.sub,
dict(pattern='y', constraint=_is_1),
1,
(T.erf, 'x')),
(T.erfc, 'x'),
allow_multiple_clients=True,
......@@ -5916,7 +5884,7 @@ register_specialize(local_one_minus_erf2)
# 1+(-erf(x))=>erfc(x) This is a different graph then the previous as
# the canonicalize don't work completly
local_one_plus_neg_erf = gof.PatternSub((T.add,
dict(pattern='y', constraint=_is_1),
1,
(T.neg, (T.erf, 'x'))),
(T.erfc, 'x'),
allow_multiple_clients=True,
......@@ -5930,7 +5898,7 @@ register_specialize(local_one_plus_neg_erf)
# (-1)+erf(x) => -erfc(x) don't need erf(x)+(-1) as the canonicalize
# will put the -1 as the first argument.
local_erf_minus_one = gof.PatternSub((T.add,
dict(pattern='y', constraint=_is_minus1),
-1,
(T.erf, 'x')),
(T.neg, (T.erfc, 'x')),
allow_multiple_clients=True,
......@@ -5943,7 +5911,7 @@ register_specialize(local_erf_minus_one)
# 1-erfc(x) => erf(x)
local_one_minus_erfc = gof.PatternSub((T.sub,
dict(pattern='y', constraint=_is_1),
1,
(T.erfc, 'x')),
(T.erf, 'x'),
allow_multiple_clients=True,
......@@ -5981,7 +5949,7 @@ register_specialize(local_one_minus_erfc3)
# 1+(-erfc(x)) => erf(x) This is a different graph then the previous as
# the canonicalize don't work completly
local_one_add_neg_erfc = gof.PatternSub((T.add,
dict(pattern='y', constraint=_is_1),
1,
(T.neg, (T.erfc, 'x'))),
(T.erf, 'x'),
allow_multiple_clients=True,
......@@ -5995,7 +5963,7 @@ register_specialize(local_one_add_neg_erfc)
# (-1)+erfc(-x)=>erf(x)
local_erf_neg_minus_one = gof.PatternSub((T.add,
dict(pattern='y', constraint=_is_minus1),
-1,
(T.erfc, (T.neg, 'x'))),
(T.erf, 'x'),
allow_multiple_clients=True,
......@@ -6008,7 +5976,7 @@ register_specialize(local_erf_neg_minus_one)
# (-1)+erfc(-1*x)=>erf(x)
local_erf_neg_minus_one2 = gof.PatternSub((T.add,
dict(pattern='y', constraint=_is_minus1),
-1,
(T.erfc, (T.mul, -1, 'x'))),
(T.erf, 'x'),
allow_multiple_clients=True,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论