提交 24b8bc28 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Make as_tensor_variable's handling of constants consistent

Now, when `as_tensor_variable` is given `Constant` types, it will return the same result as a call with the underlying non-`Variable` data. Closes #98
上级 6a2259f4
...@@ -2755,9 +2755,38 @@ class ApplyDefaultTestOp(theano.Op): ...@@ -2755,9 +2755,38 @@ class ApplyDefaultTestOp(theano.Op):
return theano.Apply(self, [x], [x.type()]) return theano.Apply(self, [x], [x.type()])
def test_constant():
int8_type = tensor.TensorType(dtype="int8", broadcastable=(False,))
# Make sure we return a `TensorConstant` unchanged
x = tensor.TensorConstant(int8_type, [1, 2])
y = constant(x)
assert y is x
# Make sure we can add and remove broadcastable dimensions
int8_type = tensor.TensorType(dtype="int8", broadcastable=())
x_data = np.array(2, dtype="int8")
x = tensor.TensorConstant(int8_type, x_data)
y = constant(x, ndim=1)
assert y.ndim == 1
assert np.array_equal(y.data, np.expand_dims(x_data, 0))
y = constant(x, ndim=2)
assert y.ndim == 2
assert np.array_equal(y.data, np.expand_dims(x_data, (0, 1)))
z = constant(y, ndim=0)
assert y.ndim == 2 and z.ndim == 0
assert np.array_equal(z.data, x_data)
class TestAsTensorVariable: class TestAsTensorVariable:
# Unit test for ensuring that as_tensor_variable handles Apply objects """
# correctly and removes leading broadcastable dimensions when possible. Unit test for ensuring that as_tensor_variable handles Apply objects
correctly and removes leading broadcastable dimensions when possible.
"""
def setup_method(self): def setup_method(self):
self.x = tensor.scalar("x") self.x = tensor.scalar("x")
...@@ -2793,35 +2822,65 @@ class TestAsTensorVariable: ...@@ -2793,35 +2822,65 @@ class TestAsTensorVariable:
with pytest.raises(ValueError): with pytest.raises(ValueError):
as_tensor_variable(x, ndim=1) as_tensor_variable(x, ndim=1)
# We test that ticket #649 stay fixed.
# We should not allow as_tensor_variable to accept True or False
# But it should upcast an ndarray of bool to uint8
def test_bool(self): def test_bool(self):
# We should not allow `as_tensor_variable` to accept `True` or `False`,
# but it should up-cast an `ndarray` of `bool` to uint8
with pytest.raises(TypeError): with pytest.raises(TypeError):
as_tensor_variable(True) as_tensor_variable(True)
with pytest.raises(TypeError):
as_tensor_variable(False)
def test_ndarray_bool(self):
ten = as_tensor_variable(np.array([True, False, False, True, True])) ten = as_tensor_variable(np.array([True, False, False, True, True]))
assert ten.type.dtype == "bool" assert ten.type.dtype == "bool"
def test_memmap(self): def test_memmap(self):
inp = np.random.rand(4, 3) inp = np.random.rand(4, 3)
f, fname = mkstemp() _, fname = mkstemp()
new_inp = np.memmap(fname, dtype=inp.dtype, mode="w+", shape=inp.shape) new_inp = np.memmap(fname, dtype=inp.dtype, mode="w+", shape=inp.shape)
new_inp[...] = inp new_inp[...] = inp
as_tensor_variable(new_inp) res = as_tensor_variable(new_inp)
assert isinstance(res, tensor.TensorConstant)
assert res.data is new_inp
def test_empty_dtype(self): @pytest.mark.parametrize(
old = theano.config.floatX "dtype",
for dtype in ["float16", "float32", "float64"]: [
try: "float16",
theano.config.floatX = dtype "float32",
assert theano.tensor.as_tensor_variable(()).dtype == dtype "float64",
assert theano.tensor.as_tensor_variable([]).dtype == dtype ],
finally: )
theano.config.floatX = old def test_empty_dtype(self, dtype):
with theano.change_flags(floatX=dtype):
assert as_tensor_variable(()).dtype == dtype
assert as_tensor_variable([]).dtype == dtype
@pytest.mark.parametrize(
("x", "y"),
[
([1, 2], [1, 2]),
([tensor.as_tensor(1), tensor.as_tensor(2)], [1, 2]),
([theano.scalar.constant(1), theano.scalar.constant(2)], [1, 2]),
],
)
def test_constant_consistency(self, x, y):
a = as_tensor_variable(x)
assert isinstance(a, tensor.TensorConstant)
assert np.array_equal(a.data, y)
def test_constant_identity(self):
# Values that are already `TensorType`s shouldn't be recreated by
# `as_tensor_variable`
x_scalar = tensor.TensorConstant(
tensor.TensorType(dtype="int8", broadcastable=()), 2
)
a_scalar = as_tensor_variable(x_scalar)
assert x_scalar is a_scalar
x_vector = tensor.TensorConstant(
tensor.TensorType(dtype="int8", broadcastable=(False,)),
np.array([1, 2], dtype="int8"),
)
a_vector = as_tensor_variable(x_vector)
assert x_vector is a_vector
class TestAlloc: class TestAlloc:
......
"""A `Type` and `Op` classes to work with numpy.ndarrays symbolically.""" """A `Type` and `Op` classes to work with numpy.ndarrays symbolically."""
import sys
import warnings import warnings
import numbers import numbers
import logging import logging
...@@ -113,10 +112,10 @@ def __oplist_tag(thing, tag): ...@@ -113,10 +112,10 @@ def __oplist_tag(thing, tag):
def as_tensor_variable(x, name=None, ndim=None): def as_tensor_variable(x, name=None, ndim=None):
"""Return `x`, transformed into a `TensorType`. """Convert `x` into the appropriate `TensorType`.
This function is often used by `make_node` methods of `Op` subclasses This function is often used by `make_node` methods of `Op` subclasses to
to turn ndarrays, numbers, `Scalar` instances, `Apply` instances and turn ndarrays, numbers, `Scalar` instances, `Apply` instances and
`TensorType` instances into valid input list elements. `TensorType` instances into valid input list elements.
Parameters Parameters
...@@ -140,6 +139,13 @@ def as_tensor_variable(x, name=None, ndim=None): ...@@ -140,6 +139,13 @@ def as_tensor_variable(x, name=None, ndim=None):
If `x` cannot be converted to a TensorType Variable. If `x` cannot be converted to a TensorType Variable.
""" """
if (
isinstance(getattr(x, "type", None), TensorType)
and (name is None or x.name == name)
and (ndim is None or x.ndim == ndim)
):
return x
if hasattr(x, "_as_TensorVariable"): if hasattr(x, "_as_TensorVariable"):
return x._as_TensorVariable() # TODO: pass name and ndim arguments return x._as_TensorVariable() # TODO: pass name and ndim arguments
...@@ -147,18 +153,24 @@ def as_tensor_variable(x, name=None, ndim=None): ...@@ -147,18 +153,24 @@ def as_tensor_variable(x, name=None, ndim=None):
# use Apply's default output mechanism # use Apply's default output mechanism
if (x.op.default_output is None) and (len(x.outputs) != 1): if (x.op.default_output is None) and (len(x.outputs) != 1):
raise ValueError( raise ValueError(
"It is ambiguous which output of a multi-output Op has" "Multi-output Op encountered. "
" to be fetched.", "Retry using only one of the outputs directly."
x,
) )
x = x.default_output() x = x.default_output()
if isinstance(x, Variable): if isinstance(x, Variable):
if isinstance(x, Constant):
return as_tensor_variable(x.data, name=name, ndim=ndim)
if isinstance(x.type, scal.Scalar): if isinstance(x.type, scal.Scalar):
x = tensor_from_scalar(x) x = tensor_from_scalar(x)
if not isinstance(x.type, TensorType): if not isinstance(x.type, TensorType):
raise AsTensorError("Variable type field must be a TensorType.", x, x.type) raise AsTensorError(
"Tensor type field must be a TensorType; found {}.".format(type(x.type))
)
if ndim is None: if ndim is None:
return x return x
...@@ -171,23 +183,36 @@ def as_tensor_variable(x, name=None, ndim=None): ...@@ -171,23 +183,36 @@ def as_tensor_variable(x, name=None, ndim=None):
x = x.dimshuffle(list(range(x.ndim))[first_non_broadcastable:]) x = x.dimshuffle(list(range(x.ndim))[first_non_broadcastable:])
if x.ndim > ndim: if x.ndim > ndim:
raise ValueError( raise ValueError(
"TensorType could not be cast to have %i dimensions" % ndim, "Tensor of type {} could not be cast to have {} dimensions".format(
x.type, x.type, ndim
)
) )
return x return x
elif x.type.ndim < ndim: elif x.type.ndim < ndim:
return shape_padleft(x, n_ones=(ndim - x.type.ndim)) return shape_padleft(x, n_ones=(ndim - x.type.ndim))
else: else:
return x return x
if isinstance(x, (tuple, list)) and python_any(
isinstance(xi, Variable) for xi in x elif isinstance(x, Sequence):
):
def extract_constants(i):
if isinstance(i, Variable):
if isinstance(i, Constant):
return i.data
else:
raise TypeError
else:
return i
try: try:
return stack(x) x = [extract_constants(i) for i in x]
except (TypeError, ValueError): except TypeError:
pass try:
return stack(x)
except (TypeError, ValueError):
pass
if isinstance(x, bool): elif isinstance(x, bool):
raise AsTensorError( raise AsTensorError(
"Cannot cast True or False as a tensor variable. Please use " "Cannot cast True or False as a tensor variable. Please use "
"np.array(True) or np.array(False) if you need these constants. " "np.array(True) or np.array(False) if you need these constants. "
...@@ -199,11 +224,9 @@ def as_tensor_variable(x, name=None, ndim=None): ...@@ -199,11 +224,9 @@ def as_tensor_variable(x, name=None, ndim=None):
try: try:
return constant(x, name=name, ndim=ndim) return constant(x, name=name, ndim=ndim)
except TypeError: except TypeError:
try: raise AsTensorError(
str_x = str(x) "Cannot convert {} of type {} to TensorType".format(x, type(x))
except Exception: )
str_x = repr(x)
raise AsTensorError("Cannot convert %s to TensorType" % str_x, type(x))
# this has a different name, because _as_tensor_variable is the # this has a different name, because _as_tensor_variable is the
...@@ -226,21 +249,34 @@ def constant(x, name=None, ndim=None, dtype=None): ...@@ -226,21 +249,34 @@ def constant(x, name=None, ndim=None, dtype=None):
`x` could not be expanded to have ndim dimensions. `x` could not be expanded to have ndim dimensions.
""" """
if isinstance(x, TensorConstant):
if (
(name is None or x.name == name)
and (ndim is None or x.ndim == ndim)
and (dtype is None or x.dtype == dtype)
):
return x
else:
x = x.data
x_ = scal.convert(x, dtype=dtype) x_ = scal.convert(x, dtype=dtype)
bcastable = [d == 1 for d in x_.shape]
if ndim is not None: if ndim is not None:
if len(bcastable) < ndim: if x_.ndim < ndim:
bcastable = [True] * (ndim - len(bcastable)) + bcastable x_ = np.expand_dims(x_, axis=tuple(range(ndim - x_.ndim)))
elif len(bcastable) > ndim: elif x_.ndim > ndim:
# TODO: strip off dimensions of size 1 try:
raise ValueError( x_ = np.squeeze(x_, axis=tuple(range(x_.ndim - ndim)))
"ndarray could not be cast to constant with %i dimensions" % ndim except np.AxisError:
) raise ValueError(
assert len(bcastable) == ndim "ndarray could not be cast to constant with %i dimensions" % ndim
)
assert x_.ndim == ndim
ttype = TensorType(dtype=x_.dtype, broadcastable=[s == 1 for s in x_.shape])
try: try:
ttype = TensorType(dtype=x_.dtype, broadcastable=bcastable)
return TensorConstant(ttype, x_, name=name) return TensorConstant(ttype, x_, name=name)
except Exception: except Exception:
raise TypeError("Could not convert %s to TensorType" % x, type(x)) raise TypeError("Could not convert %s to TensorType" % x, type(x))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论