提交 ebc4e2a1 authored 作者: fsavard's avatar fsavard

Removed old code in comments for tensor.Prod gradient.

上级 2730a0a7
...@@ -1260,77 +1260,3 @@ class ProdWithoutZeros(CAReduce): ...@@ -1260,77 +1260,3 @@ class ProdWithoutZeros(CAReduce):
else: else:
return "ProdWithoutZeros{%s}" % ", ".join(map(str, self.axis)) return "ProdWithoutZeros{%s}" % ", ".join(map(str, self.axis))
"""
class ProdGrad(Op):
'''
grad is more complex than it might seem, as the cases where:
- all inputs in a group are non-zero
- only one input in the group is zero
- more than one input the group is zero
must be handled separately
What follows is a very restricted and inefficient way to compute this
gradient (inefficient in that it's coded in Python and it doesn't share
computations which might be shared with the computation of the products
themselves; restricted in that it only handles a very simple case).
'''
def __init__(self,axis=1):
self.axis = axis
def __eq__(self, other):
return type(self) == type(other) and self.axis == other.axis
def __hash__(self):
return hash(type(self)) ^ hash(self.axis)
def __str__(self):
return '%s{%s}' % (self.__class__.__name__, self.axis)
def make_node(self, prod_in, gz):
# make_node should only be called by the grad function of DownsampleFactorMax,
# so these asserts should not fail.
assert isinstance(prod_in, Variable) and prod_in.ndim==2
return Apply(self, [prod_in, gz], [prod_in.type()])
def perform(self, node, (prod_in, gz), (gx_stg,)):
gx = numpy.zeros_like(prod_in)
# only supported case, for the moment,
# therefore is when we have a 2d tensor (matrix)
# where we take the prod over rows
if self.axis is None or self.axis != 1 and self.axis != (1,) and self.axis != [1,]:
raise NotImplementedError("tensor.Prod default gradient only handles the case where the input is a matrix and product is taken over rows. If you're sure the input does not contain any zeros, a simpler and more efficient gradient may be used; set the Prod constructor option 'no_zeros_in_input'.")
for row_idx in range(prod_in.shape[0]):
row = prod_in[row_idx,:]
zero_count = 0
first_zero_position = 0
prod = gz[row_idx]
prod_without_zeros = prod
for el_idx, el in enumerate(row):
if el == 0.:
zero_count += 1
first_zero_position = el_idx
prod = 0.
else:
prod *= el
prod_without_zeros *= el
if zero_count == 0:
gx[row_idx, :] = prod / prod_in[row_idx,:]
elif zero_count == 1:
# all elements (but the one = 0) get a gradient value of 0
gx[row_idx, first_zero_position] = prod_without_zeros
else:
# case with more than 1 zeros, where everything ends up 0.
# just leave it the way it was initialized (with zeros)
pass
gx_stg[0] = gx
"""
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论