提交 f8012657 authored 作者: David Warde-Farley's avatar David Warde-Farley

Comment TileGrad and raise NotImplementedError

TileGrad implements only one very narrow special case, and it implements it wrongly. Raise NotImplementedError instead of using this botched implementation.
上级 96363cb9
...@@ -4685,25 +4685,25 @@ def flatten(x, outdim=1): ...@@ -4685,25 +4685,25 @@ def flatten(x, outdim=1):
return Flatten(outdim)(x) return Flatten(outdim)(x)
class TileGrad(Op): # class TileGrad(Op):
""" # """
Calculates the gradient of the Tile Op. # Calculates the gradient of the Tile Op.
""" # """
#this is so weird, I can't think of how to make this a general thing. # #this is so weird, I can't think of how to make this a general thing.
def make_node(self, x, reps, g_out): # def make_node(self, x, reps, g_out):
return gof.Apply(self, [x, reps, g_out], [x.type()]) # return gof.Apply(self, [x, reps, g_out], [x.type()])
#
def perform(self, node, inp, out): # def perform(self, node, inp, out):
x, reps, g_out = inp # x, reps, g_out = inp
gx, = out # gx, = out
xsh = x.shape # xsh = x.shape
if len(reps) == 2 and reps[1] == 1 and len(x.shape) == 1: # if len(reps) == 2 and reps[1] == 1 and len(x.shape) == 1:
gx[0] = numpy.sum(g_out, axis=0) # gx[0] = numpy.sum(g_out, axis=0)
else: # else:
raise NotImplementedError('x.shape, reps combination not ' # raise NotImplementedError('x.shape, reps combination not '
'supported', (x.shape, reps)) # 'supported', (x.shape, reps))
#
tilegrad = TileGrad() # tilegrad = TileGrad()
class Tile(Op): class Tile(Op):
...@@ -4742,7 +4742,8 @@ class Tile(Op): ...@@ -4742,7 +4742,8 @@ class Tile(Op):
def grad(self, inp, grads): def grad(self, inp, grads):
x, reps = inp x, reps = inp
g_out, = grads g_out, = grads
return [tilegrad(x, reps, g_out), None] # return [tilegrad(x, reps, g_out), None]
raise NotImplementedError()
def tile(x, reps, ndim=None): def tile(x, reps, ndim=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论