提交 07ceb6df authored 作者: Olivier Breuleux's avatar Olivier Breuleux

added utility functions to construct Tensors (i|f|)(scalar|vector|matrix|row|col)(s|)

上级 1ddd6c38
...@@ -14,6 +14,8 @@ import blas # for gemm, dot ...@@ -14,6 +14,8 @@ import blas # for gemm, dot
import elemwise as s2t import elemwise as s2t
import scalar as scal import scalar as scal
from functools import partial
class Tensor(BaseTensor): class Tensor(BaseTensor):
""" """
...@@ -100,6 +102,49 @@ def astensor(data, broadcastable=None, name=None): ...@@ -100,6 +102,49 @@ def astensor(data, broadcastable=None, name=None):
s2t.astensor = astensor s2t.astensor = astensor
# Easy constructors
def _multi(*fns):
def f2(f, names):
if len(names) == 1:
return f(names)
else:
return [f(name) for name in names]
if len(fns) == 1:
return partial(f2, fns)
else:
return [partial(f2, f) for f in fns]
def _int_float(f):
return partial(f, dtype = 'int64'), partial(f, dtype = 'float64')
def scalar(name, dtype = 'float64'):
return Tensor(name = name, dtype = dtype, broadcastable = ())
iscalar, fscalar = _int_float(scalar)
scalars, iscalars, fscalars = _multi(scalar, iscalar, fscalar)
def vector(name, dtype = 'float64'):
return Tensor(name = name, dtype = dtype, broadcastable = (False))
ivector, fvector = _int_float(vector)
vectors, ivectors, fvectors = _multi(vector, ivector, fvector)
def matrix(name, dtype = 'float64'):
return Tensor(name = name, dtype = dtype, broadcastable = (False, False))
imatrix, fmatrix = _int_float(matrix)
matrices, imatrices, fmatrices = _multi(matrix, imatrix, fmatrix)
def row(name, dtype = 'float64'):
return Tensor(name = name, dtype = dtype, broadcastable = (True, False))
irow, frow = _int_float(row)
rows, irows, frows = _multi(row, irow, frow)
def col(name, dtype = 'float64'):
return Tensor(name = name, dtype = dtype, broadcastable = (False, True))
icol, fcol = _int_float(col)
cols, icols, fcols = _multi(col, icol, fcol)
############################ ############################
# Supporting Ops # Supporting Ops
############################ ############################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论