提交 b8db858d authored 作者: abalkin's avatar abalkin

Implemented x.dot(y) for tensor variables.

上级 51290164
...@@ -1732,6 +1732,8 @@ class _tensor_py_operators: ...@@ -1732,6 +1732,8 @@ class _tensor_py_operators:
def __rdot__(right, left): def __rdot__(right, left):
return dot(left, right) return dot(left, right)
dot = __dot__
def sum(self, axis=None, dtype=None, keepdims=False): def sum(self, axis=None, dtype=None, keepdims=False):
"""See `theano.tensor.sum`""" """See `theano.tensor.sum`"""
return sum(self, axis=axis, dtype=dtype, keepdims=keepdims) return sum(self, axis=axis, dtype=dtype, keepdims=keepdims)
......
...@@ -7001,7 +7001,19 @@ class TestInferShape(utt.InferShapeTester): ...@@ -7001,7 +7001,19 @@ class TestInferShape(utt.InferShapeTester):
[tile(adtens4, aivec_val, ndim)], [tile(adtens4, aivec_val, ndim)],
[adtens4_val], Tile) [adtens4_val], Tile)
class TestTensorInstanceMethods(unittest.TestCase):
def setUp(self):
self.vars = matrices('X', 'Y')
self.vals = [rand(2,2),rand(2,2)]
def test_dot(self):
X, Y = self.vars
x, y = self.vals
self.assertTrue(numpy.all(x.dot(y) == X.dot(Y).eval({X: x, Y: y})))
Z = X.dot(Y)
z = x.dot(y)
self.assertTrue(numpy.all(x.dot(z) == X.dot(Z).eval({X: x, Z: z})))
if __name__ == '__main__': if __name__ == '__main__':
t = TestInferShape('setUp') t = TestInferShape('setUp')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论