提交 8ca1bf09 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

generic fix to #649: made as_tensor_variable(boolean) raise a TypeError and made…

generic fix to #649: made as_tensor_variable(boolean) raise a TypeError and made some needed changes for tests to pass
上级 8a40497a
......@@ -138,6 +138,11 @@ def as_tensor_variable(x, name = None, ndim=None):
except (TypeError, ValueError):
pass
if isinstance(x, bool):
raise TypeError("Cannot cast True or False as a tensor variable. Please use 1 or 0. "
"This error might be caused by using the == operator on Variables. "
"v == w does not do what you think it does, use theano.tensor.eq(v, w) instead.")
try:
return constant(x, name=name, ndim=ndim)
except TypeError:
......@@ -3850,7 +3855,10 @@ class PermuteRowElements(Op):
def make_node(self, x, y, inverse):
x = as_tensor_variable(x)
y = as_tensor_variable(y)
inverse = as_tensor_variable(inverse)
if inverse: # as_tensor_variable does not accept booleans
inverse = as_tensor_variable(1)
else:
inverse = as_tensor_variable(0)
# y should contain integers
assert y.type.dtype.startswith('int') or y.type.dtype.startswith('uint')
......
......@@ -1141,7 +1141,7 @@ def local_subtensor_merge(node):
u_start = u.owner.op.idx_list[0].start
if len(u.owner.inputs) == 1 and isinstance(u_start, int):
start0 = u_start
start0 = T.as_tensor_variable(u_start)
elif (len(u.owner.inputs) == 2 and
isinstance (u_start, scalar.basic.Scalar)):
start0 = T.tensor_from_scalar(u.owner.inputs[1])
......@@ -1178,7 +1178,7 @@ def local_subtensor_merge(node):
):
idx = node.op.idx_list[0]
if len(node.inputs) == 1 and isinstance(idx, int):
pass
idx = T.as_tensor_variable(idx)
elif (len(node.inputs) == 2 and
isinstance (idx, scalar.basic.Scalar)):
idx = T.tensor_from_scalar(node.inputs[1])
......@@ -1202,7 +1202,7 @@ def local_subtensor_merge(node):
slice_idx = node.op.idx_list[0]
idx = slice_idx.stop
if len(node.inputs) == 1 and isinstance(idx, int):
pass
idx = T.as_tensor_variable(idx)
elif (len(node.inputs) == 2 and
isinstance (idx, scalar.basic.Scalar)):
idx = T.tensor_from_scalar(node.inputs[1])
......@@ -1227,8 +1227,12 @@ def local_subtensor_merge(node):
idx2 = node.op.idx_list[0].stop
if isinstance(idx1, scalar.basic.Scalar):
idx1 = T.tensor_from_scalar(u.owner.inputs[1])
elif isinstance(idx1, int):
idx1 = T.as_tensor_variable(idx1)
if isinstance(idx2, scalar.basic.Scalar):
idx2 = T.tensor_from_scalar(node.inputs[1])
elif isinstance(idx2, int):
idx2 = T.as_tensor_variable(idx1)
# The maximum is needed to don't have shape[0] - idx1 < 0
idx2_neg = T.maximum(u.owner.inputs[0].shape[0]+idx1, 0)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论