提交 ce1eeab9 authored 作者: abergeron's avatar abergeron

Merge pull request #1797 from nouiz/fast_opt

Make the slow scan test fast!
......@@ -1077,6 +1077,7 @@ class FunctionMaker(object):
self.mode = mode
self.accept_inplace = accept_inplace
self.function_builder = function_builder
self.on_unused_input = on_unused_input # Used only for the pickling
self.required = [(i.value is None) for i in self.inputs]
self.refeed = [
......@@ -1215,6 +1216,7 @@ def _pickle_FunctionMaker(self):
accept_inplace=self.accept_inplace,
function_builder=self.function_builder,
profile=self.profile,
on_unused_input=self.on_unused_input,
)
return (_constructor_FunctionMaker, (kwargs,))
......
差异被折叠。
......@@ -1581,7 +1581,7 @@ def local_upcast_elemwise_constant_inputs(node):
else:
try:
# works only for scalars
cval_i = get_scalar_constant_value(i)
cval_i = get_scalar_constant_value(i, elemwise=False)
if all(i.broadcastable):
new_inputs.append(T.shape_padleft(
T.cast(cval_i, output_dtype),
......@@ -2327,7 +2327,7 @@ def local_remove_switch_const_cond(node):
"""
if (isinstance(node.op, T.Elemwise) and
isinstance(node.op.scalar_op, scalar.basic.Switch)):
cond = T.extract_constant(node.inputs[0])
cond = T.extract_constant(node.inputs[0], elemwise=False)
if type(cond) is numpy.ndarray and cond.ndim == 0:
if cond == 0:
out = node.inputs[2]
......@@ -2377,7 +2377,8 @@ def local_mul_switch_sink(node):
if i.owner and i.owner.op == T.switch:
switch = i.owner
try:
if get_scalar_constant_value(switch.inputs[1]) == 0.:
if (isinstance(switch.inputs[0], Constant) and
get_scalar_constant_value(switch.inputs[1]) == 0.):
listmul = node.inputs[:idx] + node.inputs[idx + 1:]
fct = [T.switch(switch.inputs[0], 0,
T.mul(*(listmul + [switch.inputs[2]])))]
......@@ -2387,7 +2388,8 @@ def local_mul_switch_sink(node):
except NotScalarConstantError:
pass
try:
if get_scalar_constant_value(switch.inputs[2]) == 0.:
if (isinstance(switch.inputs[2], Constant) and
get_scalar_constant_value(switch.inputs[2]) == 0.):
listmul = node.inputs[:idx] + node.inputs[idx + 1:]
fct = [T.switch(switch.inputs[0],
T.mul(*(listmul + [switch.inputs[1]])), 0)]
......@@ -3784,7 +3786,7 @@ def local_abs_merge(node):
for i in node.inputs:
if i.owner and i.owner.op == T.abs_:
inputs.append(i.owner.inputs[0])
else:
elif isinstance(i, Constant):
try:
const = get_scalar_constant_value(i)
except NotScalarConstantError:
......@@ -3792,6 +3794,8 @@ def local_abs_merge(node):
if not (const >= 0).all():
return False
inputs.append(i)
else:
return False
return [T.abs_(T.mul(*inputs))]
if node.op == T.true_div and sum([i.owner.op == T.abs_ for i in
node.inputs if i.owner]) == 2:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论