提交 e8e4d2ef authored 作者: James Bergstra's avatar James Bergstra

adding local_add_specialize that removes any +0 from the graph.

上级 eda8fc3c
......@@ -850,7 +850,47 @@ def local_mul_specialize(node):
return False
register_specialize(local_mul_specialize)
@gof.local_optimizer([T.mul])
def local_add_specialize(node):
def fill_chain(v):
return _fill_chain(v, node.inputs)
def get_constant_through_fills_and_subtensors(v):
if v.owner is not None:
if v.owner.op == T.fill:
assert len(v.owner.inputs) == 2
return get_constant_through_fills_and_subtensors(v.owner.inputs[1])
if isinstance(v.owner.op, T.DimShuffle):
assert len(v.owner.inputs) == 1
return get_constant_through_fills_and_subtensors(v.owner.inputs[0])
elif hasattr(v, 'data'):
return v.data
else:
return v
#here, we are past the point of canonicalization, so we don't want to put in un-necessary fills.
if node.op == T.add:
new_inputs = []
for input in node.inputs:
y = get_constant_through_fills_and_subtensors(input)
if N.all(y == 0.0):
continue
else:
new_inputs.append(input)
if len(new_inputs) < len(node.inputs):
if len(new_inputs) == 0:
#we got rid of the entire expression!
return fill_chain(T.TensorConstant(T.TensorType(dtype=node.outputs[0].type.dtype,
broadcastable = [True] * node.outputs[0].ndim), N.asarray(0)))
if len(new_inputs) == 1:
return fill_chain(new_inputs[0])
else:
return fill_chain(T.add(*new_inputs))
else:
return False
register_specialize(local_add_specialize)
# neg_to_mul = out2in(gof.LocalOptGroup(local_neg_to_mul))
# mul_to_neg = out2in(gof.LocalOptGroup(local_mul_to_neg))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论