提交 f8012657 authored 作者: David Warde-Farley's avatar David Warde-Farley

Comment TileGrad and raise NotImplementedError

TileGrad implements only one very narrow special case, and it implements it wrongly. Raise NotImplementedError instead of using this botched implementation.
上级 96363cb9
......@@ -4685,25 +4685,25 @@ def flatten(x, outdim=1):
return Flatten(outdim)(x)
class TileGrad(Op):
"""
Calculates the gradient of the Tile Op.
"""
#this is so weird, I can't think of how to make this a general thing.
def make_node(self, x, reps, g_out):
return gof.Apply(self, [x, reps, g_out], [x.type()])
def perform(self, node, inp, out):
x, reps, g_out = inp
gx, = out
xsh = x.shape
if len(reps) == 2 and reps[1] == 1 and len(x.shape) == 1:
gx[0] = numpy.sum(g_out, axis=0)
else:
raise NotImplementedError('x.shape, reps combination not '
'supported', (x.shape, reps))
tilegrad = TileGrad()
# class TileGrad(Op):
# """
# Calculates the gradient of the Tile Op.
# """
# #this is so weird, I can't think of how to make this a general thing.
# def make_node(self, x, reps, g_out):
# return gof.Apply(self, [x, reps, g_out], [x.type()])
#
# def perform(self, node, inp, out):
# x, reps, g_out = inp
# gx, = out
# xsh = x.shape
# if len(reps) == 2 and reps[1] == 1 and len(x.shape) == 1:
# gx[0] = numpy.sum(g_out, axis=0)
# else:
# raise NotImplementedError('x.shape, reps combination not '
# 'supported', (x.shape, reps))
#
# tilegrad = TileGrad()
class Tile(Op):
......@@ -4742,7 +4742,8 @@ class Tile(Op):
def grad(self, inp, grads):
x, reps = inp
g_out, = grads
return [tilegrad(x, reps, g_out), None]
# return [tilegrad(x, reps, g_out), None]
raise NotImplementedError()
def tile(x, reps, ndim=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论