提交 24b8bc28 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Make as_tensor_variable's handling of constants consistent

Now, when `as_tensor_variable` is given `Constant` types, it will return the same result as a call with the underlying non-`Variable` data. Closes #98
上级 6a2259f4
......@@ -2755,9 +2755,38 @@ class ApplyDefaultTestOp(theano.Op):
return theano.Apply(self, [x], [x.type()])
def test_constant():
int8_type = tensor.TensorType(dtype="int8", broadcastable=(False,))
# Make sure we return a `TensorConstant` unchanged
x = tensor.TensorConstant(int8_type, [1, 2])
y = constant(x)
assert y is x
# Make sure we can add and remove broadcastable dimensions
int8_type = tensor.TensorType(dtype="int8", broadcastable=())
x_data = np.array(2, dtype="int8")
x = tensor.TensorConstant(int8_type, x_data)
y = constant(x, ndim=1)
assert y.ndim == 1
assert np.array_equal(y.data, np.expand_dims(x_data, 0))
y = constant(x, ndim=2)
assert y.ndim == 2
assert np.array_equal(y.data, np.expand_dims(x_data, (0, 1)))
z = constant(y, ndim=0)
assert y.ndim == 2 and z.ndim == 0
assert np.array_equal(z.data, x_data)
class TestAsTensorVariable:
# Unit test for ensuring that as_tensor_variable handles Apply objects
# correctly and removes leading broadcastable dimensions when possible.
"""
Unit test for ensuring that as_tensor_variable handles Apply objects
correctly and removes leading broadcastable dimensions when possible.
"""
def setup_method(self):
self.x = tensor.scalar("x")
......@@ -2793,35 +2822,65 @@ class TestAsTensorVariable:
with pytest.raises(ValueError):
as_tensor_variable(x, ndim=1)
# We test that ticket #649 stay fixed.
# We should not allow as_tensor_variable to accept True or False
# But it should upcast an ndarray of bool to uint8
def test_bool(self):
# We should not allow `as_tensor_variable` to accept `True` or `False`,
# but it should up-cast an `ndarray` of `bool` to uint8
with pytest.raises(TypeError):
as_tensor_variable(True)
with pytest.raises(TypeError):
as_tensor_variable(False)
def test_ndarray_bool(self):
ten = as_tensor_variable(np.array([True, False, False, True, True]))
assert ten.type.dtype == "bool"
def test_memmap(self):
inp = np.random.rand(4, 3)
f, fname = mkstemp()
_, fname = mkstemp()
new_inp = np.memmap(fname, dtype=inp.dtype, mode="w+", shape=inp.shape)
new_inp[...] = inp
as_tensor_variable(new_inp)
res = as_tensor_variable(new_inp)
assert isinstance(res, tensor.TensorConstant)
assert res.data is new_inp
def test_empty_dtype(self):
old = theano.config.floatX
for dtype in ["float16", "float32", "float64"]:
try:
theano.config.floatX = dtype
assert theano.tensor.as_tensor_variable(()).dtype == dtype
assert theano.tensor.as_tensor_variable([]).dtype == dtype
finally:
theano.config.floatX = old
@pytest.mark.parametrize(
"dtype",
[
"float16",
"float32",
"float64",
],
)
def test_empty_dtype(self, dtype):
with theano.change_flags(floatX=dtype):
assert as_tensor_variable(()).dtype == dtype
assert as_tensor_variable([]).dtype == dtype
@pytest.mark.parametrize(
("x", "y"),
[
([1, 2], [1, 2]),
([tensor.as_tensor(1), tensor.as_tensor(2)], [1, 2]),
([theano.scalar.constant(1), theano.scalar.constant(2)], [1, 2]),
],
)
def test_constant_consistency(self, x, y):
a = as_tensor_variable(x)
assert isinstance(a, tensor.TensorConstant)
assert np.array_equal(a.data, y)
def test_constant_identity(self):
# Values that are already `TensorType`s shouldn't be recreated by
# `as_tensor_variable`
x_scalar = tensor.TensorConstant(
tensor.TensorType(dtype="int8", broadcastable=()), 2
)
a_scalar = as_tensor_variable(x_scalar)
assert x_scalar is a_scalar
x_vector = tensor.TensorConstant(
tensor.TensorType(dtype="int8", broadcastable=(False,)),
np.array([1, 2], dtype="int8"),
)
a_vector = as_tensor_variable(x_vector)
assert x_vector is a_vector
class TestAlloc:
......
"""A `Type` and `Op` classes to work with numpy.ndarrays symbolically."""
import sys
import warnings
import numbers
import logging
......@@ -113,10 +112,10 @@ def __oplist_tag(thing, tag):
def as_tensor_variable(x, name=None, ndim=None):
"""Return `x`, transformed into a `TensorType`.
"""Convert `x` into the appropriate `TensorType`.
This function is often used by `make_node` methods of `Op` subclasses
to turn ndarrays, numbers, `Scalar` instances, `Apply` instances and
This function is often used by `make_node` methods of `Op` subclasses to
turn ndarrays, numbers, `Scalar` instances, `Apply` instances and
`TensorType` instances into valid input list elements.
Parameters
......@@ -140,6 +139,13 @@ def as_tensor_variable(x, name=None, ndim=None):
If `x` cannot be converted to a TensorType Variable.
"""
if (
isinstance(getattr(x, "type", None), TensorType)
and (name is None or x.name == name)
and (ndim is None or x.ndim == ndim)
):
return x
if hasattr(x, "_as_TensorVariable"):
return x._as_TensorVariable() # TODO: pass name and ndim arguments
......@@ -147,18 +153,24 @@ def as_tensor_variable(x, name=None, ndim=None):
# use Apply's default output mechanism
if (x.op.default_output is None) and (len(x.outputs) != 1):
raise ValueError(
"It is ambiguous which output of a multi-output Op has"
" to be fetched.",
x,
"Multi-output Op encountered. "
"Retry using only one of the outputs directly."
)
x = x.default_output()
if isinstance(x, Variable):
if isinstance(x, Constant):
return as_tensor_variable(x.data, name=name, ndim=ndim)
if isinstance(x.type, scal.Scalar):
x = tensor_from_scalar(x)
if not isinstance(x.type, TensorType):
raise AsTensorError("Variable type field must be a TensorType.", x, x.type)
raise AsTensorError(
"Tensor type field must be a TensorType; found {}.".format(type(x.type))
)
if ndim is None:
return x
......@@ -171,23 +183,36 @@ def as_tensor_variable(x, name=None, ndim=None):
x = x.dimshuffle(list(range(x.ndim))[first_non_broadcastable:])
if x.ndim > ndim:
raise ValueError(
"TensorType could not be cast to have %i dimensions" % ndim,
x.type,
"Tensor of type {} could not be cast to have {} dimensions".format(
x.type, ndim
)
)
return x
elif x.type.ndim < ndim:
return shape_padleft(x, n_ones=(ndim - x.type.ndim))
else:
return x
if isinstance(x, (tuple, list)) and python_any(
isinstance(xi, Variable) for xi in x
):
elif isinstance(x, Sequence):
def extract_constants(i):
if isinstance(i, Variable):
if isinstance(i, Constant):
return i.data
else:
raise TypeError
else:
return i
try:
return stack(x)
except (TypeError, ValueError):
pass
x = [extract_constants(i) for i in x]
except TypeError:
try:
return stack(x)
except (TypeError, ValueError):
pass
if isinstance(x, bool):
elif isinstance(x, bool):
raise AsTensorError(
"Cannot cast True or False as a tensor variable. Please use "
"np.array(True) or np.array(False) if you need these constants. "
......@@ -199,11 +224,9 @@ def as_tensor_variable(x, name=None, ndim=None):
try:
return constant(x, name=name, ndim=ndim)
except TypeError:
try:
str_x = str(x)
except Exception:
str_x = repr(x)
raise AsTensorError("Cannot convert %s to TensorType" % str_x, type(x))
raise AsTensorError(
"Cannot convert {} of type {} to TensorType".format(x, type(x))
)
# this has a different name, because _as_tensor_variable is the
......@@ -226,21 +249,34 @@ def constant(x, name=None, ndim=None, dtype=None):
`x` could not be expanded to have ndim dimensions.
"""
if isinstance(x, TensorConstant):
if (
(name is None or x.name == name)
and (ndim is None or x.ndim == ndim)
and (dtype is None or x.dtype == dtype)
):
return x
else:
x = x.data
x_ = scal.convert(x, dtype=dtype)
bcastable = [d == 1 for d in x_.shape]
if ndim is not None:
if len(bcastable) < ndim:
bcastable = [True] * (ndim - len(bcastable)) + bcastable
elif len(bcastable) > ndim:
# TODO: strip off dimensions of size 1
raise ValueError(
"ndarray could not be cast to constant with %i dimensions" % ndim
)
assert len(bcastable) == ndim
if x_.ndim < ndim:
x_ = np.expand_dims(x_, axis=tuple(range(ndim - x_.ndim)))
elif x_.ndim > ndim:
try:
x_ = np.squeeze(x_, axis=tuple(range(x_.ndim - ndim)))
except np.AxisError:
raise ValueError(
"ndarray could not be cast to constant with %i dimensions" % ndim
)
assert x_.ndim == ndim
ttype = TensorType(dtype=x_.dtype, broadcastable=[s == 1 for s in x_.shape])
try:
ttype = TensorType(dtype=x_.dtype, broadcastable=bcastable)
return TensorConstant(ttype, x_, name=name)
except Exception:
raise TypeError("Could not convert %s to TensorType" % x, type(x))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论