fixed grad signatures in tensor.py

上级 0976892d
...@@ -90,6 +90,9 @@ class _testCase_dot(unittest.TestCase): ...@@ -90,6 +90,9 @@ class _testCase_dot(unittest.TestCase):
self.failUnless(z.shape == (5,2)) self.failUnless(z.shape == (5,2))
self.failUnless(type(z) is mtype) self.failUnless(type(z) is mtype)
def test_missing(self):
raise NotImplementedError('tests commented out')
# def test_basic1(self): # def test_basic1(self):
# """dot: sparse left""" # """dot: sparse left"""
# a = numpy.asarray([[1, 0, 3, 0, 5], [0, 0, -2, 0, 0]], # a = numpy.asarray([[1, 0, 3, 0, 5], [0, 0, -2, 0, 0]],
......
...@@ -173,6 +173,7 @@ def broadcast(scalar_opclass, name, inplace_versions = True): ...@@ -173,6 +173,7 @@ def broadcast(scalar_opclass, name, inplace_versions = True):
return C, c return C, c
class Argmax(Op): class Argmax(Op):
"""Calculate the max and argmax over a given axis"""
nin=2 # tensor, axis nin=2 # tensor, axis
nout=2 # max val, max idx nout=2 # max val, max idx
E_axis = 'invalid axis' E_axis = 'invalid axis'
...@@ -238,8 +239,8 @@ class TransposeInplace(_Op, Viewer): ...@@ -238,8 +239,8 @@ class TransposeInplace(_Op, Viewer):
return [rval] return [rval]
def impl(self, x): def impl(self, x):
return x.T #numpy's transpose return x.T #numpy's transpose
def grad(self, x, gz): def grad(self, (x,), (gz),):
return transpose(gz) return transpose(gz),
def c_code(self, (x, ), (z, ), sub): def c_code(self, (x, ), (z, ), sub):
return """ return """
...@@ -306,7 +307,7 @@ class Subtensor(Op, Viewer): ...@@ -306,7 +307,7 @@ class Subtensor(Op, Viewer):
self.outputs[0].data = x.__getitem__(c[0]) self.outputs[0].data = x.__getitem__(c[0])
else: else:
self.outputs[0].data = x.__getitem__(c) self.outputs[0].data = x.__getitem__(c)
def grad(x, gz): def grad(self, (x,), (gz,)):
# - option: allocate a potentially large matrix of zeros, and fill in # - option: allocate a potentially large matrix of zeros, and fill in
# the appropriate elements from gz # the appropriate elements from gz
# - option: return a sparse matrix # - option: return a sparse matrix
...@@ -350,10 +351,7 @@ class Dot(_Op): ...@@ -350,10 +351,7 @@ class Dot(_Op):
return [self.broadcastable_rule(bx,by)] return [self.broadcastable_rule(bx,by)]
def impl(self, x, y): def impl(self, x, y):
return numpy.dot(x, y) return numpy.dot(x, y)
def grad(self, (x, y), gz): def grad(self, (x, y), (gz,)):
"""
@todo Shouldn't it be (gz,) ? -jpt
"""
return dot(gz, y.T), dot(x.T, gz) return dot(gz, y.T), dot(x.T, gz)
if 0: if 0:
def c_support_code(self): def c_support_code(self):
...@@ -414,7 +412,7 @@ class Gemm(_Op): ...@@ -414,7 +412,7 @@ class Gemm(_Op):
z *= b z *= b
z += a * numpy.dot(x,y) z += a * numpy.dot(x,y)
return z return z
def grad(self, (z, a, x, y, b), gz): def grad(self, (z, a, x, y, b), (gz,)):
raise NotImplementedError() raise NotImplementedError()
def c_support_code(self): def c_support_code(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论