提交 ac97e23a authored 作者: Frederic Bastien's avatar Frederic Bastien

Make TensorConstant.equals() work with python int and float

上级 23ccd2ac
......@@ -908,8 +908,9 @@ class TensorConstant(_tensor_py_operators, Constant):
return TensorConstantSignature((self.type, self.data))
def equals(self, other):
# Override Contant.equals to allow to compare with numpy.ndarray
if isinstance(other, numpy.ndarray):
# Override Contant.equals to allow to compare with
# numpy.ndarray, and python type.
if isinstance(other, (numpy.ndarray, int, float)):
# Make a TensorConstant to be able to compare
other = theano.tensor.basic.constant(other)
return (isinstance(other, TensorConstant) and
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论