提交 b38ffe57 authored 作者: James Bergstra's avatar James Bergstra

merge

...@@ -1353,42 +1353,49 @@ def local_mul_specialize(node): ...@@ -1353,42 +1353,49 @@ def local_mul_specialize(node):
neg = False neg = False
new_inputs = [] new_inputs = []
for input in node.inputs: for input in node.inputs:
# remove any neg arguments
while input.owner and input.owner.op == T.neg:
neg ^= True
input = input.owner.inputs[0]
# remove special case arguments of 1, -1 or 0
y = local_mul_canonizer.get_constant(input) y = local_mul_canonizer.get_constant(input)
if N.all(y == 1.0): if N.all(y == 1.0):
continue continue
elif N.all(y == -1.0): elif N.all(y == -1.0):
neg ^= True #toggles neg ^= True #toggles
elif N.all(y == 0.0): elif N.all(y == 0.0):
return fill_chain(input) return [T.alloc(numpy.asarray(0, dtype=node.outputs[0].dtype),
*node.env.shape_feature.shape_of[node.outputs[0]])]
else: else:
new_inputs.append(input) new_inputs.append(input)
if len(new_inputs) < len(node.inputs):
if len(new_inputs) == 0: if new_inputs != node.inputs:
if neg: if new_inputs:
newval = -y.flatten()[0] if len(new_inputs) == 1:
if neg:
msg = -new_inputs[0]
else:
msg = new_inputs[0]
return fill_chain(msg)
else: else:
newval = y.flatten()[0] if neg:
#newval = -y.flatten()[0] if neg else y.flatten()[0] msg = -T.mul(*new_inputs)
return fill_chain(T.TensorConstant(T.TensorType(dtype=node.outputs[0].type.dtype, else:
broadcastable = [True] * node.outputs[0].ndim), N.asarray(newval))) msg = T.mul(*new_inputs)
if len(new_inputs) == 1: return [T.alloc(T.cast(msg, node.outputs[0].dtype),
if neg: *node.env.shape_feature.shape_of[node.outputs[0]])]
msg = -new_inputs[0]
else:
msg = new_inputs[0]
return fill_chain(msg)
# return fill_chain(-new_inputs[0] if neg else new_inputs[0])
else: else:
if neg: if neg:
msg = -T.mul(*new_inputs) # return output's worth of -1
return [T.alloc(numpy.asarray(-1, dtype=node.outputs[0].dtype),
*node.env.shape_feature.shape_of[node.outputs[0]])]
else: else:
msg = T.mul(*new_inputs) # return output's worth of 1
return [T.alloc(numpy.asarray(1, dtype=node.outputs[0].dtype),
*node.env.shape_feature.shape_of[node.outputs[0]])]
#return fill_chain(-T.mul(*new_inputs) if neg else \
# T.mul(*new_inputs))
else:
return False
register_specialize(local_mul_specialize) register_specialize(local_mul_specialize)
@gof.local_optimizer([T.add]) @gof.local_optimizer([T.add])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论