提交 a6fe5f61 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Thomas Wiecki

Refactor import relationship between theano.tensor.basic and theano.tensor.elemwise

This removes the kludgy object re-definitions that were used to avoid circular import errors.
上级 da798b78
......@@ -1006,12 +1006,6 @@ tensor7s, ftensor7s, dtensor7s, itensor7s, ltensor7s = apply_across_args(
Tensor = TensorType
# This bizarre push-import avoids a circular dependency.
elemwise.as_tensor_variable = as_tensor_variable
elemwise.TensorType = TensorType
elemwise.TensorVariable = TensorVariable
elemwise.TensorConstant = TensorConstant
#########################
# Casting Operations
#########################
......@@ -1114,50 +1108,46 @@ def _conversion(real_value, name):
# what types you are casting to what. That logic is implemented by the
# `cast()` function below.
_convert_to_bool = _conversion(elemwise.Elemwise(scal.convert_to_bool), "bool")
_convert_to_bool = _conversion(Elemwise(scal.convert_to_bool), "bool")
"""Cast to boolean"""
_convert_to_int8 = _conversion(elemwise.Elemwise(scal.convert_to_int8), "int8")
_convert_to_int8 = _conversion(Elemwise(scal.convert_to_int8), "int8")
"""Cast to 8-bit integer"""
_convert_to_int16 = _conversion(elemwise.Elemwise(scal.convert_to_int16), "int16")
_convert_to_int16 = _conversion(Elemwise(scal.convert_to_int16), "int16")
"""Cast to 16-bit integer"""
_convert_to_int32 = _conversion(elemwise.Elemwise(scal.convert_to_int32), "int32")
_convert_to_int32 = _conversion(Elemwise(scal.convert_to_int32), "int32")
"""Cast to 32-bit integer"""
_convert_to_int64 = _conversion(elemwise.Elemwise(scal.convert_to_int64), "int64")
_convert_to_int64 = _conversion(Elemwise(scal.convert_to_int64), "int64")
"""Cast to 64-bit integer"""
_convert_to_uint8 = _conversion(elemwise.Elemwise(scal.convert_to_uint8), "uint8")
_convert_to_uint8 = _conversion(Elemwise(scal.convert_to_uint8), "uint8")
"""Cast to unsigned 8-bit integer"""
_convert_to_uint16 = _conversion(elemwise.Elemwise(scal.convert_to_uint16), "uint16")
_convert_to_uint16 = _conversion(Elemwise(scal.convert_to_uint16), "uint16")
"""Cast to unsigned 16-bit integer"""
_convert_to_uint32 = _conversion(elemwise.Elemwise(scal.convert_to_uint32), "uint32")
_convert_to_uint32 = _conversion(Elemwise(scal.convert_to_uint32), "uint32")
"""Cast to unsigned 32-bit integer"""
_convert_to_uint64 = _conversion(elemwise.Elemwise(scal.convert_to_uint64), "uint64")
_convert_to_uint64 = _conversion(Elemwise(scal.convert_to_uint64), "uint64")
"""Cast to unsigned 64-bit integer"""
_convert_to_float16 = _conversion(elemwise.Elemwise(scal.convert_to_float16), "float16")
_convert_to_float16 = _conversion(Elemwise(scal.convert_to_float16), "float16")
"""Cast to half-precision floating point"""
_convert_to_float32 = _conversion(elemwise.Elemwise(scal.convert_to_float32), "float32")
_convert_to_float32 = _conversion(Elemwise(scal.convert_to_float32), "float32")
"""Cast to single-precision floating point"""
_convert_to_float64 = _conversion(elemwise.Elemwise(scal.convert_to_float64), "float64")
_convert_to_float64 = _conversion(Elemwise(scal.convert_to_float64), "float64")
"""Cast to double-precision floating point"""
_convert_to_complex64 = _conversion(
elemwise.Elemwise(scal.convert_to_complex64), "complex64"
)
_convert_to_complex64 = _conversion(Elemwise(scal.convert_to_complex64), "complex64")
"""Cast to single-precision complex"""
_convert_to_complex128 = _conversion(
elemwise.Elemwise(scal.convert_to_complex128), "complex128"
)
_convert_to_complex128 = _conversion(Elemwise(scal.convert_to_complex128), "complex128")
"""Cast to double-precision complex"""
_cast_mapping = {
......@@ -3194,7 +3184,7 @@ def register_transfer(fn):
"""Create a duplicate of `a` (with duplicated storage)"""
tensor_copy = elemwise.Elemwise(scal.identity)
tensor_copy = Elemwise(scal.identity)
pprint.assign(tensor_copy, printing.IgnorePrinter())
......@@ -3206,7 +3196,7 @@ def sum(input, axis=None, dtype=None, keepdims=False, acc_dtype=None):
When axis is None (the default value), the sum is performed
over the flattened tensor.
For full documentation see ``tensor.elemwise.Sum``.
For full documentation see `Sum`.
In particular please pay attention to the important warning when using
a custom acc_dtype.
......@@ -3219,7 +3209,7 @@ def sum(input, axis=None, dtype=None, keepdims=False, acc_dtype=None):
"""
out = elemwise.Sum(axis=axis, dtype=dtype, acc_dtype=acc_dtype)(input)
out = Sum(axis=axis, dtype=dtype, acc_dtype=acc_dtype)(input)
if keepdims:
out = makeKeepDims(input, out, axis)
......@@ -3264,7 +3254,7 @@ def prod(
return out
class Mean(elemwise.CAReduce):
class Mean(CAReduce):
def __init__(self, axis=None):
super().__init__(scal.add, axis)
assert self.axis is None or len(self.axis) == 1
......@@ -4839,7 +4829,7 @@ def get_vector_length(v):
# `Op`s
if (
v.owner
and isinstance(v.owner.op, theano.tensor.elemwise.Elemwise)
and isinstance(v.owner.op, Elemwise)
and len(v.owner.inputs) == 1
and len(v.owner.outputs) == 1
):
......
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论