提交 8c211535 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

quick port of Elemwise, add/sub/mul and scalar_switch

上级 93b4e940
...@@ -14,9 +14,12 @@ import sys ...@@ -14,9 +14,12 @@ import sys
def inputs(): def inputs():
x = modes.build(tensor([[1.0, 2.0], [3.0, 4.0]], 'x')) l1 = [[1.0, 2.0], [3.0, 4.0]]
y = None l2 = [[3.0, 4.0], [1.0, 2.0]]
z = None l3 = numpy.ones((2, 3))
x = modes.build(tensor(l1, 'x'))
y = modes.build(tensor(l2, 'y'))
z = modes.build(tensor(l3, 'z'))
return x, y, z return x, y, z
def env(inputs, outputs, validate = True, features = []): def env(inputs, outputs, validate = True, features = []):
...@@ -27,24 +30,34 @@ class _test_TensorOps(unittest.TestCase): ...@@ -27,24 +30,34 @@ class _test_TensorOps(unittest.TestCase):
def test_0(self): def test_0(self):
x, y, z = inputs() x, y, z = inputs()
e = transpose(x) # e = mul(add(x, y), 2)
g = env([x], [e]) e = (x + y) * 2
fn, (i, ), (o, ) = gof.cc.CLinker(g).make_thunk() fn, i, o = gof.PerformLinker(env([x, y], [e])).make_thunk(True)
i.data = [[1.0, 2.0], [3.0, 4.0]]
# print sys.getrefcount(i.data)
fn() fn()
# print sys.getrefcount(i.data) print e
# print sys.getrefcount(o.data)
print o.data
# assert res == numpy.asarray(arr)
# def test_1(self): # def test_0(self):
# x, y, z = inputs() # x, y, z = inputs()
# e = mul(add(x, y), div(x, y)) # e = transpose(x)
# g = env([x, y], [e]) # g = env([x], [e])
# fn = gof.cc.CLinker(g).make_function() # fn, (i, ), (o, ) = gof.cc.CLinker(g).make_thunk()
# assert fn(1.0, 2.0) == 1.5 # i.data = [[1.0, 2.0], [3.0, 4.0]]
# assert e.data == 1.5 # # print sys.getrefcount(i.data)
# fn()
# # print sys.getrefcount(i.data)
# # print sys.getrefcount(o.data)
# print o.data
# # assert res == numpy.asarray(arr)
# # def test_1(self):
# # x, y, z = inputs()
# # e = mul(add(x, y), div(x, y))
# # g = env([x, y], [e])
# # fn = gof.cc.CLinker(g).make_function()
# # assert fn(1.0, 2.0) == 1.5
# # assert e.data == 1.5
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -27,6 +27,7 @@ class Tensor(ResultBase): ...@@ -27,6 +27,7 @@ class Tensor(ResultBase):
if dtype is None or broadcastable is None: if dtype is None or broadcastable is None:
if data is None: if data is None:
raise TypeError("Provide non-None data to complete the dtype and broadcastable flags.") raise TypeError("Provide non-None data to complete the dtype and broadcastable flags.")
data = numpy.asarray(data)
dtype = data.dtype dtype = data.dtype
if constant: if constant:
broadcastable = [1*(x == 1) for x in data.shape] broadcastable = [1*(x == 1) for x in data.shape]
...@@ -35,7 +36,7 @@ class Tensor(ResultBase): ...@@ -35,7 +36,7 @@ class Tensor(ResultBase):
self.broadcastable = broadcastable self.broadcastable = broadcastable
self.dtype = str(dtype) self.dtype = str(dtype)
self.constant = constant self.constant = constant
ResultBase.__init__(self, role = None, data = None, name = name) ResultBase.__init__(self, role = None, data = data, name = name)
def __get_constant(self): def __get_constant(self):
return self._constant return self._constant
......
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论