提交 c8aa0615 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #3974 from nouiz/faster_opt

Faster optimisation
...@@ -3344,7 +3344,7 @@ pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right')) ...@@ -3344,7 +3344,7 @@ pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right'))
########################## ##########################
def extract_constant(x, elemwise=True): def extract_constant(x, elemwise=True, only_process_constants=False):
""" """
This function is basically a call to tensor.get_scalar_constant_value. This function is basically a call to tensor.get_scalar_constant_value.
...@@ -3356,7 +3356,9 @@ def extract_constant(x, elemwise=True): ...@@ -3356,7 +3356,9 @@ def extract_constant(x, elemwise=True):
""" """
try: try:
x = get_scalar_constant_value(x, elemwise=elemwise) x = get_scalar_constant_value(x,
elemwise,
only_process_constants)
except NotScalarConstantError: except NotScalarConstantError:
pass pass
if ((isinstance(x, scal.ScalarVariable) or if ((isinstance(x, scal.ScalarVariable) or
......
...@@ -4665,23 +4665,23 @@ def local_useless_elemwise_comparison(node): ...@@ -4665,23 +4665,23 @@ def local_useless_elemwise_comparison(node):
if isinstance(node.op.scalar_op, scalar.LT) and \ if isinstance(node.op.scalar_op, scalar.LT) and \
node.inputs[0].owner and \ node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \ isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1]) == 0: T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)] return [T.zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X) # Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X)
if isinstance(node.op.scalar_op, scalar.GE) and \ if isinstance(node.op.scalar_op, scalar.GE) and \
node.inputs[0].owner and \ node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \ isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1]) == 0: T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.ones_like(node.inputs[0], dtype=node.outputs[0].dtype)] return [T.ones_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[maximum](X.shape[i], 0) -> X.shape[i] # Elemwise[maximum](X.shape[i], 0) -> X.shape[i]
if isinstance(node.op.scalar_op, scalar.Maximum) and \ if isinstance(node.op.scalar_op, scalar.Maximum) and \
node.inputs[0].owner and \ node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \ isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1]) == 0: T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [node.inputs[0]] return [node.inputs[0]]
# Elemwise[maximum](0, X.shape[i]) -> X.shape[i] # Elemwise[maximum](0, X.shape[i]) -> X.shape[i]
if isinstance(node.op.scalar_op, scalar.Maximum) and \ if isinstance(node.op.scalar_op, scalar.Maximum) and \
T.extract_constant(node.inputs[0]) == 0 and \ T.extract_constant(node.inputs[0], only_process_constants=True) == 0 and \
node.inputs[1].owner and \ node.inputs[1].owner and \
isinstance(node.inputs[1].owner.op, Shape_i): isinstance(node.inputs[1].owner.op, Shape_i):
return [node.inputs[1]] return [node.inputs[1]]
...@@ -4689,11 +4689,11 @@ def local_useless_elemwise_comparison(node): ...@@ -4689,11 +4689,11 @@ def local_useless_elemwise_comparison(node):
if isinstance(node.op.scalar_op, scalar.Minimum) and \ if isinstance(node.op.scalar_op, scalar.Minimum) and \
node.inputs[0].owner and \ node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \ isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1]) == 0: T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)] return [T.zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[minimum](0, X.shape[i]) -> 0 # Elemwise[minimum](0, X.shape[i]) -> 0
if isinstance(node.op.scalar_op, scalar.Minimum) and \ if isinstance(node.op.scalar_op, scalar.Minimum) and \
T.extract_constant(node.inputs[0]) == 0 and \ T.extract_constant(node.inputs[0], only_process_constants=True) == 0 and \
node.inputs[1].owner and \ node.inputs[1].owner and \
isinstance(node.inputs[1].owner.op, Shape_i): isinstance(node.inputs[1].owner.op, Shape_i):
return [T.zeros_like(node.inputs[1], dtype=node.outputs[0].dtype)] return [T.zeros_like(node.inputs[1], dtype=node.outputs[0].dtype)]
...@@ -4705,7 +4705,7 @@ def local_useless_elemwise_comparison(node): ...@@ -4705,7 +4705,7 @@ def local_useless_elemwise_comparison(node):
isinstance(node.inputs[0].owner.op.scalar_op, scalar.Add) and \ isinstance(node.inputs[0].owner.op.scalar_op, scalar.Add) and \
all([isinstance(var.owner and var.owner.op, Shape_i) all([isinstance(var.owner and var.owner.op, Shape_i)
for var in node.inputs[0].owner.inputs]) and \ for var in node.inputs[0].owner.inputs]) and \
T.extract_constant(node.inputs[1]) == 0: T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)] return [T.zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X) # Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X)
...@@ -4715,7 +4715,7 @@ def local_useless_elemwise_comparison(node): ...@@ -4715,7 +4715,7 @@ def local_useless_elemwise_comparison(node):
isinstance(node.inputs[0].owner.op.scalar_op, scalar.Add) and \ isinstance(node.inputs[0].owner.op.scalar_op, scalar.Add) and \
all([isinstance(var.owner and var.owner.op, Shape_i) all([isinstance(var.owner and var.owner.op, Shape_i)
for var in node.inputs[0].owner.inputs]) and \ for var in node.inputs[0].owner.inputs]) and \
T.extract_constant(node.inputs[1]) == 0: T.extract_constant(node.inputs[1], only_process_constants=True) == 0:
return [T.ones_like(node.inputs[0], dtype=node.outputs[0].dtype)] return [T.ones_like(node.inputs[0], dtype=node.outputs[0].dtype)]
return return
...@@ -5826,38 +5826,6 @@ register_stabilize(topo_constant_folding, 'fast_compile', final_opt=True) ...@@ -5826,38 +5826,6 @@ register_stabilize(topo_constant_folding, 'fast_compile', final_opt=True)
register_specialize(topo_constant_folding, 'fast_compile', final_opt=True) register_specialize(topo_constant_folding, 'fast_compile', final_opt=True)
def _is_1(expr):
"""
Returns
-------
bool
True iff expr is a constant close to 1.
"""
try:
v = get_scalar_constant_value(expr)
return numpy.allclose(v, 1)
except NotScalarConstantError:
return False
def _is_minus1(expr):
"""
Returns
-------
bool
True iff expr is a constant close to -1.
"""
try:
v = get_scalar_constant_value(expr)
return numpy.allclose(v, -1)
except NotScalarConstantError:
return False
def get_clients(node): def get_clients(node):
""" """
Used by erf/erfc opt to track less frequent op. Used by erf/erfc opt to track less frequent op.
...@@ -5881,7 +5849,7 @@ def get_clients2(node): ...@@ -5881,7 +5849,7 @@ def get_clients2(node):
# 1+erf(x)=>erfc(-x) # 1+erf(x)=>erfc(-x)
local_one_plus_erf = gof.PatternSub((T.add, local_one_plus_erf = gof.PatternSub((T.add,
dict(pattern='y', constraint=_is_1), 1,
(T.erf, 'x')), (T.erf, 'x')),
(T.erfc, (T.neg, 'x')), (T.erfc, (T.neg, 'x')),
allow_multiple_clients=True, allow_multiple_clients=True,
...@@ -5894,7 +5862,7 @@ register_specialize(local_one_plus_erf) ...@@ -5894,7 +5862,7 @@ register_specialize(local_one_plus_erf)
# 1-erf(x)=>erfc(x) # 1-erf(x)=>erfc(x)
local_one_minus_erf = gof.PatternSub((T.sub, local_one_minus_erf = gof.PatternSub((T.sub,
dict(pattern='y', constraint=_is_1), 1,
(T.erf, 'x')), (T.erf, 'x')),
(T.erfc, 'x'), (T.erfc, 'x'),
allow_multiple_clients=True, allow_multiple_clients=True,
...@@ -5916,7 +5884,7 @@ register_specialize(local_one_minus_erf2) ...@@ -5916,7 +5884,7 @@ register_specialize(local_one_minus_erf2)
# 1+(-erf(x))=>erfc(x) This is a different graph then the previous as # 1+(-erf(x))=>erfc(x) This is a different graph then the previous as
# the canonicalize don't work completly # the canonicalize don't work completly
local_one_plus_neg_erf = gof.PatternSub((T.add, local_one_plus_neg_erf = gof.PatternSub((T.add,
dict(pattern='y', constraint=_is_1), 1,
(T.neg, (T.erf, 'x'))), (T.neg, (T.erf, 'x'))),
(T.erfc, 'x'), (T.erfc, 'x'),
allow_multiple_clients=True, allow_multiple_clients=True,
...@@ -5930,7 +5898,7 @@ register_specialize(local_one_plus_neg_erf) ...@@ -5930,7 +5898,7 @@ register_specialize(local_one_plus_neg_erf)
# (-1)+erf(x) => -erfc(x) don't need erf(x)+(-1) as the canonicalize # (-1)+erf(x) => -erfc(x) don't need erf(x)+(-1) as the canonicalize
# will put the -1 as the first argument. # will put the -1 as the first argument.
local_erf_minus_one = gof.PatternSub((T.add, local_erf_minus_one = gof.PatternSub((T.add,
dict(pattern='y', constraint=_is_minus1), -1,
(T.erf, 'x')), (T.erf, 'x')),
(T.neg, (T.erfc, 'x')), (T.neg, (T.erfc, 'x')),
allow_multiple_clients=True, allow_multiple_clients=True,
...@@ -5943,7 +5911,7 @@ register_specialize(local_erf_minus_one) ...@@ -5943,7 +5911,7 @@ register_specialize(local_erf_minus_one)
# 1-erfc(x) => erf(x) # 1-erfc(x) => erf(x)
local_one_minus_erfc = gof.PatternSub((T.sub, local_one_minus_erfc = gof.PatternSub((T.sub,
dict(pattern='y', constraint=_is_1), 1,
(T.erfc, 'x')), (T.erfc, 'x')),
(T.erf, 'x'), (T.erf, 'x'),
allow_multiple_clients=True, allow_multiple_clients=True,
...@@ -5981,7 +5949,7 @@ register_specialize(local_one_minus_erfc3) ...@@ -5981,7 +5949,7 @@ register_specialize(local_one_minus_erfc3)
# 1+(-erfc(x)) => erf(x) This is a different graph then the previous as # 1+(-erfc(x)) => erf(x) This is a different graph then the previous as
# the canonicalize don't work completly # the canonicalize don't work completly
local_one_add_neg_erfc = gof.PatternSub((T.add, local_one_add_neg_erfc = gof.PatternSub((T.add,
dict(pattern='y', constraint=_is_1), 1,
(T.neg, (T.erfc, 'x'))), (T.neg, (T.erfc, 'x'))),
(T.erf, 'x'), (T.erf, 'x'),
allow_multiple_clients=True, allow_multiple_clients=True,
...@@ -5995,7 +5963,7 @@ register_specialize(local_one_add_neg_erfc) ...@@ -5995,7 +5963,7 @@ register_specialize(local_one_add_neg_erfc)
# (-1)+erfc(-x)=>erf(x) # (-1)+erfc(-x)=>erf(x)
local_erf_neg_minus_one = gof.PatternSub((T.add, local_erf_neg_minus_one = gof.PatternSub((T.add,
dict(pattern='y', constraint=_is_minus1), -1,
(T.erfc, (T.neg, 'x'))), (T.erfc, (T.neg, 'x'))),
(T.erf, 'x'), (T.erf, 'x'),
allow_multiple_clients=True, allow_multiple_clients=True,
...@@ -6008,7 +5976,7 @@ register_specialize(local_erf_neg_minus_one) ...@@ -6008,7 +5976,7 @@ register_specialize(local_erf_neg_minus_one)
# (-1)+erfc(-1*x)=>erf(x) # (-1)+erfc(-1*x)=>erf(x)
local_erf_neg_minus_one2 = gof.PatternSub((T.add, local_erf_neg_minus_one2 = gof.PatternSub((T.add,
dict(pattern='y', constraint=_is_minus1), -1,
(T.erfc, (T.mul, -1, 'x'))), (T.erfc, (T.mul, -1, 'x'))),
(T.erf, 'x'), (T.erf, 'x'),
allow_multiple_clients=True, allow_multiple_clients=True,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论