提交 eccbadb0 authored 作者: Frederic's avatar Frederic

Better error message for theano.tensor.{all,any}({True,False})

fix gh-1615
上级 2ef1058b
......@@ -1606,6 +1606,7 @@ class All(CAReduce):
return "All{%s}" % ", ".join(map(str, self.axis))
def make_node(self, input):
input = as_tensor_variable(input)
if input.dtype not in ["int8", "uint8"]:
input = theano.tensor.neq(input, 0)
ret = super(All, self).make_node(input)
......@@ -1631,6 +1632,7 @@ class Any(CAReduce):
return "Any{%s}" % ", ".join(map(str, self.axis))
def make_node(self, input):
input = as_tensor_variable(input)
if input.dtype not in ["int8", "uint8"]:
input = theano.tensor.neq(input, 0)
ret = super(Any, self).make_node(input)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论