提交 61dd9af7 authored 作者: fsavard's avatar fsavard

Modified theano.prod gradient based on James' suggestion

上级 6dbad245
...@@ -1173,20 +1173,25 @@ class Prod(CAReduce): ...@@ -1173,20 +1173,25 @@ class Prod(CAReduce):
new_dims.append(i) new_dims.append(i)
i += 1 i += 1
p_gz = theano.tensor.mul(prod_out, gz)
p_gz = DimShuffle(p_gz.type.broadcastable, new_dims)(p_gz)
return [Elemwise(scalar.true_div)(p_gz, x)]
# fill a matrix with the same shape as x by broadcasting # fill a matrix with the same shape as x by broadcasting
# values taken from gz, which has the same shape as the output # values taken from gz, which has the same shape as the output
# of prod(). # of prod().
gz_filled_x = Elemwise(scalar.second)(x, #gz_filled_x = Elemwise(scalar.second)(x,
DimShuffle(gz.type.broadcastable, new_dims)(gz)) # DimShuffle(gz.type.broadcastable, new_dims)(gz))
# do the same with the output of prod, by broadcasting along # do the same with the output of prod, by broadcasting along
# axises where the product was taken # axises where the product was taken
prod_out_filled_x = Elemwise(scalar.second)(x, #prod_out_filled_x = Elemwise(scalar.second)(x,
DimShuffle(prod_out.type.broadcastable, # DimShuffle(prod_out.type.broadcastable,
new_dims)(prod_out)) # new_dims)(prod_out))
return [theano.tensor.mul(gz_filled_x, #return [theano.tensor.mul(gz_filled_x,
theano.tensor.true_div(prod_out_filled_x, x))] # theano.tensor.true_div(prod_out_filled_x, x))]
#else: #else:
# raise NotImplementedError('Will be implemented shortly') # raise NotImplementedError('Will be implemented shortly')
......
...@@ -301,5 +301,6 @@ class test_Prod(unittest.TestCase): ...@@ -301,5 +301,6 @@ class test_Prod(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
#suite = unittest.TestSuite([test_Prod('test_prod_grad')]) #suite = unittest.TestSuite([test_Prod('test_prod_grad')])
#suite.addTest(test_Prod('test_verify_grad'))
#unittest.TextTestRunner().run(suite) #unittest.TextTestRunner().run(suite)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论