提交 e767d723 authored 作者: James Bergstra's avatar James Bergstra

Cleaned up the logic in one of the optimizations in nnet.py

上级 65e376b2
...@@ -403,10 +403,15 @@ def local_softmax_with_bias(node): ...@@ -403,10 +403,15 @@ def local_softmax_with_bias(node):
non_vectors = [] non_vectors = []
for x_in in x.owner.inputs: for x_in in x.owner.inputs:
if list(x_in.type.broadcastable) == [True, False]: if list(x_in.type.broadcastable) == [True, False]:
if x_in.owner and isinstance(x_in.owner.op, tensor.DimShuffle): print isinstance(x_in.owner.op, tensor.DimShuffle)
assert len(x_in.owner.inputs)==1 #since specialization comes relatively late in optimization,
# we don't want to put in extra DimShuffles un-necessarily.
if x_in.owner and isinstance(x_in.owner.op, tensor.DimShuffle)\
and list(x_in.owner.inputs[0].type.broadcastable)==[False]:
# cut out the DimShuffle that was broadcasting a vector
vectors.append(x_in.owner.inputs[0]) vectors.append(x_in.owner.inputs[0])
else: else:
# insert an extra DimShuffle to correct the old one
vectors.append(tensor.DimShuffle((True, False), (1,))(x_in)) vectors.append(tensor.DimShuffle((True, False), (1,))(x_in))
else: else:
non_vectors.append(x_in) non_vectors.append(x_in)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论