Miscellaneous: comments, error message to point people towards using scalar

instead of using Join for scalar values
上级 5a1fd228
......@@ -95,7 +95,11 @@ def as_tensor(x, name = None):
try:
return constant(x)
except TypeError:
raise TypeError("Cannot convert %s to Tensor" % x, type(x))
try:
str_x = str(x)
except:
str_x = repr(x)
raise TypeError("Cannot convert %s to Tensor" % str_x, type(x))
# this has a different name, because _as_tensor is the function which ops use
# to upcast their arguments... this internal-use function is a good place to put debugging stuff, better than the global astensor.
......@@ -506,6 +510,8 @@ class _tensor_py_operators:
#TRANSPOSE
T = property(lambda self: transpose(self))
shape = property(lambda self: shape(self))
#SLICING
# def __getitem__(self, args): return Subtensor.from_idxs(self,
# args).outputs[0]
......@@ -542,13 +548,19 @@ class _tensor_py_operators:
class TensorResult(Result, _tensor_py_operators):
pass
"""Subclass to add the tensor operators to the basic `Result` class."""
class TensorConstant(Constant, _tensor_py_operators):
pass
"""Subclass to add the tensor operators to the basic `Constant` class.
To create a TensorConstant, use the `constant` function in this module.
"""
class TensorValue(Value, _tensor_py_operators):
pass
"""Subclass to add the tensor operators to the basic `Value` class.
To create a TensorValue, use the `value` function in this module.
"""
#QUESTION: why are we doing this!?
elemwise.as_tensor = as_tensor
......@@ -1428,6 +1440,8 @@ class Join(Op):
Of course, TensorResult instances don't have a shape, so this error can't be caught until
runtime. See `perform()`.
For joins involving scalar values, see @stack.
.. python::
x, y, z = tensor.matrix(), tensor.matrix(), tensor.matrix()
......@@ -1447,6 +1461,9 @@ class Join(Op):
as_tensor_args= [as_tensor(x) for x in tensors]
dtypes = [x.type.dtype for x in as_tensor_args]
if not all(targs.type.ndim for targs in as_tensor_args):
raise TypeError('Join cannot handle arguments of dimension 0. For joining scalar values, see @stack');
if not all([dtypes[0] == dt for dt in dtypes[1:]]):
# Note that we could automatically find out the appropriate dtype
# able to store the concatenation of all tensors, but for now we
......@@ -1468,7 +1485,10 @@ class Join(Op):
raise ValueError('Dimensions other than the given axis must'
' match', tensors)
bcastable[:] = as_tensor_args[0].type.broadcastable
bcastable[axis] = False
try:
bcastable[axis] = False
except IndexError, e:
raise ValueError('Join argument "axis" is out of range (given input dimensions)')
inputs = [as_tensor(axis)] + as_tensor_args
if inputs[0].type not in int_types:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论