提交 692a906b authored 作者: Olivier Breuleux's avatar Olivier Breuleux

added a whole bunch of supported dtypes in base_tensor

上级 fdb0a71b
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import unittest import unittest
import numpy import numpy
from tensor import tensor, Tensor from tensor import tinit, Tensor
import gof import gof
from gof import modes, Env from gof import modes, Env
...@@ -10,6 +10,10 @@ from elemwise import * ...@@ -10,6 +10,10 @@ from elemwise import *
class ElemwiseAdd(Elemwise): class ElemwiseAdd(Elemwise):
def __init__(self, x, y):
self.inputs = (x, y)
self.outputs = [Tensor(dtype = x.dtype, broadcastable = x.broadcastable)]
def var_desc(self): def var_desc(self):
return [('x', 1), ('y', 1)], [('z', 1)] return [('x', 1), ('y', 1)], [('z', 1)]
...@@ -25,9 +29,9 @@ def inputs(): ...@@ -25,9 +29,9 @@ def inputs():
l1 = [[1.0, 2.0], [3.0, 4.0]] l1 = [[1.0, 2.0], [3.0, 4.0]]
l2 = [[3.0, 4.0], [1.0, 2.0]] l2 = [[3.0, 4.0], [1.0, 2.0]]
l3 = numpy.ones((2, 3)) l3 = numpy.ones((2, 3))
x = modes.build(tensor(l1, name = 'x')) x = modes.build(tinit(l1, name = 'x'))
y = modes.build(tensor(l2, name = 'y')) y = modes.build(tinit(l2, name = 'y'))
z = modes.build(tensor(l3, name = 'z')) z = modes.build(tinit(l3, name = 'z'))
return x, y, z return x, y, z
def env(inputs, outputs, validate = True, features = []): def env(inputs, outputs, validate = True, features = []):
...@@ -47,6 +51,17 @@ class _test_Elemwise(unittest.TestCase): ...@@ -47,6 +51,17 @@ class _test_Elemwise(unittest.TestCase):
fn() fn()
assert (e.data == numpy.array([[4, 6, 4, 6]])).all() assert (e.data == numpy.array([[4, 6, 4, 6]])).all()
# def test_1(self):
# x, y, z = inputs()
# e = ElemwiseAdd(x, y).out
# fn, i, o = gof.CLinker(env([x, y], [e])).make_thunk(True)
# fn()
# assert (e.data == numpy.array([[4, 6], [4, 6]])).all()
# x.data.resize((1, 4))
# y.data.resize((1, 4))
# fn()
# assert (e.data == numpy.array([[4, 6, 4, 6]])).all()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -80,7 +80,12 @@ class BaseTensor(ResultBase): ...@@ -80,7 +80,12 @@ class BaseTensor(ResultBase):
#TODO: add more type correspondances for e.g. int32, int64, float32, #TODO: add more type correspondances for e.g. int32, int64, float32,
#complex64, etc. #complex64, etc.
try: try:
return {'float64': (float, 'double', 'NPY_DOUBLE')}[self.dtype] return {'float32': (float, 'npy_float32', 'NPY_FLOAT32'),
'float64': (float, 'npy_float64', 'NPY_FLOAT64'),
'int8': (int, 'npy_int8', 'NPY_INT8'),
'int16': (int, 'npy_int16', 'NPY_INT16'),
'int32': (int, 'npy_int32', 'NPY_INT32'),
'int64': (int, 'npy_int64', 'NPY_INT64')}[self.dtype]
except KeyError: except KeyError:
raise TypeError("Unsupported dtype for BaseTensor: %s" % self.dtype) raise TypeError("Unsupported dtype for BaseTensor: %s" % self.dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论