提交 b8cd1229 authored 作者: James Bergstra's avatar James Bergstra

rewrite of local_mul_specialize

上级 6080bef5
......@@ -1353,42 +1353,49 @@ def local_mul_specialize(node):
neg = False
new_inputs = []
for input in node.inputs:
# remove any neg arguments
while input.owner and input.owner.op == T.neg:
neg ^= True
input = input.owner.inputs[0]
# remove special case arguments of 1, -1 or 0
y = local_mul_canonizer.get_constant(input)
if N.all(y == 1.0):
continue
elif N.all(y == -1.0):
neg ^= True #toggles
elif N.all(y == 0.0):
return fill_chain(input)
return [T.alloc(numpy.asarray(0, dtype=node.outputs[0].dtype),
*node.env.shape_feature.shape_of[node.outputs[0]])]
else:
new_inputs.append(input)
if len(new_inputs) < len(node.inputs):
if len(new_inputs) == 0:
if neg:
newval = -y.flatten()[0]
if new_inputs != node.inputs:
if new_inputs:
if len(new_inputs) == 1:
if neg:
msg = -new_inputs[0]
else:
msg = new_inputs[0]
return fill_chain(msg)
else:
newval = y.flatten()[0]
#newval = -y.flatten()[0] if neg else y.flatten()[0]
return fill_chain(T.TensorConstant(T.TensorType(dtype=node.outputs[0].type.dtype,
broadcastable = [True] * node.outputs[0].ndim), N.asarray(newval)))
if neg:
msg = -T.mul(*new_inputs)
else:
msg = T.mul(*new_inputs)
if len(new_inputs) == 1:
if neg:
msg = -new_inputs[0]
else:
msg = new_inputs[0]
return fill_chain(msg)
# return fill_chain(-new_inputs[0] if neg else new_inputs[0])
return [T.alloc(T.cast(msg, node.outputs[0].dtype),
*node.env.shape_feature.shape_of[node.outputs[0]])]
else:
if neg:
msg = -T.mul(*new_inputs)
# return output's worth of -1
return [T.alloc(numpy.asarray(-1, dtype=node.outputs[0].dtype),
*node.env.shape_feature.shape_of[node.outputs[0]])]
else:
msg = T.mul(*new_inputs)
# return output's worth of 1
return [T.alloc(numpy.asarray(1, dtype=node.outputs[0].dtype),
*node.env.shape_feature.shape_of[node.outputs[0]])]
#return fill_chain(-T.mul(*new_inputs) if neg else \
# T.mul(*new_inputs))
else:
return False
register_specialize(local_mul_specialize)
@gof.local_optimizer([T.add])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论