if self.axis is None or self.axis != 1 and self.axis != (1,) and self.axis != [1,]:
raise NotImplementedError("tensor.Prod default gradient only handles the case where the input is a matrix and product is taken over rows. If you're sure the input does not contain any zeros, a simpler and more efficient gradient may be used; set the Prod constructor option 'no_zeros_in_input'.")
for row_idx in range(prod_in.shape[0]):
row = prod_in[row_idx,:]
zero_count = 0
first_zero_position = 0
prod = gz[row_idx]
prod_without_zeros = prod
for el_idx, el in enumerate(row):
if el == 0.:
zero_count += 1
first_zero_position = el_idx
prod = 0.
else:
prod *= el
prod_without_zeros *= el
if zero_count == 0:
gx[row_idx, :] = prod / prod_in[row_idx,:]
elif zero_count == 1:
# all elements (but the one = 0) get a gradient value of 0