Miscellaneous: comments, error message to point people towards using scalar

instead of using Join for scalar values
上级 5a1fd228
...@@ -95,7 +95,11 @@ def as_tensor(x, name = None): ...@@ -95,7 +95,11 @@ def as_tensor(x, name = None):
try: try:
return constant(x) return constant(x)
except TypeError: except TypeError:
raise TypeError("Cannot convert %s to Tensor" % x, type(x)) try:
str_x = str(x)
except:
str_x = repr(x)
raise TypeError("Cannot convert %s to Tensor" % str_x, type(x))
# this has a different name, because _as_tensor is the function which ops use # this has a different name, because _as_tensor is the function which ops use
# to upcast their arguments... this internal-use function is a good place to put debugging stuff, better than the global astensor. # to upcast their arguments... this internal-use function is a good place to put debugging stuff, better than the global astensor.
...@@ -506,6 +510,8 @@ class _tensor_py_operators: ...@@ -506,6 +510,8 @@ class _tensor_py_operators:
#TRANSPOSE #TRANSPOSE
T = property(lambda self: transpose(self)) T = property(lambda self: transpose(self))
shape = property(lambda self: shape(self))
#SLICING #SLICING
# def __getitem__(self, args): return Subtensor.from_idxs(self, # def __getitem__(self, args): return Subtensor.from_idxs(self,
# args).outputs[0] # args).outputs[0]
...@@ -542,13 +548,19 @@ class _tensor_py_operators: ...@@ -542,13 +548,19 @@ class _tensor_py_operators:
class TensorResult(Result, _tensor_py_operators): class TensorResult(Result, _tensor_py_operators):
pass """Subclass to add the tensor operators to the basic `Result` class."""
class TensorConstant(Constant, _tensor_py_operators): class TensorConstant(Constant, _tensor_py_operators):
pass """Subclass to add the tensor operators to the basic `Constant` class.
To create a TensorConstant, use the `constant` function in this module.
"""
class TensorValue(Value, _tensor_py_operators): class TensorValue(Value, _tensor_py_operators):
pass """Subclass to add the tensor operators to the basic `Value` class.
To create a TensorValue, use the `value` function in this module.
"""
#QUESTION: why are we doing this!? #QUESTION: why are we doing this!?
elemwise.as_tensor = as_tensor elemwise.as_tensor = as_tensor
...@@ -1428,6 +1440,8 @@ class Join(Op): ...@@ -1428,6 +1440,8 @@ class Join(Op):
Of course, TensorResult instances don't have a shape, so this error can't be caught until Of course, TensorResult instances don't have a shape, so this error can't be caught until
runtime. See `perform()`. runtime. See `perform()`.
For joins involving scalar values, see @stack.
.. python:: .. python::
x, y, z = tensor.matrix(), tensor.matrix(), tensor.matrix() x, y, z = tensor.matrix(), tensor.matrix(), tensor.matrix()
...@@ -1447,6 +1461,9 @@ class Join(Op): ...@@ -1447,6 +1461,9 @@ class Join(Op):
as_tensor_args= [as_tensor(x) for x in tensors] as_tensor_args= [as_tensor(x) for x in tensors]
dtypes = [x.type.dtype for x in as_tensor_args] dtypes = [x.type.dtype for x in as_tensor_args]
if not all(targs.type.ndim for targs in as_tensor_args):
raise TypeError('Join cannot handle arguments of dimension 0. For joining scalar values, see @stack');
if not all([dtypes[0] == dt for dt in dtypes[1:]]): if not all([dtypes[0] == dt for dt in dtypes[1:]]):
# Note that we could automatically find out the appropriate dtype # Note that we could automatically find out the appropriate dtype
# able to store the concatenation of all tensors, but for now we # able to store the concatenation of all tensors, but for now we
...@@ -1468,7 +1485,10 @@ class Join(Op): ...@@ -1468,7 +1485,10 @@ class Join(Op):
raise ValueError('Dimensions other than the given axis must' raise ValueError('Dimensions other than the given axis must'
' match', tensors) ' match', tensors)
bcastable[:] = as_tensor_args[0].type.broadcastable bcastable[:] = as_tensor_args[0].type.broadcastable
try:
bcastable[axis] = False bcastable[axis] = False
except IndexError, e:
raise ValueError('Join argument "axis" is out of range (given input dimensions)')
inputs = [as_tensor(axis)] + as_tensor_args inputs = [as_tensor(axis)] + as_tensor_args
if inputs[0].type not in int_types: if inputs[0].type not in int_types:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论