提交 74551881 authored 作者: Reyhane Askari's avatar Reyhane Askari

removed useless code for removing previous redundant sum and dimshuffle

上级 167fe839
...@@ -1728,21 +1728,6 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): ...@@ -1728,21 +1728,6 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
else: else:
return return
# If the arg to softmax is a broadcasted vector, d_sm has the form:
# DimShuffle{x,0}(Sum{0}(...))
# we consider what's inside of the sum instead
vector_softmax = False
if d_sm.owner and isinstance(d_sm.owner.op, tensor.DimShuffle):
ds_op = d_sm.owner.op
if ds_op.input_broadcastable == (False,) and ds_op.new_order == ('x', 0):
maybe_sum = d_sm.owner.inputs[0]
if maybe_sum.owner and isinstance(maybe_sum.owner.op, tensor.Sum):
if sm.broadcastable == (True, False)\
and maybe_sum.owner.op.axis == (0,)\
and len(maybe_sum.owner.inputs) == 1:
vector_softmax = True
d_sm = maybe_sum.owner.inputs[0]
# Two cases are supported: # Two cases are supported:
# 1. AdvancedIncSubtensor( # 1. AdvancedIncSubtensor(
# zeros_like(softmax(x)), # zeros_like(softmax(x)),
...@@ -1885,8 +1870,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): ...@@ -1885,8 +1870,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
# Check z is zeros_like(log(sm)) # Check z is zeros_like(log(sm))
if not _is_const(z, 0): if not _is_const(z, 0):
return return
if z.broadcastable != (False, False): if z.broadcastable not in [(False, False), (True, False)]:
if not (vector_softmax and z.broadcastable == (True, False)):
return return
# here we know that we are incrementing a matrix of zeros # here we know that we are incrementing a matrix of zeros
# (or a broadcasted vector). # (or a broadcasted vector).
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论