提交 a6fe5f61 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Thomas Wiecki

Refactor import relationship between theano.tensor.basic and theano.tensor.elemwise

This removes the kludgy object re-definitions that were used to avoid circular import errors.
上级 da798b78
......@@ -1006,12 +1006,6 @@ tensor7s, ftensor7s, dtensor7s, itensor7s, ltensor7s = apply_across_args(
Tensor = TensorType
# This bizarre push-import avoids a circular dependency.
elemwise.as_tensor_variable = as_tensor_variable
elemwise.TensorType = TensorType
elemwise.TensorVariable = TensorVariable
elemwise.TensorConstant = TensorConstant
#########################
# Casting Operations
#########################
......@@ -1114,50 +1108,46 @@ def _conversion(real_value, name):
# what types you are casting to what. That logic is implemented by the
# `cast()` function below.
_convert_to_bool = _conversion(elemwise.Elemwise(scal.convert_to_bool), "bool")
_convert_to_bool = _conversion(Elemwise(scal.convert_to_bool), "bool")
"""Cast to boolean"""
_convert_to_int8 = _conversion(elemwise.Elemwise(scal.convert_to_int8), "int8")
_convert_to_int8 = _conversion(Elemwise(scal.convert_to_int8), "int8")
"""Cast to 8-bit integer"""
_convert_to_int16 = _conversion(elemwise.Elemwise(scal.convert_to_int16), "int16")
_convert_to_int16 = _conversion(Elemwise(scal.convert_to_int16), "int16")
"""Cast to 16-bit integer"""
_convert_to_int32 = _conversion(elemwise.Elemwise(scal.convert_to_int32), "int32")
_convert_to_int32 = _conversion(Elemwise(scal.convert_to_int32), "int32")
"""Cast to 32-bit integer"""
_convert_to_int64 = _conversion(elemwise.Elemwise(scal.convert_to_int64), "int64")
_convert_to_int64 = _conversion(Elemwise(scal.convert_to_int64), "int64")
"""Cast to 64-bit integer"""
_convert_to_uint8 = _conversion(elemwise.Elemwise(scal.convert_to_uint8), "uint8")
_convert_to_uint8 = _conversion(Elemwise(scal.convert_to_uint8), "uint8")
"""Cast to unsigned 8-bit integer"""
_convert_to_uint16 = _conversion(elemwise.Elemwise(scal.convert_to_uint16), "uint16")
_convert_to_uint16 = _conversion(Elemwise(scal.convert_to_uint16), "uint16")
"""Cast to unsigned 16-bit integer"""
_convert_to_uint32 = _conversion(elemwise.Elemwise(scal.convert_to_uint32), "uint32")
_convert_to_uint32 = _conversion(Elemwise(scal.convert_to_uint32), "uint32")
"""Cast to unsigned 32-bit integer"""
_convert_to_uint64 = _conversion(elemwise.Elemwise(scal.convert_to_uint64), "uint64")
_convert_to_uint64 = _conversion(Elemwise(scal.convert_to_uint64), "uint64")
"""Cast to unsigned 64-bit integer"""
_convert_to_float16 = _conversion(elemwise.Elemwise(scal.convert_to_float16), "float16")
_convert_to_float16 = _conversion(Elemwise(scal.convert_to_float16), "float16")
"""Cast to half-precision floating point"""
_convert_to_float32 = _conversion(elemwise.Elemwise(scal.convert_to_float32), "float32")
_convert_to_float32 = _conversion(Elemwise(scal.convert_to_float32), "float32")
"""Cast to single-precision floating point"""
_convert_to_float64 = _conversion(elemwise.Elemwise(scal.convert_to_float64), "float64")
_convert_to_float64 = _conversion(Elemwise(scal.convert_to_float64), "float64")
"""Cast to double-precision floating point"""
_convert_to_complex64 = _conversion(
elemwise.Elemwise(scal.convert_to_complex64), "complex64"
)
_convert_to_complex64 = _conversion(Elemwise(scal.convert_to_complex64), "complex64")
"""Cast to single-precision complex"""
_convert_to_complex128 = _conversion(
elemwise.Elemwise(scal.convert_to_complex128), "complex128"
)
_convert_to_complex128 = _conversion(Elemwise(scal.convert_to_complex128), "complex128")
"""Cast to double-precision complex"""
_cast_mapping = {
......@@ -3194,7 +3184,7 @@ def register_transfer(fn):
"""Create a duplicate of `a` (with duplicated storage)"""
tensor_copy = elemwise.Elemwise(scal.identity)
tensor_copy = Elemwise(scal.identity)
pprint.assign(tensor_copy, printing.IgnorePrinter())
......@@ -3206,7 +3196,7 @@ def sum(input, axis=None, dtype=None, keepdims=False, acc_dtype=None):
When axis is None (the default value), the sum is performed
over the flattened tensor.
For full documentation see ``tensor.elemwise.Sum``.
For full documentation see `Sum`.
In particular please pay attention to the important warning when using
a custom acc_dtype.
......@@ -3219,7 +3209,7 @@ def sum(input, axis=None, dtype=None, keepdims=False, acc_dtype=None):
"""
out = elemwise.Sum(axis=axis, dtype=dtype, acc_dtype=acc_dtype)(input)
out = Sum(axis=axis, dtype=dtype, acc_dtype=acc_dtype)(input)
if keepdims:
out = makeKeepDims(input, out, axis)
......@@ -3264,7 +3254,7 @@ def prod(
return out
class Mean(elemwise.CAReduce):
class Mean(CAReduce):
def __init__(self, axis=None):
super().__init__(scal.add, axis)
assert self.axis is None or len(self.axis) == 1
......@@ -4839,7 +4829,7 @@ def get_vector_length(v):
# `Op`s
if (
v.owner
and isinstance(v.owner.op, theano.tensor.elemwise.Elemwise)
and isinstance(v.owner.op, Elemwise)
and len(v.owner.inputs) == 1
and len(v.owner.outputs) == 1
):
......
......@@ -2,55 +2,52 @@ from copy import copy
import numpy as np
import theano
from theano import scalar
import theano.tensor.basic
from theano.configdefaults import config
from theano.gradient import DisconnectedType
from theano.graph.basic import Apply
from theano.graph.null_type import NullType
from theano.graph.op import COp, ExternalCOp, OpenMPOp
from theano.graph.params_type import ParamsType
from theano.graph.utils import MethodNotDefined
from theano.link.c.basic import failure_code
from theano.misc.frozendict import frozendict
from theano.misc.safe_asarray import _asarray
from theano.printing import FunctionPrinter, pprint
from theano.scalar import get_scalar_type
from theano.scalar.basic import (
AND,
OR,
XOR,
Add,
BinaryScalarOp,
Mul,
Scalar,
ScalarMaximum,
ScalarMinimum,
)
from theano.scalar.basic import add as scalar_add
from theano.scalar.basic import and_
from theano.scalar.basic import bool as scalar_bool
from theano.scalar.basic import identity as scalar_identity
from theano.scalar.basic import mul as scalar_mul
from theano.scalar.basic import (
or_,
scalar_maximum,
scalar_minimum,
second,
transfer_type,
upcast,
upcast_out,
)
from theano.tensor import elemwise_cgen as cgen
from theano.tensor.type import TensorType
from theano.utils import uniq
_numpy_ver = [int(n) for n in np.__version__.split(".")[:2]]
# tensor depends on elemwise to provide definitions for several ops
# but elemwise needs to make TensorType instances, so we have these as
# placeholders and the tensor module fills them
def as_tensor_variable(data):
raise Exception(
"Circular dependencies prevent using this" "here. import tensor before elemwise"
)
def TensorType(*inputs, **kwargs):
raise Exception(
"Circular dependencies prevent "
"using this here. import tensor before elemwise"
)
def TensorVariable(*inputs, **kwargs):
raise Exception(
"Circular dependencies "
"prevent using this here. import tensor before elemwise"
)
def TensorConstant(*inputs, **kwargs):
raise Exception(
"Circular dependencies "
"prevent using this here. import tensor before elemwise"
)
class DimShuffle(ExternalCOp):
"""
Allows to reorder the dimensions of a tensor or insert or remove
......@@ -138,9 +135,9 @@ class DimShuffle(ExternalCOp):
# because of importation issues related to TensorType.
return ParamsType(
input_broadcastable=TensorType(dtype="bool", broadcastable=(False,)),
_new_order=theano.tensor.lvector,
_new_order=theano.tensor.basic.lvector,
transposition=TensorType(dtype="uint32", broadcastable=(False,)),
inplace=theano.scalar.bool,
inplace=scalar_bool,
)
@property
......@@ -221,7 +218,7 @@ class DimShuffle(ExternalCOp):
super().__init__([self.c_func_file], self.c_func_name)
def make_node(self, _input):
input = as_tensor_variable(_input)
input = theano.tensor.basic.as_tensor_variable(_input)
ib = tuple(input.type.broadcastable)
if not ib == self.input_broadcastable:
if len(ib) != len(self.input_broadcastable):
......@@ -294,6 +291,8 @@ class DimShuffle(ExternalCOp):
return self(*eval_points, **dict(return_list=True))
def grad(self, inp, grads):
from theano.tensor.basic import as_tensor_variable, discrete_dtypes
(x,) = inp
(gz,) = grads
gz = as_tensor_variable(gz)
......@@ -304,12 +303,12 @@ class DimShuffle(ExternalCOp):
# Do not make the DimShuffle inplace as an optimization at the
# canonicalization optimization phase will remove the inplace.
# The inplace will be reintroduced automatically later in the graph.
if inp[0].dtype in theano.tensor.discrete_dtypes:
if inp[0].dtype in discrete_dtypes:
return [inp[0].zeros_like(dtype=config.floatX)]
else:
return [
DimShuffle(gz.type.broadcastable, grad_order)(
Elemwise(scalar.identity)(gz)
Elemwise(scalar_identity)(gz)
)
]
......@@ -356,12 +355,12 @@ class Elemwise(OpenMPOp):
being generalized to tensors. In particular, if the calculations
for an output are done inplace on an input, the output type must
be the same as the corresponding input type (see the doc of
scalar.ScalarOp to get help about controlling the output type)
`ScalarOp` to get help about controlling the output type)
Parameters
----------
scalar_op
An instance of a subclass of scalar.ScalarOp which works uniquely
An instance of a subclass of `ScalarOp` which works uniquely
on scalars.
inplace_pattern
A dictionary that maps the index of an output to the
......@@ -496,7 +495,7 @@ second dimension
is left-completed to the greatest number of dimensions with 1s
using DimShuffle.
"""
inputs = list(map(as_tensor_variable, inputs))
inputs = list(map(theano.tensor.basic.as_tensor_variable, inputs))
out_dtypes, out_broadcastables, inputs = self.get_output_info(
DimShuffle, *inputs
)
......@@ -525,7 +524,7 @@ second dimension
# make such that _bgrads computes only the gradients of the
# current output on the inputs ( and not all outputs)
ograds = [x.zeros_like() for x in outs]
ograds[idx] = theano.tensor.ones_like(out)
ograds[idx] = theano.tensor.basic.ones_like(out)
bgrads = self._bgrad(inputs, outs, ograds)
rop_out = None
......@@ -559,32 +558,30 @@ second dimension
return [[True for output in node.outputs] for ipt in node.inputs]
def L_op(self, inputs, outs, ograds):
from theano.tensor.basic import continuous_dtypes, discrete_dtypes
from theano.tensor.basic import sum as tt_sum
# compute grad with respect to broadcasted input
# Compute grad with respect to broadcasted input
rval = self._bgrad(inputs, outs, ograds)
# TODO: make sure that zeros are clearly identifiable
# to the gradient.grad method when the outputs have
# some integer and some floating point outputs
if any(out.type.dtype not in theano.tensor.continuous_dtypes for out in outs):
# For integer output, return value may
# only be zero or undefined
# We don't bother with trying to check
# that the scalar ops correctly
# returned something that evaluates to 0,
# we just make the return
# value obviously zero so that gradient.grad
# can tell this op did
# the right thing.
if any(out.type.dtype not in continuous_dtypes for out in outs):
# For integer output, return value may only be zero or undefined
# We don't bother with trying to check that the scalar ops
# correctly returned something that evaluates to 0, we just make
# the return value obviously zero so that gradient.grad can tell
# this op did the right thing.
new_rval = []
for elem, ipt in zip(rval, inputs):
if isinstance(elem.type, (NullType, DisconnectedType)):
new_rval.append(elem)
else:
elem = ipt.zeros_like()
if str(elem.type.dtype) not in theano.tensor.continuous_dtypes:
if str(elem.type.dtype) not in continuous_dtypes:
elem = elem.astype(config.floatX)
assert str(elem.type.dtype) not in theano.tensor.discrete_dtypes
assert str(elem.type.dtype) not in discrete_dtypes
new_rval.append(elem)
return new_rval
......@@ -593,9 +590,9 @@ second dimension
if isinstance(rval[i].type, (NullType, DisconnectedType)):
continue
# list of all the dimensions that are broadcastable for input[i] so
# List of all the dimensions that are broadcastable for input[i] so
# we can sum over them
# todo: only count dimensions that were effectively broadcasted
# TODO: only count dimensions that were effectively broadcasted
to_sum = [
j
for j, bcast in enumerate(ipt.type.broadcastable)
......@@ -603,10 +600,8 @@ second dimension
]
if to_sum:
sr = theano.tensor.basic.sum(rval[i], axis=to_sum, keepdims=True)
sr = tt_sum(rval[i], axis=to_sum, keepdims=True)
rval[i] = sr
# close if
# close for
return rval
......@@ -653,7 +648,9 @@ second dimension
# the gradient contains a constant, translate it as
# an equivalent TensorType of size 1 and proper number of
# dimensions
res = theano.tensor.constant(np.asarray(r.data), dtype=r.type.dtype)
res = theano.tensor.basic.constant(
np.asarray(r.data), dtype=r.type.dtype
)
return DimShuffle((), ["x"] * nd)(res)
new_r = Elemwise(node.op, {})(*[transform(ipt) for ipt in node.inputs])
......@@ -721,9 +718,9 @@ second dimension
# when the input is complex. So add it only when inputs is int.
out_dtype = node.outputs[0].dtype
if (
out_dtype in theano.tensor.float_dtypes
out_dtype in theano.tensor.basic.float_dtypes
and isinstance(self.nfunc, np.ufunc)
and node.inputs[0].dtype in theano.tensor.discrete_dtypes
and node.inputs[0].dtype in theano.tensor.basic.discrete_dtypes
):
char = np.sctype2char(out_dtype)
sig = char * node.nin + "->" + char * node.nout
......@@ -871,7 +868,7 @@ second dimension
# there must be some input that is not broadcastable in
# dimension 'dim'
for ishp, i in zip(i_shapes, node.inputs):
if isinstance(i.type, theano.scalar.Scalar):
if isinstance(i.type, Scalar):
continue # we skip scalar
if not i.type.broadcastable[dim]:
# input i is not broadcastable in position dim
......@@ -1032,7 +1029,7 @@ second dimension
if self.openmp:
# If we are using openmp, we need to get rid of the "goto"
# statement in sub['fail']. For now we recreate it here.
fail = theano.link.c.basic.failure_code(sub, use_goto=False)
fail = failure_code(sub, use_goto=False)
else:
fail = sub["fail"]
task_code = self.scalar_op.c_code(
......@@ -1139,7 +1136,7 @@ second dimension
contig = self.scalar_op.c_code_contiguous(
node, nodename + "_scalar_contig_", _inames, onames, sub
)
except theano.graph.utils.MethodNotDefined:
except MethodNotDefined:
# Try to make one generic version, this will help the
# compiler to vectorize the code as their won't be as
# many ptr and the stride will be hard coded.
......@@ -1343,24 +1340,25 @@ class CAReduce(COp):
self.set_ufunc(scalar_op)
def set_ufunc(self, scalar_op):
# This is probably a speed up of the implementation
if isinstance(scalar_op, theano.scalar.basic.Add):
# TODO FIXME: Why would we ever do this, instead of allowing the `Op`
# itself to tell us which `ufunc` it should use?
if isinstance(scalar_op, Add):
self.ufunc = np.add
elif isinstance(scalar_op, theano.scalar.basic.Mul):
elif isinstance(scalar_op, Mul):
self.ufunc = np.multiply
elif isinstance(scalar_op, theano.scalar.basic.ScalarMaximum):
elif isinstance(scalar_op, ScalarMaximum):
self.ufunc = np.maximum
elif isinstance(scalar_op, theano.scalar.basic.ScalarMinimum):
elif isinstance(scalar_op, ScalarMinimum):
self.ufunc = np.minimum
elif isinstance(scalar_op, theano.scalar.basic.AND) and _numpy_ver >= [1, 12]:
elif isinstance(scalar_op, AND) and _numpy_ver >= [1, 12]:
# numpy.bitwise_and.identity was incorrect for versions before
# 1.12 (it was 1 instead of -1), so we skip it in that case.
# We will fall back to the "else:" case, which defines a
# ufunc without identity.
self.ufunc = np.bitwise_and
elif isinstance(scalar_op, theano.scalar.basic.OR):
elif isinstance(scalar_op, OR):
self.ufunc = np.bitwise_or
elif isinstance(scalar_op, theano.scalar.basic.XOR):
elif isinstance(scalar_op, XOR):
self.ufunc = np.bitwise_xor
else:
self.ufunc = np.frompyfunc(scalar_op.impl, 2, 1)
......@@ -1369,6 +1367,8 @@ class CAReduce(COp):
return input_dtype
def make_node(self, input):
from theano.tensor.basic import as_tensor_variable
input = as_tensor_variable(input)
if self.axis is not None:
......@@ -1492,11 +1492,13 @@ class CAReduce(COp):
idtype = input.type.dtype_specs()[1]
odtype = output.type.dtype_specs()[1]
if hasattr(self, "acc_dtype") and self.acc_dtype is not None:
if self.acc_dtype == "float16":
raise theano.graph.utils.MethodNotDefined("no c_code for " "float16")
acc_dtype = getattr(self, "acc_dtype", None)
if acc_dtype is not None:
if acc_dtype == "float16":
raise MethodNotDefined("no c_code for float16")
acc_type = TensorType(
broadcastable=node.outputs[0].broadcastable, dtype=self.acc_dtype
broadcastable=node.outputs[0].broadcastable, dtype=acc_dtype
)
adtype = acc_type.dtype_specs()[1]
else:
......@@ -1509,9 +1511,9 @@ class CAReduce(COp):
if len(axis) == 0:
# The acc_dtype is never a downcast compared to the input dtype
# So we just need a cast to the output dtype.
var = theano.tensor.cast(input, node.outputs[0].dtype)
var = theano.tensor.basic.cast(input, node.outputs[0].dtype)
if var is input:
var = Elemwise(scalar.identity)(input)
var = Elemwise(scalar_identity)(input)
assert var.dtype == node.outputs[0].dtype
return var.owner.op._c_all(var.owner, name, inames, onames, sub)
......@@ -1570,8 +1572,8 @@ class CAReduce(COp):
if hasattr(self.scalar_op, "identity"):
identity = self.scalar_op.identity
elif self.scalar_op in [scalar.scalar_maximum, scalar.scalar_minimum]:
if self.scalar_op == scalar.scalar_maximum:
elif self.scalar_op in [scalar_maximum, scalar_minimum]:
if self.scalar_op == scalar_maximum:
scal_name = "maximum"
if input.type.dtype in ["float32", "float64"]:
identity = "-__builtin_inf()"
......@@ -1580,7 +1582,7 @@ class CAReduce(COp):
identity = "0"
else:
identity = "NPY_MIN_" + str(input.type.dtype).upper()
if self.scalar_op == scalar.scalar_minimum:
if self.scalar_op == scalar_minimum:
scal_name = "minimum"
if input.type.dtype in ["float32", "float64"]:
identity = "__builtin_inf()"
......@@ -1723,7 +1725,7 @@ class All(CAReduce):
nfunc_spec = ("all", 1, 1)
def __init__(self, axis=None):
CAReduce.__init__(self, scalar.and_, axis)
super().__init__(and_, axis)
def _output_dtype(self, idtype):
return "bool"
......@@ -1735,9 +1737,11 @@ class All(CAReduce):
return "All{%s}" % ", ".join(map(str, self.axis))
def make_node(self, input):
from theano.tensor.basic import as_tensor_variable, neq
input = as_tensor_variable(input)
if input.dtype != "bool":
input = theano.tensor.neq(input, 0)
input = neq(input, 0)
ret = super().make_node(input)
return ret
......@@ -1756,7 +1760,7 @@ class Any(CAReduce):
nfunc_spec = ("any", 1, 1)
def __init__(self, axis=None):
CAReduce.__init__(self, scalar.or_, axis)
super().__init__(or_, axis)
def _output_dtype(self, idtype):
return "bool"
......@@ -1768,9 +1772,11 @@ class Any(CAReduce):
return "Any{%s}" % ", ".join(map(str, self.axis))
def make_node(self, input):
from theano.tensor.basic import as_tensor_variable, neq
input = as_tensor_variable(input)
if input.dtype != "bool":
input = theano.tensor.neq(input, 0)
input = neq(input, 0)
ret = super().make_node(input)
return ret
......@@ -1835,7 +1841,7 @@ class CAReduceDtype(CAReduce):
__props__ = ("scalar_op", "axis", "dtype", "acc_dtype")
def __init__(self, scalar_op, axis=None, dtype=None, acc_dtype=None):
CAReduce.__init__(self, scalar_op, axis=axis)
super().__init__(scalar_op, axis=axis)
self.dtype = dtype
self.acc_dtype = acc_dtype
......@@ -1895,14 +1901,14 @@ class CAReduceDtype(CAReduce):
complex64="complex128",
).get(idtype, idtype)
elif (
acc_dtype in theano.tensor.continuous_dtypes
and idtype in theano.tensor.discrete_dtypes
acc_dtype in theano.tensor.basic.continuous_dtypes
and idtype in theano.tensor.basic.discrete_dtypes
):
# Specifying a continuous accumulator for discrete input is OK
return acc_dtype
else:
# The conversion has to be considered an upcast.
upcasted_dtype = scalar.upcast(idtype, acc_dtype)
upcasted_dtype = upcast(idtype, acc_dtype)
if acc_dtype != upcasted_dtype:
raise TypeError(
f"Cannot build {self} node with input dtype {idtype} "
......@@ -1922,11 +1928,13 @@ class CAReduceDtype(CAReduce):
# We need to redefine make_node so that, if self.dtype is None,
# we can infer what dtype should be, and create a node from an Op
# of the appropriate dtype.
input = as_tensor_variable(input)
input = theano.tensor.basic.as_tensor_variable(input)
dtype = self._output_dtype(input.dtype)
acc_dtype = self._acc_dtype(input.dtype)
assert dtype is not None
assert acc_dtype is not None
if dtype == self.dtype and acc_dtype == self.acc_dtype:
# Don't build another instance
op = self
......@@ -1937,7 +1945,10 @@ class CAReduceDtype(CAReduce):
op.acc_dtype = acc_dtype
assert op.acc_dtype is not None
return CAReduce.make_node(op, input)
# TODO: Why doesn't `make_node` just take these
# automatically-determined values as arguments?
return super(CAReduceDtype, op).make_node(input)
def __str__(self):
name = self.__class__.__name__
......@@ -1990,9 +2001,7 @@ class Sum(CAReduceDtype):
nfunc_spec = ("sum", 1, 1)
def __init__(self, axis=None, dtype=None, acc_dtype=None):
CAReduceDtype.__init__(
self, scalar.add, axis=axis, dtype=dtype, acc_dtype=acc_dtype
)
super().__init__(scalar_add, axis=axis, dtype=dtype, acc_dtype=acc_dtype)
def __str__(self):
name = self.__class__.__name__
......@@ -2003,9 +2012,11 @@ class Sum(CAReduceDtype):
return f"{name}{{{axis}acc_dtype={self.acc_dtype}}}"
def L_op(self, inp, out, grads):
from theano.tensor.basic import as_tensor_variable
(x,) = inp
if out[0].dtype not in theano.tensor.continuous_dtypes:
if out[0].dtype not in theano.tensor.basic.continuous_dtypes:
return [x.zeros_like(dtype=config.floatX)]
(gz,) = grads
......@@ -2024,7 +2035,7 @@ class Sum(CAReduceDtype):
new_dims.append(i)
i += 1
ds_op = DimShuffle(gz.type.broadcastable, new_dims)
gx = Elemwise(scalar.second)(x, ds_op(gz))
gx = Elemwise(second)(x, ds_op(gz))
return [gx]
def R_op(self, inputs, eval_points):
......@@ -2039,7 +2050,7 @@ class Prod(CAReduceDtype):
"""
Multiplies all the values of a tensor along the specified axis(es).
Equivalent to `CAReduce(scalar.prod, axis = axis)`, with the
Equivalent to `CAReduce(scalar.mul, axis = axis)`, with the
difference that this defines the gradient of prod wrt its tensor
input.
......@@ -2049,9 +2060,7 @@ class Prod(CAReduceDtype):
nfunc_spec = ("sum", 1, 1)
def __init__(self, axis=None, dtype=None, acc_dtype=None, no_zeros_in_input=False):
CAReduceDtype.__init__(
self, scalar.mul, axis=axis, dtype=dtype, acc_dtype=acc_dtype
)
super().__init__(scalar_mul, axis=axis, dtype=dtype, acc_dtype=acc_dtype)
self.no_zeros_in_input = no_zeros_in_input
def __setstate__(self, dct):
......@@ -2106,13 +2115,14 @@ class Prod(CAReduceDtype):
based on the result of this count.
"""
from theano.tensor.basic import as_tensor_variable, discrete_dtypes, eq, neq
from theano.tensor.basic import sum as tt_sum
from theano.tensor.basic import switch
(prod_in,) = inp
(gz,) = grads
if (
out[0].dtype in theano.tensor.discrete_dtypes
or self.acc_dtype in theano.tensor.discrete_dtypes
):
if out[0].dtype in discrete_dtypes or self.acc_dtype in discrete_dtypes:
# There is an int conversion in the way
return [prod_in.zeros_like(dtype=config.floatX)]
......@@ -2147,17 +2157,16 @@ class Prod(CAReduceDtype):
# this handles inputs with zeros, but only certain input shapes
return [grad_case_without_zeros]
else:
T = theano.tensor
where_zeros = T.eq(prod_in, 0.0)
sum_where_zeros = T.sum(where_zeros, axis=self.axis)
groups_with_single_zero = T.eq(sum_where_zeros, 1).dimshuffle(new_dims)
where_zeros = eq(prod_in, 0.0)
sum_where_zeros = tt_sum(where_zeros, axis=self.axis)
groups_with_single_zero = eq(sum_where_zeros, 1).dimshuffle(new_dims)
# tensor with 0 everywhere except for those places where
# a 0 part of a group with a single zero was to be found
where_single_zero = groups_with_single_zero * where_zeros
# further optimization to avoid computing ProdWithoutZeros
# if the incoming gradient is 0
where_gz_not_zero = T.neq(gz, 0.0)
where_gz_not_zero = neq(gz, 0.0)
# only take ProdWithoutZeros for the groups with single zeros
# with non-null incoming gradient
where_to_take_prod_without_zeros = (
......@@ -2173,12 +2182,12 @@ class Prod(CAReduceDtype):
prod_without_zeros = ProdWithoutZeros(axis=self.axis)(prod_without_zeros_in)
prod_without_zeros = prod_without_zeros.dimshuffle(new_dims)
groups_without_zeros = T.eq(sum_where_zeros, 0).dimshuffle(new_dims)
groups_without_zeros = eq(sum_where_zeros, 0).dimshuffle(new_dims)
final_grad = T.switch(
final_grad = switch(
groups_without_zeros,
grad_case_without_zeros,
T.switch(where_single_zero, prod_without_zeros, 0.0) * gz,
switch(where_single_zero, prod_without_zeros, 0.0) * gz,
)
return [final_grad]
......@@ -2187,7 +2196,7 @@ class Prod(CAReduceDtype):
return (1,)
class MulWithoutZeros(scalar.BinaryScalarOp):
class MulWithoutZeros(BinaryScalarOp):
# "identity" here is zero, as in Reduce we don't want to start
# with reducing (1, something_else): this leads to the erroneous
# case where a vector of zeros is reduced by binary reductions
......@@ -2217,7 +2226,7 @@ class MulWithoutZeros(scalar.BinaryScalarOp):
return (1,)
mul_without_zeros = MulWithoutZeros(scalar.upcast_out, name="mul_without_zeros")
mul_without_zeros = MulWithoutZeros(upcast_out, name="mul_without_zeros")
class ProdWithoutZeros(CAReduceDtype):
......@@ -2225,13 +2234,13 @@ class ProdWithoutZeros(CAReduceDtype):
__props__ = ("axis", "dtype", "acc_dtype")
def __init__(self, axis=None, dtype=None, acc_dtype=None):
CAReduceDtype.__init__(
self, mul_without_zeros, axis=axis, dtype=dtype, acc_dtype=acc_dtype
)
super().__init__(mul_without_zeros, axis=axis, dtype=dtype, acc_dtype=acc_dtype)
def grad(self, inp, grads):
from theano.gradient import grad_not_implemented
(a,) = inp
a_grad = theano.gradient.grad_not_implemented(
a_grad = grad_not_implemented(
self,
0,
a,
......@@ -2253,6 +2262,7 @@ def scalar_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None):
not take a NumPy array argument to put its result in.
"""
import theano.scalar as scalar
def construct(symbol):
nonlocal symbolname
......@@ -2262,7 +2272,7 @@ def scalar_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None):
if symbolname.endswith("_inplace"):
elemwise_name = f"Elemwise{{{symbolname},inplace}}"
scalar_op = getattr(scalar, symbolname[: -len("_inplace")])
inplace_scalar_op = scalar_op.__class__(scalar.transfer_type(0))
inplace_scalar_op = scalar_op.__class__(transfer_type(0))
rval = Elemwise(
inplace_scalar_op,
{0: 0},
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论