提交 a6fe5f61 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Thomas Wiecki

Refactor import relationship between theano.tensor.basic and theano.tensor.elemwise

This removes the kludgy object re-definitions that were used to avoid circular import errors.
上级 da798b78
...@@ -1006,12 +1006,6 @@ tensor7s, ftensor7s, dtensor7s, itensor7s, ltensor7s = apply_across_args( ...@@ -1006,12 +1006,6 @@ tensor7s, ftensor7s, dtensor7s, itensor7s, ltensor7s = apply_across_args(
Tensor = TensorType Tensor = TensorType
# This bizarre push-import avoids a circular dependency.
elemwise.as_tensor_variable = as_tensor_variable
elemwise.TensorType = TensorType
elemwise.TensorVariable = TensorVariable
elemwise.TensorConstant = TensorConstant
######################### #########################
# Casting Operations # Casting Operations
######################### #########################
...@@ -1114,50 +1108,46 @@ def _conversion(real_value, name): ...@@ -1114,50 +1108,46 @@ def _conversion(real_value, name):
# what types you are casting to what. That logic is implemented by the # what types you are casting to what. That logic is implemented by the
# `cast()` function below. # `cast()` function below.
_convert_to_bool = _conversion(elemwise.Elemwise(scal.convert_to_bool), "bool") _convert_to_bool = _conversion(Elemwise(scal.convert_to_bool), "bool")
"""Cast to boolean""" """Cast to boolean"""
_convert_to_int8 = _conversion(elemwise.Elemwise(scal.convert_to_int8), "int8") _convert_to_int8 = _conversion(Elemwise(scal.convert_to_int8), "int8")
"""Cast to 8-bit integer""" """Cast to 8-bit integer"""
_convert_to_int16 = _conversion(elemwise.Elemwise(scal.convert_to_int16), "int16") _convert_to_int16 = _conversion(Elemwise(scal.convert_to_int16), "int16")
"""Cast to 16-bit integer""" """Cast to 16-bit integer"""
_convert_to_int32 = _conversion(elemwise.Elemwise(scal.convert_to_int32), "int32") _convert_to_int32 = _conversion(Elemwise(scal.convert_to_int32), "int32")
"""Cast to 32-bit integer""" """Cast to 32-bit integer"""
_convert_to_int64 = _conversion(elemwise.Elemwise(scal.convert_to_int64), "int64") _convert_to_int64 = _conversion(Elemwise(scal.convert_to_int64), "int64")
"""Cast to 64-bit integer""" """Cast to 64-bit integer"""
_convert_to_uint8 = _conversion(elemwise.Elemwise(scal.convert_to_uint8), "uint8") _convert_to_uint8 = _conversion(Elemwise(scal.convert_to_uint8), "uint8")
"""Cast to unsigned 8-bit integer""" """Cast to unsigned 8-bit integer"""
_convert_to_uint16 = _conversion(elemwise.Elemwise(scal.convert_to_uint16), "uint16") _convert_to_uint16 = _conversion(Elemwise(scal.convert_to_uint16), "uint16")
"""Cast to unsigned 16-bit integer""" """Cast to unsigned 16-bit integer"""
_convert_to_uint32 = _conversion(elemwise.Elemwise(scal.convert_to_uint32), "uint32") _convert_to_uint32 = _conversion(Elemwise(scal.convert_to_uint32), "uint32")
"""Cast to unsigned 32-bit integer""" """Cast to unsigned 32-bit integer"""
_convert_to_uint64 = _conversion(elemwise.Elemwise(scal.convert_to_uint64), "uint64") _convert_to_uint64 = _conversion(Elemwise(scal.convert_to_uint64), "uint64")
"""Cast to unsigned 64-bit integer""" """Cast to unsigned 64-bit integer"""
_convert_to_float16 = _conversion(elemwise.Elemwise(scal.convert_to_float16), "float16") _convert_to_float16 = _conversion(Elemwise(scal.convert_to_float16), "float16")
"""Cast to half-precision floating point""" """Cast to half-precision floating point"""
_convert_to_float32 = _conversion(elemwise.Elemwise(scal.convert_to_float32), "float32") _convert_to_float32 = _conversion(Elemwise(scal.convert_to_float32), "float32")
"""Cast to single-precision floating point""" """Cast to single-precision floating point"""
_convert_to_float64 = _conversion(elemwise.Elemwise(scal.convert_to_float64), "float64") _convert_to_float64 = _conversion(Elemwise(scal.convert_to_float64), "float64")
"""Cast to double-precision floating point""" """Cast to double-precision floating point"""
_convert_to_complex64 = _conversion( _convert_to_complex64 = _conversion(Elemwise(scal.convert_to_complex64), "complex64")
elemwise.Elemwise(scal.convert_to_complex64), "complex64"
)
"""Cast to single-precision complex""" """Cast to single-precision complex"""
_convert_to_complex128 = _conversion( _convert_to_complex128 = _conversion(Elemwise(scal.convert_to_complex128), "complex128")
elemwise.Elemwise(scal.convert_to_complex128), "complex128"
)
"""Cast to double-precision complex""" """Cast to double-precision complex"""
_cast_mapping = { _cast_mapping = {
...@@ -3194,7 +3184,7 @@ def register_transfer(fn): ...@@ -3194,7 +3184,7 @@ def register_transfer(fn):
"""Create a duplicate of `a` (with duplicated storage)""" """Create a duplicate of `a` (with duplicated storage)"""
tensor_copy = elemwise.Elemwise(scal.identity) tensor_copy = Elemwise(scal.identity)
pprint.assign(tensor_copy, printing.IgnorePrinter()) pprint.assign(tensor_copy, printing.IgnorePrinter())
...@@ -3206,7 +3196,7 @@ def sum(input, axis=None, dtype=None, keepdims=False, acc_dtype=None): ...@@ -3206,7 +3196,7 @@ def sum(input, axis=None, dtype=None, keepdims=False, acc_dtype=None):
When axis is None (the default value), the sum is performed When axis is None (the default value), the sum is performed
over the flattened tensor. over the flattened tensor.
For full documentation see ``tensor.elemwise.Sum``. For full documentation see `Sum`.
In particular please pay attention to the important warning when using In particular please pay attention to the important warning when using
a custom acc_dtype. a custom acc_dtype.
...@@ -3219,7 +3209,7 @@ def sum(input, axis=None, dtype=None, keepdims=False, acc_dtype=None): ...@@ -3219,7 +3209,7 @@ def sum(input, axis=None, dtype=None, keepdims=False, acc_dtype=None):
""" """
out = elemwise.Sum(axis=axis, dtype=dtype, acc_dtype=acc_dtype)(input) out = Sum(axis=axis, dtype=dtype, acc_dtype=acc_dtype)(input)
if keepdims: if keepdims:
out = makeKeepDims(input, out, axis) out = makeKeepDims(input, out, axis)
...@@ -3264,7 +3254,7 @@ def prod( ...@@ -3264,7 +3254,7 @@ def prod(
return out return out
class Mean(elemwise.CAReduce): class Mean(CAReduce):
def __init__(self, axis=None): def __init__(self, axis=None):
super().__init__(scal.add, axis) super().__init__(scal.add, axis)
assert self.axis is None or len(self.axis) == 1 assert self.axis is None or len(self.axis) == 1
...@@ -4839,7 +4829,7 @@ def get_vector_length(v): ...@@ -4839,7 +4829,7 @@ def get_vector_length(v):
# `Op`s # `Op`s
if ( if (
v.owner v.owner
and isinstance(v.owner.op, theano.tensor.elemwise.Elemwise) and isinstance(v.owner.op, Elemwise)
and len(v.owner.inputs) == 1 and len(v.owner.inputs) == 1
and len(v.owner.outputs) == 1 and len(v.owner.outputs) == 1
): ):
......
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论