提交 070ae21e authored 作者: James Bergstra's avatar James Bergstra

refactored tensor.constant and tensor.value to include an ndim argument. also…

refactored tensor.constant and tensor.value to include an ndim argument. also added ndim argument to as_tensor
上级 50030aa8
......@@ -59,7 +59,7 @@ def __oplist_tag(thing, tag):
thing.__oplist_tags = tags
def as_tensor(x, name = None):
def as_tensor(x, name = None, ndim=None):
"""Return `x`, transformed into a `Tensor`
This function is often used by `make_node` methods of `Op` subclasses to
......@@ -73,6 +73,8 @@ def as_tensor(x, name = None):
to make an ndarray.
- `name`: str or None
If a new `Result` instance is created, it will be named with this string.
- `ndim`: None or integer
Return a Result with this many dimensions. Raise TypeError if it's not possible.
:Exceptions:
- `ValueError`: raised if an `Apply` with no default output is fetched
......@@ -88,12 +90,23 @@ def as_tensor(x, name = None):
x = x.outputs[0]
if isinstance(x, Result):
if isinstance(x.type, scal.Scalar):
return tensor_from_scalar(x)
x = tensor_from_scalar(x)
if not isinstance(x.type, Tensor):
raise TypeError("Result type field must be a Tensor.", x, x.type)
return x
if ndim is None:
return x
else:
if (x.type.ndim > ndim):
#TODO: strip off leading broadcastable dimensions
raise ValueError('Tensor could not be cast to have %i dimensions' % ndim, x.type)
elif (x.type.ndim < ndim):
return shape_padleft(x, n_ones=(ndim - x.type.ndim))
else:
return x
try:
return constant(x)
return constant(x, name=name, ndim=ndim)
except TypeError:
try:
str_x = str(x)
......@@ -105,43 +118,39 @@ def as_tensor(x, name = None):
# to upcast their arguments... this internal-use function is a good place to put debugging stuff, better than the global astensor.
_as_tensor = as_tensor
def constant(x, name=None):
def constant_or_value(x, rtype, name=None, ndim=None):
"""Return a symbolic `Constant` with value `x`
:Exceptions:
- `TypeError`: `x` could not be converted to a numpy.ndarray
"""
if isinstance(x, numpy.ndarray):
x_ = x
else:
x_ = numpy.asarray(x)
try:
return TensorConstant(Tensor(dtype = x_.dtype,
broadcastable = [d == 1 for d in x_.shape]), x_, name=name)
except:
raise TypeError("Could not convert %s to Tensor" % x, type(x))
- `ValueError`: `x` could not be expanded to have ndim dimensions
def value(x, name=None):
"""Return a symbolic `Value` with default value `x`
:Exceptions:
- `TypeError`: `x` could not be converted to a numpy.ndarray
"""
if isinstance(x, numpy.ndarray):
x_ = x
else:
x_ = numpy.asarray(x)
bcastable = [d == 1 for d in x_.shape]
if ndim is not None:
if len(bcastable) < ndim:
bcastable = [True] * (ndim - len(bcastable)) + bcastable
elif len(bcastable) > ndim:
#TODO: strip off dimensions of size 1
raise ValueError('ndarray could not be cast to constant with %i dimensions' % ndim)
assert len(bcastable) == ndim
try:
if name is None:
return TensorValue(Tensor(dtype = x_.dtype,
broadcastable = [d == 1 for d in x_.shape]), x_)
else:
return TensorValue(Tensor(dtype = x_.dtype,
broadcastable = [d == 1 for d in x_.shape]), x_, name=name)
return rtype(Tensor(dtype = x_.dtype, broadcastable = bcastable), x_, name=name)
except:
raise TypeError("Could not convert %s to Tensor" % x, type(x))
def constant(x, name=None, ndim=None):
return constant_or_value(x, rtype=TensorConstant, name=name, ndim=ndim)
def value(x, name=None, ndim=None):
return constant_or_value(x, rtype=TensorValue, name=name, ndim=ndim)
class Tensor(Type):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论