提交 ce1eeab9 authored 作者: abergeron's avatar abergeron

Merge pull request #1797 from nouiz/fast_opt

Make the slow scan test fast!
...@@ -1077,6 +1077,7 @@ class FunctionMaker(object): ...@@ -1077,6 +1077,7 @@ class FunctionMaker(object):
self.mode = mode self.mode = mode
self.accept_inplace = accept_inplace self.accept_inplace = accept_inplace
self.function_builder = function_builder self.function_builder = function_builder
self.on_unused_input = on_unused_input # Used only for the pickling
self.required = [(i.value is None) for i in self.inputs] self.required = [(i.value is None) for i in self.inputs]
self.refeed = [ self.refeed = [
...@@ -1215,6 +1216,7 @@ def _pickle_FunctionMaker(self): ...@@ -1215,6 +1216,7 @@ def _pickle_FunctionMaker(self):
accept_inplace=self.accept_inplace, accept_inplace=self.accept_inplace,
function_builder=self.function_builder, function_builder=self.function_builder,
profile=self.profile, profile=self.profile,
on_unused_input=self.on_unused_input,
) )
return (_constructor_FunctionMaker, (kwargs,)) return (_constructor_FunctionMaker, (kwargs,))
......
差异被折叠。
...@@ -1581,7 +1581,7 @@ def local_upcast_elemwise_constant_inputs(node): ...@@ -1581,7 +1581,7 @@ def local_upcast_elemwise_constant_inputs(node):
else: else:
try: try:
# works only for scalars # works only for scalars
cval_i = get_scalar_constant_value(i) cval_i = get_scalar_constant_value(i, elemwise=False)
if all(i.broadcastable): if all(i.broadcastable):
new_inputs.append(T.shape_padleft( new_inputs.append(T.shape_padleft(
T.cast(cval_i, output_dtype), T.cast(cval_i, output_dtype),
...@@ -2327,7 +2327,7 @@ def local_remove_switch_const_cond(node): ...@@ -2327,7 +2327,7 @@ def local_remove_switch_const_cond(node):
""" """
if (isinstance(node.op, T.Elemwise) and if (isinstance(node.op, T.Elemwise) and
isinstance(node.op.scalar_op, scalar.basic.Switch)): isinstance(node.op.scalar_op, scalar.basic.Switch)):
cond = T.extract_constant(node.inputs[0]) cond = T.extract_constant(node.inputs[0], elemwise=False)
if type(cond) is numpy.ndarray and cond.ndim == 0: if type(cond) is numpy.ndarray and cond.ndim == 0:
if cond == 0: if cond == 0:
out = node.inputs[2] out = node.inputs[2]
...@@ -2377,7 +2377,8 @@ def local_mul_switch_sink(node): ...@@ -2377,7 +2377,8 @@ def local_mul_switch_sink(node):
if i.owner and i.owner.op == T.switch: if i.owner and i.owner.op == T.switch:
switch = i.owner switch = i.owner
try: try:
if get_scalar_constant_value(switch.inputs[1]) == 0.: if (isinstance(switch.inputs[0], Constant) and
get_scalar_constant_value(switch.inputs[1]) == 0.):
listmul = node.inputs[:idx] + node.inputs[idx + 1:] listmul = node.inputs[:idx] + node.inputs[idx + 1:]
fct = [T.switch(switch.inputs[0], 0, fct = [T.switch(switch.inputs[0], 0,
T.mul(*(listmul + [switch.inputs[2]])))] T.mul(*(listmul + [switch.inputs[2]])))]
...@@ -2387,7 +2388,8 @@ def local_mul_switch_sink(node): ...@@ -2387,7 +2388,8 @@ def local_mul_switch_sink(node):
except NotScalarConstantError: except NotScalarConstantError:
pass pass
try: try:
if get_scalar_constant_value(switch.inputs[2]) == 0.: if (isinstance(switch.inputs[2], Constant) and
get_scalar_constant_value(switch.inputs[2]) == 0.):
listmul = node.inputs[:idx] + node.inputs[idx + 1:] listmul = node.inputs[:idx] + node.inputs[idx + 1:]
fct = [T.switch(switch.inputs[0], fct = [T.switch(switch.inputs[0],
T.mul(*(listmul + [switch.inputs[1]])), 0)] T.mul(*(listmul + [switch.inputs[1]])), 0)]
...@@ -3784,7 +3786,7 @@ def local_abs_merge(node): ...@@ -3784,7 +3786,7 @@ def local_abs_merge(node):
for i in node.inputs: for i in node.inputs:
if i.owner and i.owner.op == T.abs_: if i.owner and i.owner.op == T.abs_:
inputs.append(i.owner.inputs[0]) inputs.append(i.owner.inputs[0])
else: elif isinstance(i, Constant):
try: try:
const = get_scalar_constant_value(i) const = get_scalar_constant_value(i)
except NotScalarConstantError: except NotScalarConstantError:
...@@ -3792,6 +3794,8 @@ def local_abs_merge(node): ...@@ -3792,6 +3794,8 @@ def local_abs_merge(node):
if not (const >= 0).all(): if not (const >= 0).all():
return False return False
inputs.append(i) inputs.append(i)
else:
return False
return [T.abs_(T.mul(*inputs))] return [T.abs_(T.mul(*inputs))]
if node.op == T.true_div and sum([i.owner.op == T.abs_ for i in if node.op == T.true_div and sum([i.owner.op == T.abs_ for i in
node.inputs if i.owner]) == 2: node.inputs if i.owner]) == 2:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论