提交 df700c8f authored 作者: James Bergstra's avatar James Bergstra

disabling TensorType.shape pending support for partially-known type attributes

上级 3834ea89
...@@ -181,9 +181,19 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None): ...@@ -181,9 +181,19 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None):
try: try:
if rtype is TensorConstant: if rtype is TensorConstant:
# put the shape into the type if 0:
# put the shape into the type
# This is disabled because if a tensor has shape, then the following fails:
# theano.lvector == as_tensor_variable([0,1]).type
# I think the solution is that we should implement something more like
# compatability instead of equality in our Type comparisons... but we're not
# there yet.
x_shape = x_.shape
else:
x_shape = None
return rtype( return rtype(
TensorType(dtype = x_.dtype, broadcastable = bcastable, shape=x_.shape), TensorType(dtype = x_.dtype, broadcastable = bcastable, shape=x_shape),
x_, name=name) x_, name=name)
else: else:
# leave the shape out of the type # leave the shape out of the type
...@@ -1096,8 +1106,8 @@ def shape(a): ...@@ -1096,8 +1106,8 @@ def shape(a):
If the shape of the expression is not known at graph-construction time, then a symbolic If the shape of the expression is not known at graph-construction time, then a symbolic
lvector will be returned, corresponding to the actual shape at graph-execution time. lvector will be returned, corresponding to the actual shape at graph-execution time.
""" """
print 'GOT A', a, a.type
va = as_tensor_variable(a) va = as_tensor_variable(a)
#print 'HERE', va, va.type
if None in va.type.shape: if None in va.type.shape:
# Some shape components are unknown at this time # Some shape components are unknown at this time
return _shape(va) return _shape(va)
...@@ -1106,7 +1116,7 @@ def shape(a): ...@@ -1106,7 +1116,7 @@ def shape(a):
# a tuple directly. This tuple is like the numpy.ndarray.shape tuple. # a tuple directly. This tuple is like the numpy.ndarray.shape tuple.
return va.type.shape return va.type.shape
pprint.assign(shape, printing.MemberPrinter('shape')) pprint.assign(_shape, printing.MemberPrinter('shape'))
class MaxAndArgmax(Op): class MaxAndArgmax(Op):
...@@ -2403,7 +2413,7 @@ def get_vector_length(v): ...@@ -2403,7 +2413,7 @@ def get_vector_length(v):
return join.vec_length(v) return join.vec_length(v)
except ValueError: except ValueError:
pass pass
if v.owner and v.owner.op == shape: if v.owner and v.owner.op == _shape:
return v.owner.inputs[0].type.ndim return v.owner.inputs[0].type.ndim
raise ValueError("length not known") raise ValueError("length not known")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论