fixed grad signatures in tensor.py

上级 0976892d
......@@ -90,6 +90,9 @@ class _testCase_dot(unittest.TestCase):
self.failUnless(z.shape == (5,2))
self.failUnless(type(z) is mtype)
def test_missing(self):
raise NotImplementedError('tests commented out')
# def test_basic1(self):
# """dot: sparse left"""
# a = numpy.asarray([[1, 0, 3, 0, 5], [0, 0, -2, 0, 0]],
......
......@@ -173,6 +173,7 @@ def broadcast(scalar_opclass, name, inplace_versions = True):
return C, c
class Argmax(Op):
"""Calculate the max and argmax over a given axis"""
nin=2 # tensor, axis
nout=2 # max val, max idx
E_axis = 'invalid axis'
......@@ -238,8 +239,8 @@ class TransposeInplace(_Op, Viewer):
return [rval]
def impl(self, x):
return x.T #numpy's transpose
def grad(self, x, gz):
return transpose(gz)
def grad(self, (x,), (gz),):
return transpose(gz),
def c_code(self, (x, ), (z, ), sub):
return """
......@@ -306,7 +307,7 @@ class Subtensor(Op, Viewer):
self.outputs[0].data = x.__getitem__(c[0])
else:
self.outputs[0].data = x.__getitem__(c)
def grad(x, gz):
def grad(self, (x,), (gz,)):
# - option: allocate a potentially large matrix of zeros, and fill in
# the appropriate elements from gz
# - option: return a sparse matrix
......@@ -350,10 +351,7 @@ class Dot(_Op):
return [self.broadcastable_rule(bx,by)]
def impl(self, x, y):
return numpy.dot(x, y)
def grad(self, (x, y), gz):
"""
@todo Shouldn't it be (gz,) ? -jpt
"""
def grad(self, (x, y), (gz,)):
return dot(gz, y.T), dot(x.T, gz)
if 0:
def c_support_code(self):
......@@ -414,7 +412,7 @@ class Gemm(_Op):
z *= b
z += a * numpy.dot(x,y)
return z
def grad(self, (z, a, x, y, b), gz):
def grad(self, (z, a, x, y, b), (gz,)):
raise NotImplementedError()
def c_support_code(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论