提交 42ba5a5d authored 作者: james@X40's avatar james@X40

added shortcuts for complex tensor types

上级 24f4ad19
......@@ -407,6 +407,8 @@ def _multi(*fns):
else:
return [partial(f2, f) for f in fns]
cscalar = Tensor('complex64', ())
zscalar = Tensor('complex128', ())
fscalar = Tensor('float32', ())
dscalar = Tensor('float64', ())
bscalar = Tensor('int8', ())
......@@ -420,9 +422,13 @@ scalars, fscalars, dscalars, iscalars, lscalars = _multi(scalar, fscalar, dscala
int_types = bscalar, wscalar, iscalar, lscalar
float_types = fscalar, dscalar
complex_types = cscalar, zscalar
int_scalar_types = int_types
float_scalar_types = float_types
complex_scalar_types = complex_types
cvector = Tensor('complex64', (False, ))
zvector = Tensor('complex128', (False, ))
fvector = Tensor('float32', (False, ))
dvector = Tensor('float64', (False, ))
bvector = Tensor('int8', (False,))
......@@ -436,7 +442,10 @@ vectors, fvectors, dvectors, ivectors, lvectors = _multi(vector, fvector, dvecto
int_vector_types = bvector, wvector, ivector, lvector
float_vector_types = fvector, dvector
complex_vector_types = cvector, zvector
cmatrix = Tensor('complex64', (False, False))
zmatrix = Tensor('complex128', (False, False))
fmatrix = Tensor('float32', (False, False))
dmatrix = Tensor('float64', (False, False))
bmatrix = Tensor('int8', (False, False))
......@@ -450,7 +459,10 @@ matrices, fmatrices, dmatrices, imatrices, lmatrices = _multi(matrix, fmatrix, d
int_matrix_types = bmatrix, wmatrix, imatrix, lmatrix
float_matrix_types = fmatrix, dmatrix
complex_matrix_types = cmatrix, zmatrix
crow = Tensor('complex64', (True, False))
zrow = Tensor('complex128', (True, False))
frow = Tensor('float32', (True, False))
drow = Tensor('float64', (True, False))
brow = Tensor('int8', (True, False))
......@@ -462,6 +474,8 @@ def row(name = None, dtype = 'float64'):
return type(name)
rows, frows, drows, irows, lrows = _multi(row, frow, drow, irow, lrow)
ccol = Tensor('complex64', (False, True))
zcol = Tensor('complex128', (False, True))
fcol = Tensor('float32', (False, True))
dcol = Tensor('float64', (False, True))
bcol = Tensor('int8', (False, True))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论