提交 8f4208f6 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Allow subclasses of _tensor_py_operators to wrap the results in their own type.

上级 fcbd4a34
......@@ -18,40 +18,43 @@ class AsTensorError(TypeError):
class _tensor_py_operators:
def _wrap(self, val):
return val
# UNARY
def __abs__(self):
return theano.tensor.basic.abs_(self)
return self._wrap(theano.tensor.basic.abs_(self))
def __neg__(self):
return theano.tensor.basic.neg(self)
return self._wrap(theano.tensor.basic.neg(self))
# CASTS
#### REMOVED THESE BECAUSE PYTHON appears to require __int__ to return
#### an int. -JB 20081112
#def __int__(self): return convert_to_int32(self)
#def __float__(self): return convert_to_float64(self)
#def __complex__(self): return convert_to_complex128(self)
#def __int__(self): return self._wrap(convert_to_int32(self))
#def __float__(self): return self._wrap(convert_to_float64(self))
#def __complex__(self): return self._wrap(convert_to_complex128(self))
# COMPARISONS
_is_nonzero = True
def __lt__(self, other):
rval = theano.tensor.basic.lt(self, other)
rval = self._wrap(theano.tensor.basic.lt(self, other))
rval._is_nonzero = False
return rval
def __le__(self, other):
rval = theano.tensor.basic.le(self, other)
rval = self._wrap(theano.tensor.basic.le(self, other))
rval._is_nonzero = False
return rval
def __gt__(self, other):
rval = theano.tensor.basic.gt(self, other)
rval = self._wrap(theano.tensor.basic.gt(self, other))
rval._is_nonzero = False
return rval
def __ge__(self, other):
rval = theano.tensor.basic.ge(self, other)
rval = self._wrap(theano.tensor.basic.ge(self, other))
rval._is_nonzero = False
return rval
......@@ -81,39 +84,39 @@ class _tensor_py_operators:
# BITWISE
def __invert__(self):
return theano.tensor.basic.invert(self)
return self._wrap(theano.tensor.basic.invert(self))
def __and__(self, other):
return theano.tensor.basic.and_(self, other)
return self._wrap(theano.tensor.basic.and_(self, other))
def __or__(self, other):
return theano.tensor.basic.or_(self, other)
return self._wrap(theano.tensor.basic.or_(self, other))
def __xor__(self, other):
return theano.tensor.basic.xor(self, other)
return self._wrap(theano.tensor.basic.xor(self, other))
def __rand__(self, other):
return theano.tensor.basic.and_(other, self)
return self._wrap(theano.tensor.basic.and_(other, self))
def __ror__(self, other):
return theano.tensor.basic.or_(other, self)
return self._wrap(theano.tensor.basic.or_(other, self))
def __rxor__(self, other):
return theano.tensor.basic.xor(other, self)
return self._wrap(theano.tensor.basic.xor(other, self))
# def __iand__(self, other):
# return _and_inplace(self, other)
# return self._wrap(_and_inplace(self, other))
#
# def __ior__(self, other):
# return _or_inplace(self, other)
# return self._wrap(_or_inplace(self, other))
#
#def __ixor__(self, other):
# return _xor_inplace(self, other)
# return self._wrap(_xor_inplace(self, other))
# ARITHMETIC - NORMAL
def __add__(self, other):
try:
return theano.tensor.basic.add(self, other)
return self._wrap(theano.tensor.basic.add(self, other))
# We should catch the minimum number of exception here.
# Otherwise this will convert error when Theano flags
# compute_test_value is used
......@@ -132,7 +135,7 @@ class _tensor_py_operators:
# See explanation in __add__ for the error catched
# and the return value in that case
try:
return theano.tensor.basic.sub(self, other)
return self._wrap(theano.tensor.basic.sub(self, other))
except (NotImplementedError, AsTensorError):
return NotImplemented
......@@ -140,7 +143,7 @@ class _tensor_py_operators:
# See explanation in __add__ for the error catched
# and the return value in that case
try:
return theano.tensor.mul(self, other)
return self._wrap(theano.tensor.mul(self, other))
except (NotImplementedError, AsTensorError):
return NotImplemented
......@@ -148,7 +151,7 @@ class _tensor_py_operators:
# See explanation in __add__ for the error catched
# and the return value in that case
try:
return theano.tensor.basic.div_proxy(self, other)
return self._wrap(theano.tensor.basic.div_proxy(self, other))
except IntegerDivisionError:
# This is to raise the exception that occurs when trying to divide
# two integer arrays (currently forbidden).
......@@ -162,7 +165,7 @@ class _tensor_py_operators:
# See explanation in __add__ for the error catched
# adn the return value in that case
try:
return theano.tensor.basic.pow(self, other)
return self._wrap(theano.tensor.basic.pow(self, other))
except (NotImplementedError, AsTensorError):
return NotImplemented
......@@ -170,7 +173,7 @@ class _tensor_py_operators:
# See explanation in __add__ for the error catched
# adn the return value in that case
try:
return theano.tensor.basic.mod_check(self, other)
return self._wrap(theano.tensor.basic.mod_check(self, other))
except ComplexError:
# This is to raise the exception that occurs when trying to compute
# x % y with either x or y a complex number.
......@@ -179,55 +182,55 @@ class _tensor_py_operators:
return NotImplemented
def __truediv__(self, other):
return theano.tensor.basic.true_div(self, other)
return self._wrap(theano.tensor.basic.true_div(self, other))
def __floordiv__(self, other):
return theano.tensor.basic.floor_div(self, other)
return self._wrap(theano.tensor.basic.floor_div(self, other))
def __rtruediv__(self, other):
return theano.tensor.basic.true_div(other, self)
return self._wrap(theano.tensor.basic.true_div(other, self))
def __rfloordiv__(self, other):
return theano.tensor.basic.floor_div(other, self)
return self._wrap(theano.tensor.basic.floor_div(other, self))
##### DO NOT USE THESE BECAUSE INPLACE OPS SHOULD BE INSERTED
##### BY OPTIMIZATIONS ONLY
## ARITHMETIC - INPLACE
#def __iadd__(self, other):
# return _add_inplace(self, other)
# return self._wrap(_add_inplace(self, other))
#def __isub__(self, other):
# return _sub_inplace(self, other)
# return self._wrap(_sub_inplace(self, other))
#
#def __imul__(self, other):
# return _mul_inplace(self, other)
# return self._wrap(_mul_inplace(self, other))
#
#def __idiv__(self, other):
# return _div_inplace(self, other)
# return self._wrap(_div_inplace(self, other))
#
#def __ipow__(self, other):
# return _pow_inplace(self, other)
# return self._wrap(_pow_inplace(self, other))
# ARITHMETIC - RIGHT-OPERAND
def __radd__(self, other):
return theano.tensor.basic.add(other, self)
return self._wrap(theano.tensor.basic.add(other, self))
def __rsub__(self, other):
return theano.tensor.basic.sub(other, self)
return self._wrap(theano.tensor.basic.sub(other, self))
def __rmul__(self, other):
return theano.tensor.basic.mul(other, self)
return self._wrap(theano.tensor.basic.mul(other, self))
def __rdiv__(self, other):
return theano.tensor.basic.div_proxy(other, self)
return self._wrap(theano.tensor.basic.div_proxy(other, self))
def __rmod__(self, other):
return theano.tensor.basic.mod(other, self)
return self._wrap(theano.tensor.basic.mod(other, self))
def __rpow__(self, other):
return theano.tensor.basic.pow(other, self)
return self._wrap(theano.tensor.basic.pow(other, self))
# TRANSPOSE
T = property(lambda self: theano.tensor.basic.transpose(self))
T = property(lambda self: self._wrap(theano.tensor.basic.transpose(self)))
def transpose(self, *axes):
"""
......@@ -240,16 +243,16 @@ class _tensor_py_operators:
"""
if len(axes) == 0:
return theano.tensor.basic.transpose(self)
return self._wrap(theano.tensor.basic.transpose(self))
try:
iter(axes[0])
iterable = True
except TypeError:
iterable = False
if len(axes) == 1 and iterable:
return theano.tensor.basic.transpose(self, axes[0])
return self._wrap(theano.tensor.basic.transpose(self, axes[0]))
else:
return theano.tensor.basic.transpose(self, axes)
return self._wrap(theano.tensor.basic.transpose(self, axes))
shape = property(lambda self: theano.tensor.basic.shape(self))
......@@ -257,10 +260,12 @@ class _tensor_py_operators:
# We can't implement __len__ to provide a better error message.
def any(self, axis=None, keepdims=False):
return theano.tensor.basic.any(self, axis=axis, keepdims=keepdims)
return self._wrap(theano.tensor.basic.any(self, axis=axis,
keepdims=keepdims))
def all(self, axis=None, keepdims=False):
return theano.tensor.basic.all(self, axis=axis, keepdims=keepdims)
return self._wrap(theano.tensor.basic.all(self, axis=axis,
keepdims=keepdims))
# Otherwise TensorVariable[:-1] does not work as Python 2.5.1 calls
# __len__ before calling __getitem__. It also does not catch the raised
......@@ -293,7 +298,7 @@ class _tensor_py_operators:
raise ValueError("Expected ndim to be an integer, is " +
str(type(ndim)))
return theano.tensor.basic.reshape(self, shape, ndim=ndim)
return self._wrap(theano.tensor.basic.reshape(self, shape, ndim=ndim))
def dimshuffle(self, *pattern):
"""
......@@ -320,20 +325,21 @@ class _tensor_py_operators:
pattern = pattern[0]
op = theano.tensor.basic.DimShuffle(list(self.type.broadcastable),
pattern)
return op(self)
return self._wrap(op(self))
def flatten(self, ndim=1):
return theano.tensor.basic.flatten(self, ndim)
return self._wrap(theano.tensor.basic.flatten(self, ndim))
def ravel(self):
return theano.tensor.basic.flatten(self)
return self._wrap(theano.tensor.basic.flatten(self))
def diagonal(self, offset=0, axis1=0, axis2=1):
return theano.tensor.basic.diagonal(self, offset, axis1, axis2)
return self._wrap(theano.tensor.basic.diagonal(self, offset,
axis1, axis2))
# CASTING
def astype(self, dtype):
return theano.tensor.cast(self, dtype)
return self._wrap(theano.tensor.cast(self, dtype))
# SLICING
# Do not define __getslice__ here:
......@@ -375,9 +381,9 @@ class _tensor_py_operators:
TensorVariable,
TensorConstant,
theano.tensor.sharedvar.TensorSharedVariable))):
return self.take(arg, axis)
return self._wrap(self.take(arg, axis))
else:
return theano.tensor.subtensor.AdvancedSubtensor()(self, *args)
return self._wrap(theano.tensor.subtensor.AdvancedSubtensor()(self, *args))
else:
if numpy.newaxis in args:
# None (aka np.newaxis) in numpy indexing means to add a
......@@ -399,23 +405,24 @@ class _tensor_py_operators:
new_args.append(arg)
view = self.dimshuffle(pattern)
rval = view.__getitem__(tuple(new_args))
return rval
return self._wrap(rval)
else:
return theano.tensor.subtensor.Subtensor(args)(
self, *theano.tensor.subtensor.Subtensor.collapse(args,
lambda entry: isinstance(entry, Variable)))
return self._wrap(theano.tensor.subtensor.Subtensor(args)(
self, *theano.tensor.subtensor.Subtensor.collapse(args,
lambda entry: isinstance(entry, Variable))))
def take(self, indices, axis=None, mode='raise'):
return theano.tensor.subtensor.take(self, indices, axis, mode)
return self._wrap(theano.tensor.subtensor.take(self, indices,
axis, mode))
# COPYING
def copy(self):
return theano.tensor.basic.tensor_copy(self)
return self._wrap(theano.tensor.basic.tensor_copy(self))
def __iter__(self):
try:
for i in xrange(theano.tensor.basic.get_vector_length(self)):
yield self[i]
yield self._wrap(self[i])
except TypeError:
# This prevents accidental iteration via builtin.sum(self)
raise TypeError(('TensorType does not support iteration. '
......@@ -437,24 +444,26 @@ class _tensor_py_operators:
# extra pseudo-operator symbols
def __dot__(left, right):
return theano.tensor.basic.dot(left, right)
return left._wrap(theano.tensor.basic.dot(left, right))
def __rdot__(right, left):
return theano.tensor.basic.dot(left, right)
return right._wrap(theano.tensor.basic.dot(left, right))
dot = __dot__
def sum(self, axis=None, dtype=None, keepdims=False, acc_dtype=None):
"""See `theano.tensor.sum`"""
return theano.tensor.basic.sum(self, axis=axis,
dtype=dtype, keepdims=keepdims,
acc_dtype=acc_dtype)
return self._wrap(theano.tensor.basic.sum(self, axis=axis,
dtype=dtype,
keepdims=keepdims,
acc_dtype=acc_dtype))
def prod(self, axis=None, dtype=None, keepdims=False, acc_dtype=None):
"""See `theano.tensor.prod`"""
return theano.tensor.basic.prod(self, axis=axis,
dtype=dtype, keepdims=keepdims,
acc_dtype=acc_dtype)
return self._wrap(theano.tensor.basic.prod(self, axis=axis,
dtype=dtype,
keepdims=keepdims,
acc_dtype=acc_dtype))
def norm(self, L, axis=None):
if L == 0:
......@@ -462,80 +471,87 @@ class _tensor_py_operators:
if numpy.isinf(L):
raise NotImplementedError()
# optimizations will/should catch cases like L=1, L=2
return theano.tensor.basic.pow(
return self._wrap(theano.tensor.basic.pow(
theano.tensor.basic.pow(
theano.tensor.basic.abs_(self), L).sum(axis=axis), 1.0 / L)
theano.tensor.basic.abs_(self), L).sum(axis=axis), 1.0 / L))
def mean(self, axis=None, dtype=None, keepdims=False, acc_dtype=None):
"""See `theano.tensor.mean`"""
return theano.tensor.basic.mean(self, axis=axis,
dtype=dtype, keepdims=keepdims,
acc_dtype=acc_dtype)
return self._wrap(theano.tensor.basic.mean(self, axis=axis,
dtype=dtype,
keepdims=keepdims,
acc_dtype=acc_dtype))
def var(self, axis=None, keepdims=False):
"""See `theano.tensor.var`"""
return theano.tensor.basic.var(self, axis, keepdims=keepdims)
return self._wrap(theano.tensor.basic.var(self, axis,
keepdims=keepdims))
def std(self, axis=None, keepdims=False):
"""See `theano.tensor.std`"""
return theano.tensor.basic.std(self, axis, keepdims=keepdims)
return self._wrap(theano.tensor.basic.std(self, axis,
keepdims=keepdims))
def min(self, axis=None, keepdims=False):
"""See `theano.tensor.min`"""
return theano.tensor.basic.min(self, axis, keepdims=keepdims)
return self._wrap(theano.tensor.basic.min(self, axis,
keepdims=keepdims))
def max(self, axis=None, keepdims=False):
"""See `theano.tensor.max`"""
return theano.tensor.basic.max(self, axis, keepdims=keepdims)
return self._wrap(theano.tensor.basic.max(self, axis,
keepdims=keepdims))
def argmin(self, axis=None, keepdims=False):
"""See `theano.tensor.argmin`"""
return theano.tensor.basic.argmin(self, axis, keepdims=keepdims)
return self._wrap(theano.tensor.basic.argmin(self, axis,
keepdims=keepdims))
def argmax(self, axis=None, keepdims=False):
"""See `theano.tensor.argmax`"""
return theano.tensor.basic.argmax(self, axis, keepdims=keepdims)
return self._wrap(theano.tensor.basic.argmax(self, axis,
keepdims=keepdims))
def nonzero(self, return_matrix=False):
"""See `theano.tensor.nonzero`"""
return theano.tensor.basic.nonzero(self, return_matrix=return_matrix)
return self._wrap(theano.tensor.basic.nonzero(self, return_matrix=return_matrix))
def nonzero_values(self):
"""See `theano.tensor.nonzero_values`"""
return theano.tensor.basic.nonzero_values(self)
return self._wrap(theano.tensor.basic.nonzero_values(self))
def sort(self, axis=-1, kind='quicksort', order=None):
"""See `theano.tensor.sort`"""
from theano.tensor.sort import sort
return sort(self, axis, kind, order)
return self._wrap(sort(self, axis, kind, order))
def argsort(self, axis=-1, kind='quicksort', order=None):
"""See `theano.tensor.argsort`"""
from theano.tensor.sort import argsort
return argsort(self, axis, kind, order)
return self._wrap(argsort(self, axis, kind, order))
def clip(self, a_min, a_max):
"Clip (limit) the values in an array."
return theano.tensor.basic.clip(self, a_min, a_max)
return self._wrap(theano.tensor.basic.clip(self, a_min, a_max))
def conj(self):
"""See `theano.tensor.conj`"""
return theano.tensor.basic.conj(self)
return self._wrap(theano.tensor.basic.conj(self))
conjugate = conj
def repeat(self, repeats, axis=None):
"""See `theano.tensor.repeat`"""
from theano.tensor.extra_ops import repeat
return repeat(self, repeats, axis)
return self._wrap(repeat(self, repeats, axis))
def round(self, mode="half_away_from_zero"):
"""See `theano.tensor.round`"""
return theano.tensor.basic.round(self, mode)
return self._wrap(theano.tensor.basic.round(self, mode))
def trace(self):
from theano.sandbox.linalg import trace
return trace(self)
return self._wrap(trace(self))
# TO TRUMP NUMPY OPERATORS
__array_priority__ = 1000
......@@ -543,9 +559,8 @@ class _tensor_py_operators:
def get_scalar_constant_value(self):
return theano.tensor.basic.get_scalar_constant_value(self)
def zeros_like(self, dtype=None):
return theano.tensor.basic.zeros_like(self, dtype=dtype)
def zeros_like(model, dtype=None):
return self._wrap(theano.tensor.basic.zeros_like(model, dtype=dtype))
class TensorVariable(_tensor_py_operators, Variable):
"""Subclass to add the tensor operators to the basic `Variable` class."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论