提交 a6fe5f61 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Thomas Wiecki

Refactor import relationship between theano.tensor.basic and theano.tensor.elemwise

This removes the kludgy object re-definitions that were used to avoid circular import errors.
上级 da798b78
...@@ -1006,12 +1006,6 @@ tensor7s, ftensor7s, dtensor7s, itensor7s, ltensor7s = apply_across_args( ...@@ -1006,12 +1006,6 @@ tensor7s, ftensor7s, dtensor7s, itensor7s, ltensor7s = apply_across_args(
Tensor = TensorType Tensor = TensorType
# This bizarre push-import avoids a circular dependency.
elemwise.as_tensor_variable = as_tensor_variable
elemwise.TensorType = TensorType
elemwise.TensorVariable = TensorVariable
elemwise.TensorConstant = TensorConstant
######################### #########################
# Casting Operations # Casting Operations
######################### #########################
...@@ -1114,50 +1108,46 @@ def _conversion(real_value, name): ...@@ -1114,50 +1108,46 @@ def _conversion(real_value, name):
# what types you are casting to what. That logic is implemented by the # what types you are casting to what. That logic is implemented by the
# `cast()` function below. # `cast()` function below.
_convert_to_bool = _conversion(elemwise.Elemwise(scal.convert_to_bool), "bool") _convert_to_bool = _conversion(Elemwise(scal.convert_to_bool), "bool")
"""Cast to boolean""" """Cast to boolean"""
_convert_to_int8 = _conversion(elemwise.Elemwise(scal.convert_to_int8), "int8") _convert_to_int8 = _conversion(Elemwise(scal.convert_to_int8), "int8")
"""Cast to 8-bit integer""" """Cast to 8-bit integer"""
_convert_to_int16 = _conversion(elemwise.Elemwise(scal.convert_to_int16), "int16") _convert_to_int16 = _conversion(Elemwise(scal.convert_to_int16), "int16")
"""Cast to 16-bit integer""" """Cast to 16-bit integer"""
_convert_to_int32 = _conversion(elemwise.Elemwise(scal.convert_to_int32), "int32") _convert_to_int32 = _conversion(Elemwise(scal.convert_to_int32), "int32")
"""Cast to 32-bit integer""" """Cast to 32-bit integer"""
_convert_to_int64 = _conversion(elemwise.Elemwise(scal.convert_to_int64), "int64") _convert_to_int64 = _conversion(Elemwise(scal.convert_to_int64), "int64")
"""Cast to 64-bit integer""" """Cast to 64-bit integer"""
_convert_to_uint8 = _conversion(elemwise.Elemwise(scal.convert_to_uint8), "uint8") _convert_to_uint8 = _conversion(Elemwise(scal.convert_to_uint8), "uint8")
"""Cast to unsigned 8-bit integer""" """Cast to unsigned 8-bit integer"""
_convert_to_uint16 = _conversion(elemwise.Elemwise(scal.convert_to_uint16), "uint16") _convert_to_uint16 = _conversion(Elemwise(scal.convert_to_uint16), "uint16")
"""Cast to unsigned 16-bit integer""" """Cast to unsigned 16-bit integer"""
_convert_to_uint32 = _conversion(elemwise.Elemwise(scal.convert_to_uint32), "uint32") _convert_to_uint32 = _conversion(Elemwise(scal.convert_to_uint32), "uint32")
"""Cast to unsigned 32-bit integer""" """Cast to unsigned 32-bit integer"""
_convert_to_uint64 = _conversion(elemwise.Elemwise(scal.convert_to_uint64), "uint64") _convert_to_uint64 = _conversion(Elemwise(scal.convert_to_uint64), "uint64")
"""Cast to unsigned 64-bit integer""" """Cast to unsigned 64-bit integer"""
_convert_to_float16 = _conversion(elemwise.Elemwise(scal.convert_to_float16), "float16") _convert_to_float16 = _conversion(Elemwise(scal.convert_to_float16), "float16")
"""Cast to half-precision floating point""" """Cast to half-precision floating point"""
_convert_to_float32 = _conversion(elemwise.Elemwise(scal.convert_to_float32), "float32") _convert_to_float32 = _conversion(Elemwise(scal.convert_to_float32), "float32")
"""Cast to single-precision floating point""" """Cast to single-precision floating point"""
_convert_to_float64 = _conversion(elemwise.Elemwise(scal.convert_to_float64), "float64") _convert_to_float64 = _conversion(Elemwise(scal.convert_to_float64), "float64")
"""Cast to double-precision floating point""" """Cast to double-precision floating point"""
_convert_to_complex64 = _conversion( _convert_to_complex64 = _conversion(Elemwise(scal.convert_to_complex64), "complex64")
elemwise.Elemwise(scal.convert_to_complex64), "complex64"
)
"""Cast to single-precision complex""" """Cast to single-precision complex"""
_convert_to_complex128 = _conversion( _convert_to_complex128 = _conversion(Elemwise(scal.convert_to_complex128), "complex128")
elemwise.Elemwise(scal.convert_to_complex128), "complex128"
)
"""Cast to double-precision complex""" """Cast to double-precision complex"""
_cast_mapping = { _cast_mapping = {
...@@ -3194,7 +3184,7 @@ def register_transfer(fn): ...@@ -3194,7 +3184,7 @@ def register_transfer(fn):
"""Create a duplicate of `a` (with duplicated storage)""" """Create a duplicate of `a` (with duplicated storage)"""
tensor_copy = elemwise.Elemwise(scal.identity) tensor_copy = Elemwise(scal.identity)
pprint.assign(tensor_copy, printing.IgnorePrinter()) pprint.assign(tensor_copy, printing.IgnorePrinter())
...@@ -3206,7 +3196,7 @@ def sum(input, axis=None, dtype=None, keepdims=False, acc_dtype=None): ...@@ -3206,7 +3196,7 @@ def sum(input, axis=None, dtype=None, keepdims=False, acc_dtype=None):
When axis is None (the default value), the sum is performed When axis is None (the default value), the sum is performed
over the flattened tensor. over the flattened tensor.
For full documentation see ``tensor.elemwise.Sum``. For full documentation see `Sum`.
In particular please pay attention to the important warning when using In particular please pay attention to the important warning when using
a custom acc_dtype. a custom acc_dtype.
...@@ -3219,7 +3209,7 @@ def sum(input, axis=None, dtype=None, keepdims=False, acc_dtype=None): ...@@ -3219,7 +3209,7 @@ def sum(input, axis=None, dtype=None, keepdims=False, acc_dtype=None):
""" """
out = elemwise.Sum(axis=axis, dtype=dtype, acc_dtype=acc_dtype)(input) out = Sum(axis=axis, dtype=dtype, acc_dtype=acc_dtype)(input)
if keepdims: if keepdims:
out = makeKeepDims(input, out, axis) out = makeKeepDims(input, out, axis)
...@@ -3264,7 +3254,7 @@ def prod( ...@@ -3264,7 +3254,7 @@ def prod(
return out return out
class Mean(elemwise.CAReduce): class Mean(CAReduce):
def __init__(self, axis=None): def __init__(self, axis=None):
super().__init__(scal.add, axis) super().__init__(scal.add, axis)
assert self.axis is None or len(self.axis) == 1 assert self.axis is None or len(self.axis) == 1
...@@ -4839,7 +4829,7 @@ def get_vector_length(v): ...@@ -4839,7 +4829,7 @@ def get_vector_length(v):
# `Op`s # `Op`s
if ( if (
v.owner v.owner
and isinstance(v.owner.op, theano.tensor.elemwise.Elemwise) and isinstance(v.owner.op, Elemwise)
and len(v.owner.inputs) == 1 and len(v.owner.inputs) == 1
and len(v.owner.outputs) == 1 and len(v.owner.outputs) == 1
): ):
......
...@@ -2,55 +2,52 @@ from copy import copy ...@@ -2,55 +2,52 @@ from copy import copy
import numpy as np import numpy as np
import theano import theano.tensor.basic
from theano import scalar
from theano.configdefaults import config from theano.configdefaults import config
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
from theano.graph.basic import Apply from theano.graph.basic import Apply
from theano.graph.null_type import NullType from theano.graph.null_type import NullType
from theano.graph.op import COp, ExternalCOp, OpenMPOp from theano.graph.op import COp, ExternalCOp, OpenMPOp
from theano.graph.params_type import ParamsType from theano.graph.params_type import ParamsType
from theano.graph.utils import MethodNotDefined
from theano.link.c.basic import failure_code
from theano.misc.frozendict import frozendict from theano.misc.frozendict import frozendict
from theano.misc.safe_asarray import _asarray from theano.misc.safe_asarray import _asarray
from theano.printing import FunctionPrinter, pprint from theano.printing import FunctionPrinter, pprint
from theano.scalar import get_scalar_type from theano.scalar import get_scalar_type
from theano.scalar.basic import (
AND,
OR,
XOR,
Add,
BinaryScalarOp,
Mul,
Scalar,
ScalarMaximum,
ScalarMinimum,
)
from theano.scalar.basic import add as scalar_add
from theano.scalar.basic import and_
from theano.scalar.basic import bool as scalar_bool
from theano.scalar.basic import identity as scalar_identity
from theano.scalar.basic import mul as scalar_mul
from theano.scalar.basic import (
or_,
scalar_maximum,
scalar_minimum,
second,
transfer_type,
upcast,
upcast_out,
)
from theano.tensor import elemwise_cgen as cgen from theano.tensor import elemwise_cgen as cgen
from theano.tensor.type import TensorType
from theano.utils import uniq from theano.utils import uniq
_numpy_ver = [int(n) for n in np.__version__.split(".")[:2]] _numpy_ver = [int(n) for n in np.__version__.split(".")[:2]]
# tensor depends on elemwise to provide definitions for several ops
# but elemwise needs to make TensorType instances, so we have these as
# placeholders and the tensor module fills them
def as_tensor_variable(data):
raise Exception(
"Circular dependencies prevent using this" "here. import tensor before elemwise"
)
def TensorType(*inputs, **kwargs):
raise Exception(
"Circular dependencies prevent "
"using this here. import tensor before elemwise"
)
def TensorVariable(*inputs, **kwargs):
raise Exception(
"Circular dependencies "
"prevent using this here. import tensor before elemwise"
)
def TensorConstant(*inputs, **kwargs):
raise Exception(
"Circular dependencies "
"prevent using this here. import tensor before elemwise"
)
class DimShuffle(ExternalCOp): class DimShuffle(ExternalCOp):
""" """
Allows to reorder the dimensions of a tensor or insert or remove Allows to reorder the dimensions of a tensor or insert or remove
...@@ -138,9 +135,9 @@ class DimShuffle(ExternalCOp): ...@@ -138,9 +135,9 @@ class DimShuffle(ExternalCOp):
# because of importation issues related to TensorType. # because of importation issues related to TensorType.
return ParamsType( return ParamsType(
input_broadcastable=TensorType(dtype="bool", broadcastable=(False,)), input_broadcastable=TensorType(dtype="bool", broadcastable=(False,)),
_new_order=theano.tensor.lvector, _new_order=theano.tensor.basic.lvector,
transposition=TensorType(dtype="uint32", broadcastable=(False,)), transposition=TensorType(dtype="uint32", broadcastable=(False,)),
inplace=theano.scalar.bool, inplace=scalar_bool,
) )
@property @property
...@@ -221,7 +218,7 @@ class DimShuffle(ExternalCOp): ...@@ -221,7 +218,7 @@ class DimShuffle(ExternalCOp):
super().__init__([self.c_func_file], self.c_func_name) super().__init__([self.c_func_file], self.c_func_name)
def make_node(self, _input): def make_node(self, _input):
input = as_tensor_variable(_input) input = theano.tensor.basic.as_tensor_variable(_input)
ib = tuple(input.type.broadcastable) ib = tuple(input.type.broadcastable)
if not ib == self.input_broadcastable: if not ib == self.input_broadcastable:
if len(ib) != len(self.input_broadcastable): if len(ib) != len(self.input_broadcastable):
...@@ -294,6 +291,8 @@ class DimShuffle(ExternalCOp): ...@@ -294,6 +291,8 @@ class DimShuffle(ExternalCOp):
return self(*eval_points, **dict(return_list=True)) return self(*eval_points, **dict(return_list=True))
def grad(self, inp, grads): def grad(self, inp, grads):
from theano.tensor.basic import as_tensor_variable, discrete_dtypes
(x,) = inp (x,) = inp
(gz,) = grads (gz,) = grads
gz = as_tensor_variable(gz) gz = as_tensor_variable(gz)
...@@ -304,12 +303,12 @@ class DimShuffle(ExternalCOp): ...@@ -304,12 +303,12 @@ class DimShuffle(ExternalCOp):
# Do not make the DimShuffle inplace as an optimization at the # Do not make the DimShuffle inplace as an optimization at the
# canonicalization optimization phase will remove the inplace. # canonicalization optimization phase will remove the inplace.
# The inplace will be reintroduced automatically later in the graph. # The inplace will be reintroduced automatically later in the graph.
if inp[0].dtype in theano.tensor.discrete_dtypes: if inp[0].dtype in discrete_dtypes:
return [inp[0].zeros_like(dtype=config.floatX)] return [inp[0].zeros_like(dtype=config.floatX)]
else: else:
return [ return [
DimShuffle(gz.type.broadcastable, grad_order)( DimShuffle(gz.type.broadcastable, grad_order)(
Elemwise(scalar.identity)(gz) Elemwise(scalar_identity)(gz)
) )
] ]
...@@ -356,12 +355,12 @@ class Elemwise(OpenMPOp): ...@@ -356,12 +355,12 @@ class Elemwise(OpenMPOp):
being generalized to tensors. In particular, if the calculations being generalized to tensors. In particular, if the calculations
for an output are done inplace on an input, the output type must for an output are done inplace on an input, the output type must
be the same as the corresponding input type (see the doc of be the same as the corresponding input type (see the doc of
scalar.ScalarOp to get help about controlling the output type) `ScalarOp` to get help about controlling the output type)
Parameters Parameters
---------- ----------
scalar_op scalar_op
An instance of a subclass of scalar.ScalarOp which works uniquely An instance of a subclass of `ScalarOp` which works uniquely
on scalars. on scalars.
inplace_pattern inplace_pattern
A dictionary that maps the index of an output to the A dictionary that maps the index of an output to the
...@@ -496,7 +495,7 @@ second dimension ...@@ -496,7 +495,7 @@ second dimension
is left-completed to the greatest number of dimensions with 1s is left-completed to the greatest number of dimensions with 1s
using DimShuffle. using DimShuffle.
""" """
inputs = list(map(as_tensor_variable, inputs)) inputs = list(map(theano.tensor.basic.as_tensor_variable, inputs))
out_dtypes, out_broadcastables, inputs = self.get_output_info( out_dtypes, out_broadcastables, inputs = self.get_output_info(
DimShuffle, *inputs DimShuffle, *inputs
) )
...@@ -525,7 +524,7 @@ second dimension ...@@ -525,7 +524,7 @@ second dimension
# make such that _bgrads computes only the gradients of the # make such that _bgrads computes only the gradients of the
# current output on the inputs ( and not all outputs) # current output on the inputs ( and not all outputs)
ograds = [x.zeros_like() for x in outs] ograds = [x.zeros_like() for x in outs]
ograds[idx] = theano.tensor.ones_like(out) ograds[idx] = theano.tensor.basic.ones_like(out)
bgrads = self._bgrad(inputs, outs, ograds) bgrads = self._bgrad(inputs, outs, ograds)
rop_out = None rop_out = None
...@@ -559,32 +558,30 @@ second dimension ...@@ -559,32 +558,30 @@ second dimension
return [[True for output in node.outputs] for ipt in node.inputs] return [[True for output in node.outputs] for ipt in node.inputs]
def L_op(self, inputs, outs, ograds): def L_op(self, inputs, outs, ograds):
from theano.tensor.basic import continuous_dtypes, discrete_dtypes
from theano.tensor.basic import sum as tt_sum
# compute grad with respect to broadcasted input # Compute grad with respect to broadcasted input
rval = self._bgrad(inputs, outs, ograds) rval = self._bgrad(inputs, outs, ograds)
# TODO: make sure that zeros are clearly identifiable # TODO: make sure that zeros are clearly identifiable
# to the gradient.grad method when the outputs have # to the gradient.grad method when the outputs have
# some integer and some floating point outputs # some integer and some floating point outputs
if any(out.type.dtype not in theano.tensor.continuous_dtypes for out in outs): if any(out.type.dtype not in continuous_dtypes for out in outs):
# For integer output, return value may # For integer output, return value may only be zero or undefined
# only be zero or undefined # We don't bother with trying to check that the scalar ops
# We don't bother with trying to check # correctly returned something that evaluates to 0, we just make
# that the scalar ops correctly # the return value obviously zero so that gradient.grad can tell
# returned something that evaluates to 0, # this op did the right thing.
# we just make the return
# value obviously zero so that gradient.grad
# can tell this op did
# the right thing.
new_rval = [] new_rval = []
for elem, ipt in zip(rval, inputs): for elem, ipt in zip(rval, inputs):
if isinstance(elem.type, (NullType, DisconnectedType)): if isinstance(elem.type, (NullType, DisconnectedType)):
new_rval.append(elem) new_rval.append(elem)
else: else:
elem = ipt.zeros_like() elem = ipt.zeros_like()
if str(elem.type.dtype) not in theano.tensor.continuous_dtypes: if str(elem.type.dtype) not in continuous_dtypes:
elem = elem.astype(config.floatX) elem = elem.astype(config.floatX)
assert str(elem.type.dtype) not in theano.tensor.discrete_dtypes assert str(elem.type.dtype) not in discrete_dtypes
new_rval.append(elem) new_rval.append(elem)
return new_rval return new_rval
...@@ -593,9 +590,9 @@ second dimension ...@@ -593,9 +590,9 @@ second dimension
if isinstance(rval[i].type, (NullType, DisconnectedType)): if isinstance(rval[i].type, (NullType, DisconnectedType)):
continue continue
# list of all the dimensions that are broadcastable for input[i] so # List of all the dimensions that are broadcastable for input[i] so
# we can sum over them # we can sum over them
# todo: only count dimensions that were effectively broadcasted # TODO: only count dimensions that were effectively broadcasted
to_sum = [ to_sum = [
j j
for j, bcast in enumerate(ipt.type.broadcastable) for j, bcast in enumerate(ipt.type.broadcastable)
...@@ -603,10 +600,8 @@ second dimension ...@@ -603,10 +600,8 @@ second dimension
] ]
if to_sum: if to_sum:
sr = theano.tensor.basic.sum(rval[i], axis=to_sum, keepdims=True) sr = tt_sum(rval[i], axis=to_sum, keepdims=True)
rval[i] = sr rval[i] = sr
# close if
# close for
return rval return rval
...@@ -653,7 +648,9 @@ second dimension ...@@ -653,7 +648,9 @@ second dimension
# the gradient contains a constant, translate it as # the gradient contains a constant, translate it as
# an equivalent TensorType of size 1 and proper number of # an equivalent TensorType of size 1 and proper number of
# dimensions # dimensions
res = theano.tensor.constant(np.asarray(r.data), dtype=r.type.dtype) res = theano.tensor.basic.constant(
np.asarray(r.data), dtype=r.type.dtype
)
return DimShuffle((), ["x"] * nd)(res) return DimShuffle((), ["x"] * nd)(res)
new_r = Elemwise(node.op, {})(*[transform(ipt) for ipt in node.inputs]) new_r = Elemwise(node.op, {})(*[transform(ipt) for ipt in node.inputs])
...@@ -721,9 +718,9 @@ second dimension ...@@ -721,9 +718,9 @@ second dimension
# when the input is complex. So add it only when inputs is int. # when the input is complex. So add it only when inputs is int.
out_dtype = node.outputs[0].dtype out_dtype = node.outputs[0].dtype
if ( if (
out_dtype in theano.tensor.float_dtypes out_dtype in theano.tensor.basic.float_dtypes
and isinstance(self.nfunc, np.ufunc) and isinstance(self.nfunc, np.ufunc)
and node.inputs[0].dtype in theano.tensor.discrete_dtypes and node.inputs[0].dtype in theano.tensor.basic.discrete_dtypes
): ):
char = np.sctype2char(out_dtype) char = np.sctype2char(out_dtype)
sig = char * node.nin + "->" + char * node.nout sig = char * node.nin + "->" + char * node.nout
...@@ -871,7 +868,7 @@ second dimension ...@@ -871,7 +868,7 @@ second dimension
# there must be some input that is not broadcastable in # there must be some input that is not broadcastable in
# dimension 'dim' # dimension 'dim'
for ishp, i in zip(i_shapes, node.inputs): for ishp, i in zip(i_shapes, node.inputs):
if isinstance(i.type, theano.scalar.Scalar): if isinstance(i.type, Scalar):
continue # we skip scalar continue # we skip scalar
if not i.type.broadcastable[dim]: if not i.type.broadcastable[dim]:
# input i is not broadcastable in position dim # input i is not broadcastable in position dim
...@@ -1032,7 +1029,7 @@ second dimension ...@@ -1032,7 +1029,7 @@ second dimension
if self.openmp: if self.openmp:
# If we are using openmp, we need to get rid of the "goto" # If we are using openmp, we need to get rid of the "goto"
# statement in sub['fail']. For now we recreate it here. # statement in sub['fail']. For now we recreate it here.
fail = theano.link.c.basic.failure_code(sub, use_goto=False) fail = failure_code(sub, use_goto=False)
else: else:
fail = sub["fail"] fail = sub["fail"]
task_code = self.scalar_op.c_code( task_code = self.scalar_op.c_code(
...@@ -1139,7 +1136,7 @@ second dimension ...@@ -1139,7 +1136,7 @@ second dimension
contig = self.scalar_op.c_code_contiguous( contig = self.scalar_op.c_code_contiguous(
node, nodename + "_scalar_contig_", _inames, onames, sub node, nodename + "_scalar_contig_", _inames, onames, sub
) )
except theano.graph.utils.MethodNotDefined: except MethodNotDefined:
# Try to make one generic version, this will help the # Try to make one generic version, this will help the
# compiler to vectorize the code as their won't be as # compiler to vectorize the code as their won't be as
# many ptr and the stride will be hard coded. # many ptr and the stride will be hard coded.
...@@ -1343,24 +1340,25 @@ class CAReduce(COp): ...@@ -1343,24 +1340,25 @@ class CAReduce(COp):
self.set_ufunc(scalar_op) self.set_ufunc(scalar_op)
def set_ufunc(self, scalar_op): def set_ufunc(self, scalar_op):
# This is probably a speed up of the implementation # TODO FIXME: Why would we ever do this, instead of allowing the `Op`
if isinstance(scalar_op, theano.scalar.basic.Add): # itself to tell us which `ufunc` it should use?
if isinstance(scalar_op, Add):
self.ufunc = np.add self.ufunc = np.add
elif isinstance(scalar_op, theano.scalar.basic.Mul): elif isinstance(scalar_op, Mul):
self.ufunc = np.multiply self.ufunc = np.multiply
elif isinstance(scalar_op, theano.scalar.basic.ScalarMaximum): elif isinstance(scalar_op, ScalarMaximum):
self.ufunc = np.maximum self.ufunc = np.maximum
elif isinstance(scalar_op, theano.scalar.basic.ScalarMinimum): elif isinstance(scalar_op, ScalarMinimum):
self.ufunc = np.minimum self.ufunc = np.minimum
elif isinstance(scalar_op, theano.scalar.basic.AND) and _numpy_ver >= [1, 12]: elif isinstance(scalar_op, AND) and _numpy_ver >= [1, 12]:
# numpy.bitwise_and.identity was incorrect for versions before # numpy.bitwise_and.identity was incorrect for versions before
# 1.12 (it was 1 instead of -1), so we skip it in that case. # 1.12 (it was 1 instead of -1), so we skip it in that case.
# We will fall back to the "else:" case, which defines a # We will fall back to the "else:" case, which defines a
# ufunc without identity. # ufunc without identity.
self.ufunc = np.bitwise_and self.ufunc = np.bitwise_and
elif isinstance(scalar_op, theano.scalar.basic.OR): elif isinstance(scalar_op, OR):
self.ufunc = np.bitwise_or self.ufunc = np.bitwise_or
elif isinstance(scalar_op, theano.scalar.basic.XOR): elif isinstance(scalar_op, XOR):
self.ufunc = np.bitwise_xor self.ufunc = np.bitwise_xor
else: else:
self.ufunc = np.frompyfunc(scalar_op.impl, 2, 1) self.ufunc = np.frompyfunc(scalar_op.impl, 2, 1)
...@@ -1369,6 +1367,8 @@ class CAReduce(COp): ...@@ -1369,6 +1367,8 @@ class CAReduce(COp):
return input_dtype return input_dtype
def make_node(self, input): def make_node(self, input):
from theano.tensor.basic import as_tensor_variable
input = as_tensor_variable(input) input = as_tensor_variable(input)
if self.axis is not None: if self.axis is not None:
...@@ -1492,11 +1492,13 @@ class CAReduce(COp): ...@@ -1492,11 +1492,13 @@ class CAReduce(COp):
idtype = input.type.dtype_specs()[1] idtype = input.type.dtype_specs()[1]
odtype = output.type.dtype_specs()[1] odtype = output.type.dtype_specs()[1]
if hasattr(self, "acc_dtype") and self.acc_dtype is not None: acc_dtype = getattr(self, "acc_dtype", None)
if self.acc_dtype == "float16":
raise theano.graph.utils.MethodNotDefined("no c_code for " "float16") if acc_dtype is not None:
if acc_dtype == "float16":
raise MethodNotDefined("no c_code for float16")
acc_type = TensorType( acc_type = TensorType(
broadcastable=node.outputs[0].broadcastable, dtype=self.acc_dtype broadcastable=node.outputs[0].broadcastable, dtype=acc_dtype
) )
adtype = acc_type.dtype_specs()[1] adtype = acc_type.dtype_specs()[1]
else: else:
...@@ -1509,9 +1511,9 @@ class CAReduce(COp): ...@@ -1509,9 +1511,9 @@ class CAReduce(COp):
if len(axis) == 0: if len(axis) == 0:
# The acc_dtype is never a downcast compared to the input dtype # The acc_dtype is never a downcast compared to the input dtype
# So we just need a cast to the output dtype. # So we just need a cast to the output dtype.
var = theano.tensor.cast(input, node.outputs[0].dtype) var = theano.tensor.basic.cast(input, node.outputs[0].dtype)
if var is input: if var is input:
var = Elemwise(scalar.identity)(input) var = Elemwise(scalar_identity)(input)
assert var.dtype == node.outputs[0].dtype assert var.dtype == node.outputs[0].dtype
return var.owner.op._c_all(var.owner, name, inames, onames, sub) return var.owner.op._c_all(var.owner, name, inames, onames, sub)
...@@ -1570,8 +1572,8 @@ class CAReduce(COp): ...@@ -1570,8 +1572,8 @@ class CAReduce(COp):
if hasattr(self.scalar_op, "identity"): if hasattr(self.scalar_op, "identity"):
identity = self.scalar_op.identity identity = self.scalar_op.identity
elif self.scalar_op in [scalar.scalar_maximum, scalar.scalar_minimum]: elif self.scalar_op in [scalar_maximum, scalar_minimum]:
if self.scalar_op == scalar.scalar_maximum: if self.scalar_op == scalar_maximum:
scal_name = "maximum" scal_name = "maximum"
if input.type.dtype in ["float32", "float64"]: if input.type.dtype in ["float32", "float64"]:
identity = "-__builtin_inf()" identity = "-__builtin_inf()"
...@@ -1580,7 +1582,7 @@ class CAReduce(COp): ...@@ -1580,7 +1582,7 @@ class CAReduce(COp):
identity = "0" identity = "0"
else: else:
identity = "NPY_MIN_" + str(input.type.dtype).upper() identity = "NPY_MIN_" + str(input.type.dtype).upper()
if self.scalar_op == scalar.scalar_minimum: if self.scalar_op == scalar_minimum:
scal_name = "minimum" scal_name = "minimum"
if input.type.dtype in ["float32", "float64"]: if input.type.dtype in ["float32", "float64"]:
identity = "__builtin_inf()" identity = "__builtin_inf()"
...@@ -1723,7 +1725,7 @@ class All(CAReduce): ...@@ -1723,7 +1725,7 @@ class All(CAReduce):
nfunc_spec = ("all", 1, 1) nfunc_spec = ("all", 1, 1)
def __init__(self, axis=None): def __init__(self, axis=None):
CAReduce.__init__(self, scalar.and_, axis) super().__init__(and_, axis)
def _output_dtype(self, idtype): def _output_dtype(self, idtype):
return "bool" return "bool"
...@@ -1735,9 +1737,11 @@ class All(CAReduce): ...@@ -1735,9 +1737,11 @@ class All(CAReduce):
return "All{%s}" % ", ".join(map(str, self.axis)) return "All{%s}" % ", ".join(map(str, self.axis))
def make_node(self, input): def make_node(self, input):
from theano.tensor.basic import as_tensor_variable, neq
input = as_tensor_variable(input) input = as_tensor_variable(input)
if input.dtype != "bool": if input.dtype != "bool":
input = theano.tensor.neq(input, 0) input = neq(input, 0)
ret = super().make_node(input) ret = super().make_node(input)
return ret return ret
...@@ -1756,7 +1760,7 @@ class Any(CAReduce): ...@@ -1756,7 +1760,7 @@ class Any(CAReduce):
nfunc_spec = ("any", 1, 1) nfunc_spec = ("any", 1, 1)
def __init__(self, axis=None): def __init__(self, axis=None):
CAReduce.__init__(self, scalar.or_, axis) super().__init__(or_, axis)
def _output_dtype(self, idtype): def _output_dtype(self, idtype):
return "bool" return "bool"
...@@ -1768,9 +1772,11 @@ class Any(CAReduce): ...@@ -1768,9 +1772,11 @@ class Any(CAReduce):
return "Any{%s}" % ", ".join(map(str, self.axis)) return "Any{%s}" % ", ".join(map(str, self.axis))
def make_node(self, input): def make_node(self, input):
from theano.tensor.basic import as_tensor_variable, neq
input = as_tensor_variable(input) input = as_tensor_variable(input)
if input.dtype != "bool": if input.dtype != "bool":
input = theano.tensor.neq(input, 0) input = neq(input, 0)
ret = super().make_node(input) ret = super().make_node(input)
return ret return ret
...@@ -1835,7 +1841,7 @@ class CAReduceDtype(CAReduce): ...@@ -1835,7 +1841,7 @@ class CAReduceDtype(CAReduce):
__props__ = ("scalar_op", "axis", "dtype", "acc_dtype") __props__ = ("scalar_op", "axis", "dtype", "acc_dtype")
def __init__(self, scalar_op, axis=None, dtype=None, acc_dtype=None): def __init__(self, scalar_op, axis=None, dtype=None, acc_dtype=None):
CAReduce.__init__(self, scalar_op, axis=axis) super().__init__(scalar_op, axis=axis)
self.dtype = dtype self.dtype = dtype
self.acc_dtype = acc_dtype self.acc_dtype = acc_dtype
...@@ -1895,14 +1901,14 @@ class CAReduceDtype(CAReduce): ...@@ -1895,14 +1901,14 @@ class CAReduceDtype(CAReduce):
complex64="complex128", complex64="complex128",
).get(idtype, idtype) ).get(idtype, idtype)
elif ( elif (
acc_dtype in theano.tensor.continuous_dtypes acc_dtype in theano.tensor.basic.continuous_dtypes
and idtype in theano.tensor.discrete_dtypes and idtype in theano.tensor.basic.discrete_dtypes
): ):
# Specifying a continuous accumulator for discrete input is OK # Specifying a continuous accumulator for discrete input is OK
return acc_dtype return acc_dtype
else: else:
# The conversion has to be considered an upcast. # The conversion has to be considered an upcast.
upcasted_dtype = scalar.upcast(idtype, acc_dtype) upcasted_dtype = upcast(idtype, acc_dtype)
if acc_dtype != upcasted_dtype: if acc_dtype != upcasted_dtype:
raise TypeError( raise TypeError(
f"Cannot build {self} node with input dtype {idtype} " f"Cannot build {self} node with input dtype {idtype} "
...@@ -1922,11 +1928,13 @@ class CAReduceDtype(CAReduce): ...@@ -1922,11 +1928,13 @@ class CAReduceDtype(CAReduce):
# We need to redefine make_node so that, if self.dtype is None, # We need to redefine make_node so that, if self.dtype is None,
# we can infer what dtype should be, and create a node from an Op # we can infer what dtype should be, and create a node from an Op
# of the appropriate dtype. # of the appropriate dtype.
input = as_tensor_variable(input) input = theano.tensor.basic.as_tensor_variable(input)
dtype = self._output_dtype(input.dtype) dtype = self._output_dtype(input.dtype)
acc_dtype = self._acc_dtype(input.dtype) acc_dtype = self._acc_dtype(input.dtype)
assert dtype is not None assert dtype is not None
assert acc_dtype is not None assert acc_dtype is not None
if dtype == self.dtype and acc_dtype == self.acc_dtype: if dtype == self.dtype and acc_dtype == self.acc_dtype:
# Don't build another instance # Don't build another instance
op = self op = self
...@@ -1937,7 +1945,10 @@ class CAReduceDtype(CAReduce): ...@@ -1937,7 +1945,10 @@ class CAReduceDtype(CAReduce):
op.acc_dtype = acc_dtype op.acc_dtype = acc_dtype
assert op.acc_dtype is not None assert op.acc_dtype is not None
return CAReduce.make_node(op, input)
# TODO: Why doesn't `make_node` just take these
# automatically-determined values as arguments?
return super(CAReduceDtype, op).make_node(input)
def __str__(self): def __str__(self):
name = self.__class__.__name__ name = self.__class__.__name__
...@@ -1990,9 +2001,7 @@ class Sum(CAReduceDtype): ...@@ -1990,9 +2001,7 @@ class Sum(CAReduceDtype):
nfunc_spec = ("sum", 1, 1) nfunc_spec = ("sum", 1, 1)
def __init__(self, axis=None, dtype=None, acc_dtype=None): def __init__(self, axis=None, dtype=None, acc_dtype=None):
CAReduceDtype.__init__( super().__init__(scalar_add, axis=axis, dtype=dtype, acc_dtype=acc_dtype)
self, scalar.add, axis=axis, dtype=dtype, acc_dtype=acc_dtype
)
def __str__(self): def __str__(self):
name = self.__class__.__name__ name = self.__class__.__name__
...@@ -2003,9 +2012,11 @@ class Sum(CAReduceDtype): ...@@ -2003,9 +2012,11 @@ class Sum(CAReduceDtype):
return f"{name}{{{axis}acc_dtype={self.acc_dtype}}}" return f"{name}{{{axis}acc_dtype={self.acc_dtype}}}"
def L_op(self, inp, out, grads): def L_op(self, inp, out, grads):
from theano.tensor.basic import as_tensor_variable
(x,) = inp (x,) = inp
if out[0].dtype not in theano.tensor.continuous_dtypes: if out[0].dtype not in theano.tensor.basic.continuous_dtypes:
return [x.zeros_like(dtype=config.floatX)] return [x.zeros_like(dtype=config.floatX)]
(gz,) = grads (gz,) = grads
...@@ -2024,7 +2035,7 @@ class Sum(CAReduceDtype): ...@@ -2024,7 +2035,7 @@ class Sum(CAReduceDtype):
new_dims.append(i) new_dims.append(i)
i += 1 i += 1
ds_op = DimShuffle(gz.type.broadcastable, new_dims) ds_op = DimShuffle(gz.type.broadcastable, new_dims)
gx = Elemwise(scalar.second)(x, ds_op(gz)) gx = Elemwise(second)(x, ds_op(gz))
return [gx] return [gx]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
...@@ -2039,7 +2050,7 @@ class Prod(CAReduceDtype): ...@@ -2039,7 +2050,7 @@ class Prod(CAReduceDtype):
""" """
Multiplies all the values of a tensor along the specified axis(es). Multiplies all the values of a tensor along the specified axis(es).
Equivalent to `CAReduce(scalar.prod, axis = axis)`, with the Equivalent to `CAReduce(scalar.mul, axis = axis)`, with the
difference that this defines the gradient of prod wrt its tensor difference that this defines the gradient of prod wrt its tensor
input. input.
...@@ -2049,9 +2060,7 @@ class Prod(CAReduceDtype): ...@@ -2049,9 +2060,7 @@ class Prod(CAReduceDtype):
nfunc_spec = ("sum", 1, 1) nfunc_spec = ("sum", 1, 1)
def __init__(self, axis=None, dtype=None, acc_dtype=None, no_zeros_in_input=False): def __init__(self, axis=None, dtype=None, acc_dtype=None, no_zeros_in_input=False):
CAReduceDtype.__init__( super().__init__(scalar_mul, axis=axis, dtype=dtype, acc_dtype=acc_dtype)
self, scalar.mul, axis=axis, dtype=dtype, acc_dtype=acc_dtype
)
self.no_zeros_in_input = no_zeros_in_input self.no_zeros_in_input = no_zeros_in_input
def __setstate__(self, dct): def __setstate__(self, dct):
...@@ -2106,13 +2115,14 @@ class Prod(CAReduceDtype): ...@@ -2106,13 +2115,14 @@ class Prod(CAReduceDtype):
based on the result of this count. based on the result of this count.
""" """
from theano.tensor.basic import as_tensor_variable, discrete_dtypes, eq, neq
from theano.tensor.basic import sum as tt_sum
from theano.tensor.basic import switch
(prod_in,) = inp (prod_in,) = inp
(gz,) = grads (gz,) = grads
if ( if out[0].dtype in discrete_dtypes or self.acc_dtype in discrete_dtypes:
out[0].dtype in theano.tensor.discrete_dtypes
or self.acc_dtype in theano.tensor.discrete_dtypes
):
# There is an int conversion in the way # There is an int conversion in the way
return [prod_in.zeros_like(dtype=config.floatX)] return [prod_in.zeros_like(dtype=config.floatX)]
...@@ -2147,17 +2157,16 @@ class Prod(CAReduceDtype): ...@@ -2147,17 +2157,16 @@ class Prod(CAReduceDtype):
# this handles inputs with zeros, but only certain input shapes # this handles inputs with zeros, but only certain input shapes
return [grad_case_without_zeros] return [grad_case_without_zeros]
else: else:
T = theano.tensor
where_zeros = T.eq(prod_in, 0.0) where_zeros = eq(prod_in, 0.0)
sum_where_zeros = T.sum(where_zeros, axis=self.axis) sum_where_zeros = tt_sum(where_zeros, axis=self.axis)
groups_with_single_zero = T.eq(sum_where_zeros, 1).dimshuffle(new_dims) groups_with_single_zero = eq(sum_where_zeros, 1).dimshuffle(new_dims)
# tensor with 0 everywhere except for those places where # tensor with 0 everywhere except for those places where
# a 0 part of a group with a single zero was to be found # a 0 part of a group with a single zero was to be found
where_single_zero = groups_with_single_zero * where_zeros where_single_zero = groups_with_single_zero * where_zeros
# further optimization to avoid computing ProdWithoutZeros # further optimization to avoid computing ProdWithoutZeros
# if the incoming gradient is 0 # if the incoming gradient is 0
where_gz_not_zero = T.neq(gz, 0.0) where_gz_not_zero = neq(gz, 0.0)
# only take ProdWithoutZeros for the groups with single zeros # only take ProdWithoutZeros for the groups with single zeros
# with non-null incoming gradient # with non-null incoming gradient
where_to_take_prod_without_zeros = ( where_to_take_prod_without_zeros = (
...@@ -2173,12 +2182,12 @@ class Prod(CAReduceDtype): ...@@ -2173,12 +2182,12 @@ class Prod(CAReduceDtype):
prod_without_zeros = ProdWithoutZeros(axis=self.axis)(prod_without_zeros_in) prod_without_zeros = ProdWithoutZeros(axis=self.axis)(prod_without_zeros_in)
prod_without_zeros = prod_without_zeros.dimshuffle(new_dims) prod_without_zeros = prod_without_zeros.dimshuffle(new_dims)
groups_without_zeros = T.eq(sum_where_zeros, 0).dimshuffle(new_dims) groups_without_zeros = eq(sum_where_zeros, 0).dimshuffle(new_dims)
final_grad = T.switch( final_grad = switch(
groups_without_zeros, groups_without_zeros,
grad_case_without_zeros, grad_case_without_zeros,
T.switch(where_single_zero, prod_without_zeros, 0.0) * gz, switch(where_single_zero, prod_without_zeros, 0.0) * gz,
) )
return [final_grad] return [final_grad]
...@@ -2187,7 +2196,7 @@ class Prod(CAReduceDtype): ...@@ -2187,7 +2196,7 @@ class Prod(CAReduceDtype):
return (1,) return (1,)
class MulWithoutZeros(scalar.BinaryScalarOp): class MulWithoutZeros(BinaryScalarOp):
# "identity" here is zero, as in Reduce we don't want to start # "identity" here is zero, as in Reduce we don't want to start
# with reducing (1, something_else): this leads to the erroneous # with reducing (1, something_else): this leads to the erroneous
# case where a vector of zeros is reduced by binary reductions # case where a vector of zeros is reduced by binary reductions
...@@ -2217,7 +2226,7 @@ class MulWithoutZeros(scalar.BinaryScalarOp): ...@@ -2217,7 +2226,7 @@ class MulWithoutZeros(scalar.BinaryScalarOp):
return (1,) return (1,)
mul_without_zeros = MulWithoutZeros(scalar.upcast_out, name="mul_without_zeros") mul_without_zeros = MulWithoutZeros(upcast_out, name="mul_without_zeros")
class ProdWithoutZeros(CAReduceDtype): class ProdWithoutZeros(CAReduceDtype):
...@@ -2225,13 +2234,13 @@ class ProdWithoutZeros(CAReduceDtype): ...@@ -2225,13 +2234,13 @@ class ProdWithoutZeros(CAReduceDtype):
__props__ = ("axis", "dtype", "acc_dtype") __props__ = ("axis", "dtype", "acc_dtype")
def __init__(self, axis=None, dtype=None, acc_dtype=None): def __init__(self, axis=None, dtype=None, acc_dtype=None):
CAReduceDtype.__init__( super().__init__(mul_without_zeros, axis=axis, dtype=dtype, acc_dtype=acc_dtype)
self, mul_without_zeros, axis=axis, dtype=dtype, acc_dtype=acc_dtype
)
def grad(self, inp, grads): def grad(self, inp, grads):
from theano.gradient import grad_not_implemented
(a,) = inp (a,) = inp
a_grad = theano.gradient.grad_not_implemented( a_grad = grad_not_implemented(
self, self,
0, 0,
a, a,
...@@ -2253,6 +2262,7 @@ def scalar_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None): ...@@ -2253,6 +2262,7 @@ def scalar_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None):
not take a NumPy array argument to put its result in. not take a NumPy array argument to put its result in.
""" """
import theano.scalar as scalar
def construct(symbol): def construct(symbol):
nonlocal symbolname nonlocal symbolname
...@@ -2262,7 +2272,7 @@ def scalar_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None): ...@@ -2262,7 +2272,7 @@ def scalar_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None):
if symbolname.endswith("_inplace"): if symbolname.endswith("_inplace"):
elemwise_name = f"Elemwise{{{symbolname},inplace}}" elemwise_name = f"Elemwise{{{symbolname},inplace}}"
scalar_op = getattr(scalar, symbolname[: -len("_inplace")]) scalar_op = getattr(scalar, symbolname[: -len("_inplace")])
inplace_scalar_op = scalar_op.__class__(scalar.transfer_type(0)) inplace_scalar_op = scalar_op.__class__(transfer_type(0))
rval = Elemwise( rval = Elemwise(
inplace_scalar_op, inplace_scalar_op,
{0: 0}, {0: 0},
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论