提交 64d621b4 authored 作者: Frederic's avatar Frederic

Make opt faster by calling get_scalar_constant_value() only on constant.

This last commit finish to bring from 584s to 66s the slow scan test: theano/scan_module/tests/test_scan.py:T_Scan.test_hessian_bug_grad_grad_two_scans The changed optimization: local_abs_merge, local_mul_switch_sink, local_upcast_elemwise_constant_inputs, local_remove_switch_const_con are called on nodes that aren't yet part of the graph, during local_subtensor_merge optimization. They aren't supposed to work on node that are in the graph, but as they called get_scalar_constant_value() they traversed it very frequently! This pre_greedy_local_optimizer call is needed as it would introduce too much scrapt that will be optimized away. So we preoptimize the graph. As now we do the constant_folding in the phase (canonicalize/stabilize/specialize) where those changed opt are used, they can work directly on Constant. Other opt will make sure their inputs get constant if possible.
上级 231e51f0
...@@ -1580,6 +1580,8 @@ def local_upcast_elemwise_constant_inputs(node): ...@@ -1580,6 +1580,8 @@ def local_upcast_elemwise_constant_inputs(node):
new_inputs.append(i) new_inputs.append(i)
else: else:
try: try:
if not isinstance(i, Constant):
raise NotScalarConstantError()
# works only for scalars # works only for scalars
cval_i = get_scalar_constant_value(i) cval_i = get_scalar_constant_value(i)
if all(i.broadcastable): if all(i.broadcastable):
...@@ -2326,7 +2328,8 @@ def local_remove_switch_const_cond(node): ...@@ -2326,7 +2328,8 @@ def local_remove_switch_const_cond(node):
if cond is constant and cond != 0: left if cond is constant and cond != 0: left
""" """
if (isinstance(node.op, T.Elemwise) and if (isinstance(node.op, T.Elemwise) and
isinstance(node.op.scalar_op, scalar.basic.Switch)): isinstance(node.op.scalar_op, scalar.basic.Switch) and
isinstance(node.inputs[0], Constant)):
cond = T.extract_constant(node.inputs[0]) cond = T.extract_constant(node.inputs[0])
if type(cond) is numpy.ndarray and cond.ndim == 0: if type(cond) is numpy.ndarray and cond.ndim == 0:
if cond == 0: if cond == 0:
...@@ -2377,7 +2380,8 @@ def local_mul_switch_sink(node): ...@@ -2377,7 +2380,8 @@ def local_mul_switch_sink(node):
if i.owner and i.owner.op == T.switch: if i.owner and i.owner.op == T.switch:
switch = i.owner switch = i.owner
try: try:
if get_scalar_constant_value(switch.inputs[1]) == 0.: if (isinstance(switch.inputs[0], Constant) and
get_scalar_constant_value(switch.inputs[1]) == 0.):
listmul = node.inputs[:idx] + node.inputs[idx + 1:] listmul = node.inputs[:idx] + node.inputs[idx + 1:]
fct = [T.switch(switch.inputs[0], 0, fct = [T.switch(switch.inputs[0], 0,
T.mul(*(listmul + [switch.inputs[2]])))] T.mul(*(listmul + [switch.inputs[2]])))]
...@@ -2387,7 +2391,8 @@ def local_mul_switch_sink(node): ...@@ -2387,7 +2391,8 @@ def local_mul_switch_sink(node):
except NotScalarConstantError: except NotScalarConstantError:
pass pass
try: try:
if get_scalar_constant_value(switch.inputs[2]) == 0.: if (isinstance(switch.inputs[2], Constant) and
get_scalar_constant_value(switch.inputs[2]) == 0.):
listmul = node.inputs[:idx] + node.inputs[idx + 1:] listmul = node.inputs[:idx] + node.inputs[idx + 1:]
fct = [T.switch(switch.inputs[0], fct = [T.switch(switch.inputs[0],
T.mul(*(listmul + [switch.inputs[1]])), 0)] T.mul(*(listmul + [switch.inputs[1]])), 0)]
...@@ -3784,7 +3789,7 @@ def local_abs_merge(node): ...@@ -3784,7 +3789,7 @@ def local_abs_merge(node):
for i in node.inputs: for i in node.inputs:
if i.owner and i.owner.op == T.abs_: if i.owner and i.owner.op == T.abs_:
inputs.append(i.owner.inputs[0]) inputs.append(i.owner.inputs[0])
else: elif isinstance(i, Constant):
try: try:
const = get_scalar_constant_value(i) const = get_scalar_constant_value(i)
except NotScalarConstantError: except NotScalarConstantError:
...@@ -3792,6 +3797,8 @@ def local_abs_merge(node): ...@@ -3792,6 +3797,8 @@ def local_abs_merge(node):
if not (const >= 0).all(): if not (const >= 0).all():
return False return False
inputs.append(i) inputs.append(i)
else:
return False
return [T.abs_(T.mul(*inputs))] return [T.abs_(T.mul(*inputs))]
if node.op == T.true_div and sum([i.owner.op == T.abs_ for i in if node.op == T.true_div and sum([i.owner.op == T.abs_ for i in
node.inputs if i.owner]) == 2: node.inputs if i.owner]) == 2:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论