提交 b8db858d authored 作者: abalkin's avatar abalkin

Implemented x.dot(y) for tensor variables.

上级 51290164
......@@ -1732,6 +1732,8 @@ class _tensor_py_operators:
def __rdot__(right, left):
return dot(left, right)
dot = __dot__
def sum(self, axis=None, dtype=None, keepdims=False):
"""See `theano.tensor.sum`"""
return sum(self, axis=axis, dtype=dtype, keepdims=keepdims)
......
......@@ -7001,6 +7001,18 @@ class TestInferShape(utt.InferShapeTester):
[tile(adtens4, aivec_val, ndim)],
[adtens4_val], Tile)
class TestTensorInstanceMethods(unittest.TestCase):
def setUp(self):
self.vars = matrices('X', 'Y')
self.vals = [rand(2,2),rand(2,2)]
def test_dot(self):
X, Y = self.vars
x, y = self.vals
self.assertTrue(numpy.all(x.dot(y) == X.dot(Y).eval({X: x, Y: y})))
Z = X.dot(Y)
z = x.dot(y)
self.assertTrue(numpy.all(x.dot(z) == X.dot(Z).eval({X: x, Z: z})))
if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论