提交 625e298a authored 作者: Olivier Breuleux's avatar Olivier Breuleux

finished porting the ops

上级 8c211535
...@@ -34,7 +34,14 @@ class _test_TensorOps(unittest.TestCase): ...@@ -34,7 +34,14 @@ class _test_TensorOps(unittest.TestCase):
e = (x + y) * 2 e = (x + y) * 2
fn, i, o = gof.PerformLinker(env([x, y], [e])).make_thunk(True) fn, i, o = gof.PerformLinker(env([x, y], [e])).make_thunk(True)
fn() fn()
print e assert (e.data == numpy.array([[8, 12], [8, 12]])).all()
def test_1(self):
x, y, z = inputs()
e = dot(x, z).T
fn, i, o = gof.PerformLinker(env([x, z], [e])).make_thunk(True)
fn()
assert (e.data == numpy.array([[3, 3, 3], [7, 7, 7]]).T).all()
# def test_0(self): # def test_0(self):
# x, y, z = inputs() # x, y, z = inputs()
......
"""
Helper functions to make gof backwards compatible (tested on python 2.4 and 2.5)
"""
import sys
if sys.version_info[:2] < (2,5):
def all(iterable):
for element in iterable:
if not element:
return False
return True
else:
# Only bother with this else clause and the __all__ line if you are putting
# this in a separate file.
import __builtin__
all = __builtin__.all
__all__ = ['all']
from tensor import * from tensor import *
from gof import Op, utils from gof import Op, utils, Destroyer, Viewer
def upcast(dtype, *dtypes): def upcast(dtype, *dtypes):
...@@ -81,7 +81,7 @@ class Transpose(UnaryTensorOp): ...@@ -81,7 +81,7 @@ class Transpose(UnaryTensorOp):
def propagate_broadcastable(self, x): def propagate_broadcastable(self, x):
x2 = copy(x) x2 = copy(x)
x2.reverse() x2.reverse()
return x2 return [x2]
def impl(self, x): def impl(self, x):
return x.T return x.T
...@@ -171,7 +171,10 @@ class Elemwise(TensorOp): ...@@ -171,7 +171,10 @@ class Elemwise(TensorOp):
@classmethod @classmethod
def inplace_version(cls): def inplace_version(cls):
return cls # placeholder class Ret(cls, Destroyer):
def destroy_list(self):
return self.inputs[0]
return Ret
def c_init(self, inputs, outputs): def c_init(self, inputs, outputs):
pass pass
...@@ -200,6 +203,46 @@ class TensorScalarOp(Elemwise): ...@@ -200,6 +203,46 @@ class TensorScalarOp(Elemwise):
return "z_i = %s;" % self.c_expr return "z_i = %s;" % self.c_expr
###########################
#### Binary Operations ####
###########################
#########
## Dot ##
#########
class Dot(TensorOp):
@staticmethod
def _output_shape(xshape, yshape):
# This describes the logic to calculate numpy.dot(x, y).shape
# given x.shape and y.shape
if len(xshape) == 0: # x is a scalar
shape = yshape
else:
if len(yshape) >= 2: #y is a matrix or tensor
assert xshape[-1] == yshape[-2]
shape = tuple(xshape[:-1]+ yshape[:-2]+yshape[-1:])
elif len(yshape)==1: #y is vector
assert xshape[-1] == yshape[-1]
shape = tuple(xshape[:-1])
else: #y is a scalar
shape = xshape
return shape
def impl(self, x, y):
return numpy.dot(x, y)
def grad(self, (x, y), gz):
return dot(gz, transpose(y)), dot(transpose(x), gz)
def propagate_broadcastable(self, x, y):
assert len(x) == 2 and len(x) == len(y)
return [(x[0], y[1])]
def c_support_code(self):
return blas.cblas_header_text()
def c_libs(self):
return blas.ldflags()
def c_impl(self, (_x, _y), (_z, )):
return blas.gemm_code('', '1.0', '0.0')
######### #########
## Add ## ## Add ##
...@@ -236,20 +279,6 @@ class AddScalarInplace(AddScalar.inplace_version()): ...@@ -236,20 +279,6 @@ class AddScalarInplace(AddScalar.inplace_version()):
x += a x += a
return x return x
# shortcuts #
class Twice(Elemwise):
def impl(self, x):
return 2.0 * x
def grad(self, x, gz):
return scale(gz, 2.0)
def c_foreach(self, (x_i, ), (z_i, )):
"z_i = x_i + x_i;"
class TwiceInplace(Twice.inplace_version()):
def impl(self, x):
x *= 2.0
return x
######### #########
...@@ -279,7 +308,7 @@ def sub_scalar_r(x, a): ...@@ -279,7 +308,7 @@ def sub_scalar_r(x, a):
def sub_scalar_l(x, a): def sub_scalar_l(x, a):
return add_scalar(-x, a) return add_scalar(-x, a)
def sub_scalar_r_inplace(x, a): def sub_scalar_rinplace(x, a):
return add_scalar_inplace(x, -a) return add_scalar_inplace(x, -a)
...@@ -319,314 +348,327 @@ class ScaleInplace(Scale.inplace_version()): ...@@ -319,314 +348,327 @@ class ScaleInplace(Scale.inplace_version()):
x *= a x *= a
return x return x
# shortcuts #
class Sqr(Elemwise):
def impl(self, x):
return x * x
def grad(self, x, gz):
return scale(mul_elemwise(x, gz), 2.0)
def c_foreach(self, (x_i, ), (z_i, )):
return "z_i = x_i * x_i;"
class SqrInplace(Sqr.inplace_version()):
def impl(x): #########
x *= x ## Div ##
#########
# Elemwise #
class DivElemwise(Elemwise):
def impl(self, x, y):
assert_same_shapes(x, y)
return x / y
def grad(self, (x, y), gz):
return div(gz, y), -div(mul(x, gz), sqr(y))
def c_foreach(self, (x_i, y_i), (z_i, )):
return "z_i = x_i / y_i;"
class DivElemwiseInplace(DivElemwise.inplace_version()):
def impl(self, x, y):
assert_same_shapes(x, y)
x /= y
return x return x
# Scalar #
def div_scalar_r(x, a):
return scale(x, inv_elemwise(a))
class Sqrt(Elemwise): def div_scalar_l(x, a):
def impl(self, x): return scale(inv_elemwise(x), a)
return numpy.sqrt(x)
def grad(self, x, gz):
return scale(div(gz, sqrt(x)), 0.5)
def c_foreach(self, (x_i, ), (z_i, )):
return "z_i = pow(x_i, 0.5);"
class SqrtInplace(Sqrt.inplace_version()): def div_scalar_rinplace(x, a):
def impl(self, x): return scale_inplace(x, inv_elemwise(a))
x **= 0.5
#########
## Pow ##
#########
# Elemwise #
class PowElemwise(Elemwise):
def impl(self, x, y):
assert_same_shapes(x, y)
return x ** y
def grad(self, (x, s), gz):
gx = gz * s * (pow_elemwise(x, s-1.0))
gs = gz * log(x) * pow_elemwise(x, s)
return gx, gs
def c_foreach(self, (x_i, s_i), (z_i, )):
return "z_i = pow(x_i, s_i)"
class PowElemwiseInplace(PowElemwise.inplace_version()):
def impl(self, x, y):
assert_same_shapes(x, y)
x **= y
return x return x
# Scalar #
class PowScalarL(TensorScalarOp):
def impl(self, x, a):
assert_tensor_scalar(x, a)
return a ** x
def grad(self, (x, s), gz):
gx = sum(gz * s * pow_scalar_l(add_scalar(s,-1.0), x))
gs = scale(mul(gz, pow_scalar_l(s, x)), log(x))
return gx, gs
c_expr = "pow(a, x_i)"
class PowScalarR(TensorScalarOp):
def impl(self, x, a):
assert_tensor_scalar(x, a)
return x ** a
def grad(self, (x, s), gz):
gx = scale(mul_elemwise(gz,pow_scalar_r(x, add_scalar(s,-1.0))), s)
gs = sum(mul_elemwise(mul_elemwise(gz, pow_scalar_r(x,s)), log(x)))
return gx, gs
c_expr = "pow(x_i, a)"
class PowScalarRInplace(PowScalarR.inplace_version()):
def impl(self, x, a):
assert_tensor_scalar(x, a)
x **= a
return x
# ## Element-wise division ##
# class div_elemwise(elemwise):
# impl = assert_same_shapes(numpy.ndarray.__div__)
# def grad(x, y, gz):
# return div(gz, y), -div(mul(x, gz), sqr(y))
# def c_foreach((x_i, y_i), (z_i, )):
# return "z_i = x_i / y_i;"
# div_elemwise_inplace = div_elemwise.inplace_version()
# div_elemwise_inplace.set_impl(assert_same_shapes(numpy.ndarray.__idiv__))
# def div_scalar_r(x, a): ############
# return scale(x, inv_elemwise(a)) ## Others ##
############
# def div_scalar_l(x, a): class Fill(Elemwise):
# return scale(inv_elemwise(x), a) def impl(self, model, value):
return (model * 0) + value
def grad(self, (model, value), gz):
return None, sum(gz)
def c_init(self, (model, value), (z, )):
return "value_dtype value0 = ((value_dtype*)PyArray_DATA(value))[0];"
def c_foreach(self, (model_i, value), (z_i, )):
return "z_i = value0;"
# def div_scalar_r_inplace(x, a):
# return scale_inplace(x, inv_elemwise(a))
##########################
#### Unary Operations ####
##########################
# ## Scaling ## class Transpose(TensorOp, Viewer):
def view_map(self):
return {self.out: [self.inputs[0]]}
def impl(self, x):
return x.T
def grad(self, x, gz):
return transpose_copy(gz)
def propagate_broadcastable(self, x):
rval = list(x)
rval.reverse()
return [rval]
def c_impl(self, (x, ), (xt, )):
return """
const int l = x->nd;
// The user must ensure that all references to
//xt->data go through xt, or there's going to be trouble..
int refcheck = 0;
if (x == xt)
{
return -1;
}
if (refcheck)
{
int refcnt = PyArray_REFCOUNT(xt);
if ((refcnt > 2) // you might think this should be 1.. but this works
//|| (xt->base != NULL)
|| (xt->weakreflist != NULL))
{
PyErr_SetString(PyExc_ValueError,
"cannot resize an array that has "\\
"been referenced or is referencing\\n"\\
"another array in this way. Use the "\\
"resize function");
return -2;
}
}
if (xt->nd != x->nd)
{
// this technique comes from PyArray_Resize()
npy_intp * dimptr = (npy_intp*)PyDimMem_RENEW(xt->dimensions, 2 * x->nd);
if (!dimptr)
{
PyErr_NoMemory();
return 1;
}
xt->nd = x->nd;
xt->dimensions = dimptr;
xt->strides = dimptr + x->nd;
}
//copy x's dimensions and strides
for (int i = 0; i < l; ++i)
{
xt->dimensions[i] = x->dimensions[l-i-1];
xt->strides[i] = x->strides[l-i-1];
}
# class scale(tensor_scalar_op): // point directly at b's type descriptor
# impl = tensor_scalar_impl(numpy.ndarray.__mul__) Py_INCREF(x->descr);
# def grad(x, a, gz): Py_DECREF(xt->descr);
# return scale(a, gz), sum(mul_elemwise(x, gz)) xt->descr = x->descr;
# c_expr = "x_i * a"
// name x as a base of xt, increment its refcount
if ( xt->base != (PyObject*)x)
{
Py_INCREF(x);
if ((xt->base) && (xt->base != Py_None))
{
Py_DECREF(xt->base);
}
xt->base = (PyObject*)x;
}
// mark xt as not owning its data
if (PyArray_CHKFLAGS(xt, NPY_OWNDATA))
{
PyDataMem_FREE(xt->data);
xt->flags &= ~NPY_OWNDATA;
}
xt->data = x->data;
# scale_inplace = scale.inplace_version() // this function is described in
# scale_inplace.set_impl(tensor_scalar_impl(numpy.ndarray.__imul__)) // ~/zzz.NOBACKUP/pub/src/numpy-1.0.3.1/numpy/core/src/arrayobject.c:1890
PyArray_UpdateFlags(xt, NPY_CONTIGUOUS|NPY_FORTRAN|NPY_ALIGNED|NPY_WRITEABLE);
/*
TODO
What should be done with the weakreflist ?
*/
"""
# class neg(elemwise): def transpose_copy(x):
# impl = numpy.ndarray.__neg__ return array_copy(transpose(x))
# def grad(x, gz):
# return -gz
# def c_foreach((x_i, ), (z_i, )):
# return "z_i = -x_i;"
# neg_inplace = neg.inplace_version()
# neg_inplace.set_impl(lambda x: x.__imul__(-1))
class Neg(Elemwise):
def impl(self, x):
return -x
def grad(self, x, gz):
return -gz
def c_foreach(self, (x_i, ), (z_i, )):
return "z_i = -x_i;"
# class inv_elemwise(elemwise): class NegInplace(Neg.inplace_version()):
# impl = lambda x: 1 / x def impl(self, x):
# def grad(x, gz): x *= -1
# return -gz return x
# def c_foreach((x_i, ), (z_i, )):
# return "z_i = 1 / x_i;"
# inv_elemwise_inplace = inv_elemwise.inplace_version()
class InvElemwise(Elemwise):
def impl(self, x):
return 1 / x
def grad(self, x, gz):
return -gz / (x * x)
def c_foreach(self, (x_i, ), (z_i, )):
return "z_i = 1 / x_i;"
# ## Dot product ## class InvElemwiseInplace(InvElemwise.inplace_version()):
def impl(self, x):
x[:] = 1 / x
return x
# class dot(omega_op):
# @staticmethod
# def _output_shape(xshape, yshape):
# if len(xshape) == 0: # x is a scalar
# shape = yshape
# else:
# if len(yshape) >= 2: #y is a matrix or tensor
# assert xshape[-1] == yshape[-2]
# shape = tuple(xshape[:-1]+ yshape[:-2]+yshape[-1:])
# elif len(yshape)==1: #y is vector
# assert xshape[-1] == yshape[-1]
# shape = tuple(xshape[:-1])
# else: #y is a scalar
# shape = xshape
# return shape
# impl = numpy.dot class Exp(Elemwise):
# def grad(x, y, gz): def impl(self, x): return numpy.exp(x)
# return dot(gz, transpose(y)), dot(transpose(x), gz) def grad(self, x, gz): return gz * exp(x)
# def refresh(self, alloc=False): def c_foreach(self, (x_i, ), (z_i, )): return "z_i = exp(x_i);"
# x,y = self.inputs
# shape = self._output_shape(x.shape, y.shape)
# dtype = upcast(x.dtype, y.dtype)
# if self.out.data is not None \
# and self.out.shape == shape \
# and self.out.dtype == dtype:
# return #everything is ok
# if alloc or self.out.data is not None: #data should be allocated
# self.out.data = None
# self.out.shape = shape
# self.out.dtype = dtype
# self.out.alloc()
# else:
# self.out.shape = shape
# self.out.dtype = dtype
# def c_support_code(self):
# return blas.cblas_header_text()
# def c_libs(self):
# return blas.ldflags()
# def c_impl((_x, _y), (_z, )):
# return blas.gemm_code('', '1.0', '0.0')
# ## Transposition ##
# class transpose(omega_op):
# def view_map(self): return {self.out: [self.inputs[0]]}
# impl = numpy.transpose
# def grad(x, gz):
# return transpose_copy(gz)
# def refresh_shape(self):
# rval = list(self.inputs[0].shape)
# rval.reverse()
# return rval
# def refresh_dtype(self):
# return self.inputs[0].dtype
# def c_impl((x, ), (xt, )):
# return """
# const int l = x->nd;
# // The user must ensure that all references to
# //xt->data go through xt, or there's going to be trouble..
# int refcheck = 0;
# if (x == xt)
# {
# return -1;
# }
# if (refcheck)
# {
# int refcnt = PyArray_REFCOUNT(xt);
# if ((refcnt > 2) // you might think this should be 1.. but this works
# //|| (xt->base != NULL)
# || (xt->weakreflist != NULL))
# {
# PyErr_SetString(PyExc_ValueError,
# "cannot resize an array that has "\\
# "been referenced or is referencing\\n"\\
# "another array in this way. Use the "\\
# "resize function");
# return -2;
# }
# }
# if (xt->nd != x->nd)
# {
# // this technique comes from PyArray_Resize()
# npy_intp * dimptr = (npy_intp*)PyDimMem_RENEW(xt->dimensions, 2 * x->nd);
# if (!dimptr)
# {
# PyErr_NoMemory();
# return 1;
# }
# xt->nd = x->nd;
# xt->dimensions = dimptr;
# xt->strides = dimptr + x->nd;
# }
# //copy x's dimensions and strides
# for (int i = 0; i < l; ++i)
# {
# xt->dimensions[i] = x->dimensions[l-i-1];
# xt->strides[i] = x->strides[l-i-1];
# }
# // point directly at b's type descriptor
# Py_INCREF(x->descr);
# Py_DECREF(xt->descr);
# xt->descr = x->descr;
# // name x as a base of xt, increment its refcount
# if ( xt->base != (PyObject*)x)
# {
# Py_INCREF(x);
# if ((xt->base) && (xt->base != Py_None))
# {
# Py_DECREF(xt->base);
# }
# xt->base = (PyObject*)x;
# }
# // mark xt as not owning its data class Log(Elemwise):
# if (PyArray_CHKFLAGS(xt, NPY_OWNDATA)) def impl(self, x): return numpy.log(x)
# { def grad(self, x, gz): return gz / x
# PyDataMem_FREE(xt->data); def c_foreach(self, (x_i, ), (z_i, )): return "z_i = log(x_i);"
# xt->flags &= ~NPY_OWNDATA;
# }
# xt->data = x->data;
# // this function is described in class Log2(Elemwise):
# // ~/zzz.NOBACKUP/pub/src/numpy-1.0.3.1/numpy/core/src/arrayobject.c:1890 def impl(self, x): return numpy.log2(x)
# PyArray_UpdateFlags(xt, NPY_CONTIGUOUS|NPY_FORTRAN|NPY_ALIGNED|NPY_WRITEABLE); def grad(self, x, gz): return gz / (x * numpy.log(2))
def c_foreach(self, (x_i, ), (z_i, )): return "z_i = log2(x_i);"
# /*
# TODO
# What should be done with the weakreflist ?
# */
# """
# def transpose_copy(x): class Twice(Elemwise):
# return array_copy(transpose(x)) def impl(self, x):
return 2.0 * x
def grad(self, x, gz):
return scale(gz, 2.0)
def c_foreach(self, (x_i, ), (z_i, )):
"z_i = x_i + x_i;"
class TwiceInplace(Twice.inplace_version()):
def impl(self, x):
x *= 2.0
return x
# ## Copy ##
# class array_copy(elemwise): class Sqr(Elemwise):
# impl = numpy.array def impl(self, x):
# grad = lambda x, gz: gz return x * x
# def c_foreach((x_i, ), (z_i, )): def grad(self, x, gz):
# return "z_i = x_i;" return scale(mul_elemwise(x, gz), 2.0)
def c_foreach(self, (x_i, ), (z_i, )):
return "z_i = x_i * x_i;"
class SqrInplace(Sqr.inplace_version()):
def impl(x):
x *= x
return x
# ## Power ##
# class sqr(elemwise): class Sqrt(Elemwise):
# def impl(x): def impl(self, x):
# return x * x return numpy.sqrt(x)
# def grad(x, gz): def grad(self, x, gz):
# return scale(mul_elemwise(x, gz), 2.0) return scale(div(gz, sqrt(x)), 0.5)
# def c_foreach((x_i, ), (z_i, )): def c_foreach(self, (x_i, ), (z_i, )):
# return "z_i = x_i * x_i;" return "z_i = pow(x_i, 0.5);"
# sqr_inplace = sqr.inplace_version() class SqrtInplace(Sqrt.inplace_version()):
# sqr_inplace.set_impl(lambda x: x.__imul__(x)) def impl(self, x):
x **= 0.5
return x
class Sum(Elemwise):
def impl(self, x):
return numpy.sum(x)
def grad(self, x, gz):
return fill(x, gz)
def propagate_broadcastable(self, *inputs):
return [()]
def c_init(self, (x, ), (sum, )):
return "sum_dtype* sump = ((sum_dtype*)PyArray_DATA(sum)); sump[0] = 0;"
def c_foreach(self, (x_i, ), (sum, )):
return "sump[0] += x_i;"
# class sqrt(elemwise): class ArrayCopy(Elemwise):
# impl = numpy.sqrt def impl(self, x):
# def grad(x, gz): return numpy.array(x)
# return scale(div(gz, sqrt(x)), 0.5) def grad(self, x, gz):
# def c_foreach((x_i, ), (z_i, )): return gz
# return "z_i = pow(x_i, 0.5);" def c_foreach(self, (x_i, ), (z_i, )):
return "z_i = x_i;"
# sqrt_inplace = sqrt.inplace_version() class OnesLike(Elemwise):
# sqrt_inplace.set_impl(lambda x: x.__ipow__(0.5)) def impl(self, x):
return numpy.ones_like(x)
def grad(self, x, gz):
return None
class ZerosLike(Elemwise):
def impl(self, x):
return numpy.zeros_like(x)
def grad(self, x, gz):
return None
# class exp(elemwise):
# def impl(x): return numpy.exp(x)
# def grad(x, gz): return gz * exp(x)
# def c_foreach((x_i, ), (z_i, )): return "z_i = exp(x_i);"
# class log(elemwise):
# def impl(x): return numpy.log(x)
# def grad(x, gz): return gz / x
# def c_foreach((x_i, ), (z_i, )): return "z_i = log(x_i);"
# class log2(elemwise):
# def impl(x): return numpy.log2(x)
# def grad(x, gz): return gz / (x * numpy.log(2))
# def c_foreach((x_i, ), (z_i, )): return "z_i = log2(x_i);"
# class pow_elemwise(elemwise):
# impl = assert_same_shapes(numpy.ndarray.__pow__)
# def grad(x, s, gz):
# raise NotImplemented # no gs
# return gz * s * (pow_elemwise(x, s-1.0))
# def c_foreach((x_i, s_i), (z_i, )):
# return "z_i = pow(x_i, s_i)"
# pow_elemwise_inplace = pow_elemwise.inplace_version()
# pow_elemwise_inplace.set_impl(assert_same_shapes(numpy.ndarray.__ipow__))
# class pow_scalar_l(tensor_scalar_op):
# impl = tensor_scalar_impl(lambda x, y: numpy.ndarray.__pow__(y, x))
# def grad(x, s, gz):
# raise NotImplemented # no gs
# return gz * x * (pow_scalar_l(s,x-1.0))
# c_expr = "pow(a, x_i)"
# class pow_scalar_r(tensor_scalar_op):
# impl = tensor_scalar_impl(numpy.ndarray.__pow__)
# def grad(x, s, gz):
# gx = gz * s * (pow_scalar_r(x,s-1.0))
# gs = sum(gz * pow_scalar_r(x,s) * log(x))
# return gx, gs
# c_expr = "pow(x_i, a)"
# pow_scalar_r_inplace = pow_scalar_r.inplace_version()
# pow_scalar_r_inplace.set_impl(tensor_scalar_impl(numpy.ndarray.__ipow__))
# ## Others ## # ## Others ##
...@@ -658,34 +700,6 @@ class SqrtInplace(Sqrt.inplace_version()): ...@@ -658,34 +700,6 @@ class SqrtInplace(Sqrt.inplace_version()):
# """ # """
# class fill(elemwise):
# impl = lambda model, value: (model * 0) + value
# def c_init((model, value), (z, )):
# return "value_dtype value0 = ((value_dtype*)PyArray_DATA(value))[0];"
# def c_foreach((model_i, value), (z_i, )):
# return "z_i = value0;"
# fill_inplace = fill.inplace_version()
# class sum(elemwise):
# impl = numpy.sum
# def grad(x, gz):
# return fill(x, gz)
# def refresh_shape(self):
# return ()
# def c_init((x, ), (sum, )):
# return "sum_dtype* sump = ((sum_dtype*)PyArray_DATA(sum)); sump[0] = 0;"
# def c_foreach((x_i, ), (sum, )):
# return "sump[0] += x_i;"
# class ones_like(elemwise):
# impl = numpy.ones_like
# def grad(x, gz): return Undefined
# class zeros_like(elemwise):
# impl = numpy.zeros_like
# def grad(x, gz): return Undefined
# ## Array slicing ## # ## Array slicing ##
# class get_slice(omega_op): # class get_slice(omega_op):
...@@ -713,16 +727,16 @@ add = scalar_switch(add_elemwise, add_scalar, add_scalar) ...@@ -713,16 +727,16 @@ add = scalar_switch(add_elemwise, add_scalar, add_scalar)
add_inplace = scalar_switch(add_elemwise_inplace, add_scalar_inplace) add_inplace = scalar_switch(add_elemwise_inplace, add_scalar_inplace)
sub = scalar_switch(sub_elemwise, sub_scalar_r, sub_scalar_l) sub = scalar_switch(sub_elemwise, sub_scalar_r, sub_scalar_l)
sub_inplace = scalar_switch(sub_elemwise_inplace, sub_scalar_r_inplace) sub_inplace = scalar_switch(sub_elemwise_inplace, sub_scalar_rinplace)
mul = scalar_switch(mul_elemwise, scale, scale) mul = scalar_switch(mul_elemwise, scale, scale)
mul_inplace = scalar_switch(mul_elemwise_inplace, scale_inplace) mul_inplace = scalar_switch(mul_elemwise_inplace, scale_inplace)
# div = scalar_switch(div_elemwise, div_scalar_r, div_scalar_l) div = scalar_switch(div_elemwise, div_scalar_r, div_scalar_l)
# div_inplace = scalar_switch(div_elemwise_inplace, div_scalar_r_inplace) div_inplace = scalar_switch(div_elemwise_inplace, div_scalar_rinplace)
# pow = scalar_switch(pow_elemwise, pow_scalar_r, pow_scalar_l) pow = scalar_switch(pow_elemwise, pow_scalar_r, pow_scalar_l)
# pow_inplace = scalar_switch(pow_elemwise_inplace, pow_scalar_r_inplace) pow_inplace = scalar_switch(pow_elemwise_inplace, pow_scalar_rinplace)
Tensor.__add__ = add Tensor.__add__ = add
Tensor.__sub__ = sub Tensor.__sub__ = sub
...@@ -730,6 +744,6 @@ Tensor.__mul__ = mul ...@@ -730,6 +744,6 @@ Tensor.__mul__ = mul
Tensor.__iadd__ = add_inplace Tensor.__iadd__ = add_inplace
Tensor.__isub__ = sub_inplace Tensor.__isub__ = sub_inplace
Tensor.__imul__ = mul_inplace Tensor.__imul__ = mul_inplace
Tensor.T = property(transpose)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论