提交 6b6cee3b authored 作者: Olivier Breuleux's avatar Olivier Breuleux

merge

import gof import gof
import base_tensor
import tensor
import sparse
import compile
import gradient
import opt
from base_tensor import * from base_tensor import *
from tensor import * from tensor import *
from compile import * from compile import *
#from sparse import *
from opt import * from opt import *
from gradient import * from gradient import *
...@@ -647,6 +647,7 @@ class t_gemm(unittest.TestCase): ...@@ -647,6 +647,7 @@ class t_gemm(unittest.TestCase):
def setUp(self): def setUp(self):
numpy.random.seed(44) numpy.random.seed(44)
_approx_eq.debug = 0 _approx_eq.debug = 0
Gemm.debug = False
@staticmethod @staticmethod
def _gemm(z,a,x,y,b): def _gemm(z,a,x,y,b):
...@@ -681,6 +682,7 @@ class t_gemm(unittest.TestCase): ...@@ -681,6 +682,7 @@ class t_gemm(unittest.TestCase):
cmp_linker(copy(z), a, x, y, b, gof.link.PerformLinker) cmp_linker(copy(z), a, x, y, b, gof.link.PerformLinker)
def test0a(self): def test0a(self):
Gemm.debug = True
try: try:
g = gemm([1.], 1., [1.], [1.], 1.) g = gemm([1.], 1., [1.], [1.], 1.)
except ValueError, e: except ValueError, e:
...@@ -723,5 +725,52 @@ class t_gemm(unittest.TestCase): ...@@ -723,5 +725,52 @@ class t_gemm(unittest.TestCase):
def test12(self): self.cmp(self.rand(3,4), -1.0, def test12(self): self.cmp(self.rand(3,4), -1.0,
self.rand(3,5), self.rand(5,4), -1.0) self.rand(3,5), self.rand(5,4), -1.0)
def test_destroy_map0(self):
"""test that only first input can be overwritten"""
Z = astensor(self.rand(2,2))
try:
gemm(Z, 1.0, Z, Z, 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map1(self):
"""test that only first input can be overwritten"""
Z = astensor(self.rand(2,2))
A = astensor(self.rand(2,2))
try:
gemm(Z, 1.0, A, Z.T, 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map2(self):
"""test that only first input can be overwritten"""
Z = astensor(self.rand(2,2))
A = astensor(self.rand(2,2))
try:
gemm(Z, 1.0, Z.T, A, 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map3(self):
"""test that only first input can be overwritten"""
Z = astensor(self.rand(2,2))
A = astensor(self.rand(2,2))
try:
gemm(Z, 1.0, Z, A, 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map4(self):
"""test that dot args can be aliased"""
Z = astensor(self.rand(2,2))
A = astensor(self.rand(2,2))
eval_outputs([gemm(Z, 1.0, A, A, 1.0)])
eval_outputs([gemm(Z, 1.0, A, A.T, 1.0)])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -56,7 +56,9 @@ class ResultBase(object): ...@@ -56,7 +56,9 @@ class ResultBase(object):
__slots__ = ['_role', '_data', 'state', '_name', '_hash_id'] __slots__ = ['_role', '_data', 'state', '_name', '_hash_id']
def __init__(self, role=None, name=None): def __init__(self, role=None, name=None):
self._role = role self._role = None
if role is not None:
self.role = role
self._data = [None] self._data = [None]
self.state = Empty self.state = Empty
self.name = name self.name = name
...@@ -98,6 +100,9 @@ class ResultBase(object): ...@@ -98,6 +100,9 @@ class ResultBase(object):
if _index != index: if _index != index:
raise ValueError("Result %s was already mapped to a different index." % self) raise ValueError("Result %s was already mapped to a different index." % self)
return # because _owner is owner and _index == index return # because _owner is owner and _index == index
#TODO: this doesn't work because many bits of code set the role before
# owner.outputs. Op.__init__ should do this I think. -JSB
#assert owner.outputs[index] is self
self._role = role self._role = role
role = property(__get_role, __set_role) role = property(__get_role, __set_role)
...@@ -331,6 +336,9 @@ class PythonResult(ResultBase): ...@@ -331,6 +336,9 @@ class PythonResult(ResultBase):
rval.data = copy.copy(self.data) rval.data = copy.copy(self.data)
return rval return rval
def python_result(data, **kwargs):
rval = PythonResult(**kwargs)
rval.data = data
return rval
...@@ -366,6 +366,13 @@ class Gemm(_Op): ...@@ -366,6 +366,13 @@ class Gemm(_Op):
nout=1 nout=1
E_rank = 'gemm only works for rank 2' E_rank = 'gemm only works for rank 2'
E_scalar = 'gemm requires scalar argument' E_scalar = 'gemm requires scalar argument'
E_z_uniq = 'argument z not unique in argument list'
debug = False
def __init__(self, *args, **kwargs):
_Op.__init__(self, *args, **kwargs)
z, a, x, y, b = self.inputs
if z in self.inputs[1:]:
raise ValueError(Gemm.E_z_uniq, self.inputs)
def destroy_map(self): def destroy_map(self):
return {self.out:[self.inputs[0]]} return {self.out:[self.inputs[0]]}
def propagate_broadcastable(self, bz, ba, bx, by, bb): def propagate_broadcastable(self, bz, ba, bx, by, bb):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论