Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
19358099
提交
19358099
authored
1月 09, 2024
作者:
Ben Mares
提交者:
Ricardo Vieira
6月 06, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix E721: do not compare types, for exact checks use is / is not
上级
4b6a4440
隐藏空白字符变更
内嵌
并排
正在显示
27 个修改的文件
包含
41 行增加
和
41 行删除
+41
-41
debugmode.py
pytensor/compile/debugmode.py
+2
-2
ops.py
pytensor/compile/ops.py
+1
-1
basic.py
pytensor/graph/basic.py
+1
-1
null_type.py
pytensor/graph/null_type.py
+1
-1
unify.py
pytensor/graph/rewriting/unify.py
+2
-2
utils.py
pytensor/graph/utils.py
+1
-1
ifelse.py
pytensor/ifelse.py
+1
-1
params_type.py
pytensor/link/c/params_type.py
+2
-2
type.py
pytensor/link/c/type.py
+1
-1
raise_op.py
pytensor/raise_op.py
+2
-2
basic.py
pytensor/scalar/basic.py
+3
-3
math.py
pytensor/scalar/math.py
+5
-5
op.py
pytensor/scan/op.py
+1
-1
basic.py
pytensor/sparse/basic.py
+1
-1
type.py
pytensor/tensor/random/type.py
+1
-1
math.py
pytensor/tensor/rewriting/math.py
+1
-1
type.py
pytensor/tensor/type.py
+2
-2
type_other.py
pytensor/tensor/type_other.py
+1
-1
variable.py
pytensor/tensor/variable.py
+3
-3
type.py
pytensor/typed_list/type.py
+1
-1
test_unify.py
tests/graph/rewriting/test_unify.py
+1
-1
test_fg.py
tests/graph/test_fg.py
+2
-2
test_op.py
tests/graph/test_op.py
+1
-1
test_basic.py
tests/link/c/test_basic.py
+1
-1
test_basic.py
tests/sparse/test_basic.py
+1
-1
test_basic.py
tests/tensor/test_basic.py
+1
-1
test_subtensor.py
tests/tensor/test_subtensor.py
+1
-1
没有找到文件。
pytensor/compile/debugmode.py
浏览文件 @
19358099
...
@@ -687,7 +687,7 @@ def _lessbroken_deepcopy(a):
...
@@ -687,7 +687,7 @@ def _lessbroken_deepcopy(a):
else
:
else
:
rval
=
copy
.
deepcopy
(
a
)
rval
=
copy
.
deepcopy
(
a
)
assert
type
(
rval
)
==
type
(
a
),
(
type
(
rval
),
type
(
a
))
assert
type
(
rval
)
is
type
(
a
),
(
type
(
rval
),
type
(
a
))
if
isinstance
(
rval
,
np
.
ndarray
):
if
isinstance
(
rval
,
np
.
ndarray
):
assert
rval
.
dtype
==
a
.
dtype
assert
rval
.
dtype
==
a
.
dtype
...
@@ -1154,7 +1154,7 @@ class _FunctionGraphEvent:
...
@@ -1154,7 +1154,7 @@ class _FunctionGraphEvent:
return
str
(
self
.
__dict__
)
return
str
(
self
.
__dict__
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
rval
=
type
(
self
)
==
type
(
other
)
rval
=
type
(
self
)
is
type
(
other
)
if
rval
:
if
rval
:
# nodes are not compared because this comparison is
# nodes are not compared because this comparison is
# supposed to be true for corresponding events that happen
# supposed to be true for corresponding events that happen
...
...
pytensor/compile/ops.py
浏览文件 @
19358099
...
@@ -246,7 +246,7 @@ class FromFunctionOp(Op):
...
@@ -246,7 +246,7 @@ class FromFunctionOp(Op):
self
.
infer_shape
=
self
.
_infer_shape
self
.
infer_shape
=
self
.
_infer_shape
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
and
self
.
__fn
==
other
.
__fn
return
type
(
self
)
is
type
(
other
)
and
self
.
__fn
==
other
.
__fn
def
__hash__
(
self
):
def
__hash__
(
self
):
return
hash
(
type
(
self
))
^
hash
(
self
.
__fn
)
return
hash
(
type
(
self
))
^
hash
(
self
.
__fn
)
...
...
pytensor/graph/basic.py
浏览文件 @
19358099
...
@@ -748,7 +748,7 @@ class NominalVariable(AtomicVariable[_TypeType]):
...
@@ -748,7 +748,7 @@ class NominalVariable(AtomicVariable[_TypeType]):
return
True
return
True
return
(
return
(
type
(
self
)
==
type
(
other
)
type
(
self
)
is
type
(
other
)
and
self
.
id
==
other
.
id
and
self
.
id
==
other
.
id
and
self
.
type
==
other
.
type
and
self
.
type
==
other
.
type
)
)
...
...
pytensor/graph/null_type.py
浏览文件 @
19358099
...
@@ -33,7 +33,7 @@ class NullType(Type):
...
@@ -33,7 +33,7 @@ class NullType(Type):
raise
ValueError
(
"NullType has no values to compare"
)
raise
ValueError
(
"NullType has no values to compare"
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
return
type
(
self
)
is
type
(
other
)
def
__hash__
(
self
):
def
__hash__
(
self
):
return
hash
(
type
(
self
))
return
hash
(
type
(
self
))
...
...
pytensor/graph/rewriting/unify.py
浏览文件 @
19358099
...
@@ -57,8 +57,8 @@ class ConstrainedVar(Var):
...
@@ -57,8 +57,8 @@ class ConstrainedVar(Var):
return
obj
return
obj
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
if
type
(
self
)
==
type
(
other
):
if
type
(
self
)
is
type
(
other
):
return
self
.
token
==
other
.
token
and
self
.
constraint
==
other
.
constraint
return
self
.
token
is
other
.
token
and
self
.
constraint
==
other
.
constraint
return
NotImplemented
return
NotImplemented
def
__hash__
(
self
):
def
__hash__
(
self
):
...
...
pytensor/graph/utils.py
浏览文件 @
19358099
...
@@ -229,7 +229,7 @@ class MetaType(ABCMeta):
...
@@ -229,7 +229,7 @@ class MetaType(ABCMeta):
if
"__eq__"
not
in
dct
:
if
"__eq__"
not
in
dct
:
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
and
tuple
(
return
type
(
self
)
is
type
(
other
)
and
tuple
(
getattr
(
self
,
a
)
for
a
in
props
getattr
(
self
,
a
)
for
a
in
props
)
==
tuple
(
getattr
(
other
,
a
)
for
a
in
props
)
)
==
tuple
(
getattr
(
other
,
a
)
for
a
in
props
)
...
...
pytensor/ifelse.py
浏览文件 @
19358099
...
@@ -78,7 +78,7 @@ class IfElse(_NoPythonOp):
...
@@ -78,7 +78,7 @@ class IfElse(_NoPythonOp):
self
.
name
=
name
self
.
name
=
name
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
if
type
(
self
)
!=
type
(
other
):
if
type
(
self
)
is
not
type
(
other
):
return
False
return
False
if
self
.
as_view
!=
other
.
as_view
:
if
self
.
as_view
!=
other
.
as_view
:
return
False
return
False
...
...
pytensor/link/c/params_type.py
浏览文件 @
19358099
...
@@ -301,7 +301,7 @@ class Params(dict):
...
@@ -301,7 +301,7 @@ class Params(dict):
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
(
return
(
type
(
self
)
==
type
(
other
)
type
(
self
)
is
type
(
other
)
and
self
.
__params_type__
==
other
.
__params_type__
and
self
.
__params_type__
==
other
.
__params_type__
and
all
(
and
all
(
# NB: Params object should have been already filtered.
# NB: Params object should have been already filtered.
...
@@ -435,7 +435,7 @@ class ParamsType(CType):
...
@@ -435,7 +435,7 @@ class ParamsType(CType):
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
(
return
(
type
(
self
)
==
type
(
other
)
type
(
self
)
is
type
(
other
)
and
self
.
fields
==
other
.
fields
and
self
.
fields
==
other
.
fields
and
self
.
types
==
other
.
types
and
self
.
types
==
other
.
types
)
)
...
...
pytensor/link/c/type.py
浏览文件 @
19358099
...
@@ -515,7 +515,7 @@ class EnumType(CType, dict):
...
@@ -515,7 +515,7 @@ class EnumType(CType, dict):
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
(
return
(
type
(
self
)
==
type
(
other
)
type
(
self
)
is
type
(
other
)
and
self
.
ctype
==
other
.
ctype
and
self
.
ctype
==
other
.
ctype
and
len
(
self
)
==
len
(
other
)
and
len
(
self
)
==
len
(
other
)
and
len
(
self
.
aliases
)
==
len
(
other
.
aliases
)
and
len
(
self
.
aliases
)
==
len
(
other
.
aliases
)
...
...
pytensor/raise_op.py
浏览文件 @
19358099
...
@@ -16,7 +16,7 @@ from pytensor.tensor.type import DenseTensorType
...
@@ -16,7 +16,7 @@ from pytensor.tensor.type import DenseTensorType
class
ExceptionType
(
Generic
):
class
ExceptionType
(
Generic
):
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
return
type
(
self
)
is
type
(
other
)
def
__hash__
(
self
):
def
__hash__
(
self
):
return
hash
(
type
(
self
))
return
hash
(
type
(
self
))
...
@@ -51,7 +51,7 @@ class CheckAndRaise(COp):
...
@@ -51,7 +51,7 @@ class CheckAndRaise(COp):
return
f
"CheckAndRaise{{{self.exc_type}({self.msg})}}"
return
f
"CheckAndRaise{{{self.exc_type}({self.msg})}}"
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
if
type
(
self
)
!=
type
(
other
):
if
type
(
self
)
is
not
type
(
other
):
return
False
return
False
if
self
.
msg
==
other
.
msg
and
self
.
exc_type
==
other
.
exc_type
:
if
self
.
msg
==
other
.
msg
and
self
.
exc_type
==
other
.
exc_type
:
...
...
pytensor/scalar/basic.py
浏览文件 @
19358099
...
@@ -1074,7 +1074,7 @@ class unary_out_lookup(MetaObject):
...
@@ -1074,7 +1074,7 @@ class unary_out_lookup(MetaObject):
return
[
rval
]
return
[
rval
]
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
and
self
.
tbl
==
other
.
tbl
return
type
(
self
)
is
type
(
other
)
and
self
.
tbl
==
other
.
tbl
def
__hash__
(
self
):
def
__hash__
(
self
):
return
hash
(
type
(
self
))
# ignore hash of table
return
hash
(
type
(
self
))
# ignore hash of table
...
@@ -1160,7 +1160,7 @@ class ScalarOp(COp):
...
@@ -1160,7 +1160,7 @@ class ScalarOp(COp):
return
self
.
grad
(
inputs
,
output_gradients
)
return
self
.
grad
(
inputs
,
output_gradients
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
test
=
type
(
self
)
==
type
(
other
)
and
getattr
(
test
=
type
(
self
)
is
type
(
other
)
and
getattr
(
self
,
"output_types_preference"
,
None
self
,
"output_types_preference"
,
None
)
==
getattr
(
other
,
"output_types_preference"
,
None
)
)
==
getattr
(
other
,
"output_types_preference"
,
None
)
return
test
return
test
...
@@ -4133,7 +4133,7 @@ class ScalarInnerGraphOp(ScalarOp, HasInnerGraph):
...
@@ -4133,7 +4133,7 @@ class ScalarInnerGraphOp(ScalarOp, HasInnerGraph):
if
self
is
other
:
if
self
is
other
:
return
True
return
True
if
(
if
(
type
(
self
)
!=
type
(
other
)
type
(
self
)
is
not
type
(
other
)
or
self
.
nin
!=
other
.
nin
or
self
.
nin
!=
other
.
nin
or
self
.
nout
!=
other
.
nout
or
self
.
nout
!=
other
.
nout
):
):
...
...
pytensor/scalar/math.py
浏览文件 @
19358099
...
@@ -626,7 +626,7 @@ class Chi2SF(BinaryScalarOp):
...
@@ -626,7 +626,7 @@ class Chi2SF(BinaryScalarOp):
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
return
type
(
self
)
is
type
(
other
)
def
__hash__
(
self
):
def
__hash__
(
self
):
return
hash
(
type
(
self
))
return
hash
(
type
(
self
))
...
@@ -679,7 +679,7 @@ class GammaInc(BinaryScalarOp):
...
@@ -679,7 +679,7 @@ class GammaInc(BinaryScalarOp):
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
return
type
(
self
)
is
type
(
other
)
def
__hash__
(
self
):
def
__hash__
(
self
):
return
hash
(
type
(
self
))
return
hash
(
type
(
self
))
...
@@ -732,7 +732,7 @@ class GammaIncC(BinaryScalarOp):
...
@@ -732,7 +732,7 @@ class GammaIncC(BinaryScalarOp):
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
return
type
(
self
)
is
type
(
other
)
def
__hash__
(
self
):
def
__hash__
(
self
):
return
hash
(
type
(
self
))
return
hash
(
type
(
self
))
...
@@ -1045,7 +1045,7 @@ class GammaU(BinaryScalarOp):
...
@@ -1045,7 +1045,7 @@ class GammaU(BinaryScalarOp):
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
return
type
(
self
)
is
type
(
other
)
def
__hash__
(
self
):
def
__hash__
(
self
):
return
hash
(
type
(
self
))
return
hash
(
type
(
self
))
...
@@ -1083,7 +1083,7 @@ class GammaL(BinaryScalarOp):
...
@@ -1083,7 +1083,7 @@ class GammaL(BinaryScalarOp):
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
return
type
(
self
)
is
type
(
other
)
def
__hash__
(
self
):
def
__hash__
(
self
):
return
hash
(
type
(
self
))
return
hash
(
type
(
self
))
...
...
pytensor/scan/op.py
浏览文件 @
19358099
...
@@ -1246,7 +1246,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -1246,7 +1246,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
return
apply_node
return
apply_node
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
if
type
(
self
)
!=
type
(
other
):
if
type
(
self
)
is
not
type
(
other
):
return
False
return
False
if
self
.
info
!=
other
.
info
:
if
self
.
info
!=
other
.
info
:
...
...
pytensor/sparse/basic.py
浏览文件 @
19358099
...
@@ -462,7 +462,7 @@ class SparseConstantSignature(tuple):
...
@@ -462,7 +462,7 @@ class SparseConstantSignature(tuple):
return
(
return
(
a
==
x
a
==
x
and
(
b
.
dtype
==
y
.
dtype
)
and
(
b
.
dtype
==
y
.
dtype
)
and
(
type
(
b
)
==
type
(
y
))
and
(
type
(
b
)
is
type
(
y
))
and
(
b
.
shape
==
y
.
shape
)
and
(
b
.
shape
==
y
.
shape
)
and
(
abs
(
b
-
y
)
.
sum
()
<
1e-6
*
b
.
nnz
)
and
(
abs
(
b
-
y
)
.
sum
()
<
1e-6
*
b
.
nnz
)
)
)
...
...
pytensor/tensor/random/type.py
浏览文件 @
19358099
...
@@ -107,7 +107,7 @@ class RandomGeneratorType(RandomType[np.random.Generator]):
...
@@ -107,7 +107,7 @@ class RandomGeneratorType(RandomType[np.random.Generator]):
return
_eq
(
sa
,
sb
)
return
_eq
(
sa
,
sb
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
return
type
(
self
)
is
type
(
other
)
def
__hash__
(
self
):
def
__hash__
(
self
):
return
hash
(
type
(
self
))
return
hash
(
type
(
self
))
...
...
pytensor/tensor/rewriting/math.py
浏览文件 @
19358099
...
@@ -1742,7 +1742,7 @@ def local_reduce_broadcastable(fgraph, node):
...
@@ -1742,7 +1742,7 @@ def local_reduce_broadcastable(fgraph, node):
ii
+=
1
ii
+=
1
new_reduced
=
reduced
.
dimshuffle
(
*
pattern
)
new_reduced
=
reduced
.
dimshuffle
(
*
pattern
)
if
new_axis
:
if
new_axis
:
if
type
(
node
.
op
)
==
CAReduce
:
if
type
(
node
.
op
)
is
CAReduce
:
# This case handles `CAReduce` instances
# This case handles `CAReduce` instances
# (e.g. generated by `scalar_elemwise`), and not the
# (e.g. generated by `scalar_elemwise`), and not the
# scalar `Op`-specific subclasses
# scalar `Op`-specific subclasses
...
...
pytensor/tensor/type.py
浏览文件 @
19358099
...
@@ -370,7 +370,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
...
@@ -370,7 +370,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
return
values_eq_approx
(
a
,
b
,
allow_remove_inf
,
allow_remove_nan
,
rtol
,
atol
)
return
values_eq_approx
(
a
,
b
,
allow_remove_inf
,
allow_remove_nan
,
rtol
,
atol
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
if
type
(
self
)
!=
type
(
other
):
if
type
(
self
)
is
not
type
(
other
):
return
NotImplemented
return
NotImplemented
return
other
.
dtype
==
self
.
dtype
and
other
.
shape
==
self
.
shape
return
other
.
dtype
==
self
.
dtype
and
other
.
shape
==
self
.
shape
...
@@ -624,7 +624,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
...
@@ -624,7 +624,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
class
DenseTypeMeta
(
MetaType
):
class
DenseTypeMeta
(
MetaType
):
def
__instancecheck__
(
self
,
o
):
def
__instancecheck__
(
self
,
o
):
if
type
(
o
)
==
TensorType
or
isinstance
(
o
,
DenseTypeMeta
):
if
type
(
o
)
is
TensorType
or
isinstance
(
o
,
DenseTypeMeta
):
return
True
return
True
return
False
return
False
...
...
pytensor/tensor/type_other.py
浏览文件 @
19358099
...
@@ -64,7 +64,7 @@ class SliceType(Type[slice]):
...
@@ -64,7 +64,7 @@ class SliceType(Type[slice]):
return
"slice"
return
"slice"
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
return
type
(
self
)
is
type
(
other
)
def
__hash__
(
self
):
def
__hash__
(
self
):
return
hash
(
type
(
self
))
return
hash
(
type
(
self
))
...
...
pytensor/tensor/variable.py
浏览文件 @
19358099
...
@@ -945,7 +945,7 @@ class TensorConstantSignature(tuple):
...
@@ -945,7 +945,7 @@ class TensorConstantSignature(tuple):
"""
"""
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
if
type
(
self
)
!=
type
(
other
):
if
type
(
self
)
is
not
type
(
other
):
return
False
return
False
try
:
try
:
(
t0
,
d0
),
(
t1
,
d1
)
=
self
,
other
(
t0
,
d0
),
(
t1
,
d1
)
=
self
,
other
...
@@ -1105,7 +1105,7 @@ TensorType.constant_type = TensorConstant
...
@@ -1105,7 +1105,7 @@ TensorType.constant_type = TensorConstant
class
DenseVariableMeta
(
MetaType
):
class
DenseVariableMeta
(
MetaType
):
def
__instancecheck__
(
self
,
o
):
def
__instancecheck__
(
self
,
o
):
if
type
(
o
)
==
TensorVariable
or
isinstance
(
o
,
DenseVariableMeta
):
if
type
(
o
)
is
TensorVariable
or
isinstance
(
o
,
DenseVariableMeta
):
return
True
return
True
return
False
return
False
...
@@ -1120,7 +1120,7 @@ class DenseTensorVariable(TensorType, metaclass=DenseVariableMeta):
...
@@ -1120,7 +1120,7 @@ class DenseTensorVariable(TensorType, metaclass=DenseVariableMeta):
class
DenseConstantMeta
(
MetaType
):
class
DenseConstantMeta
(
MetaType
):
def
__instancecheck__
(
self
,
o
):
def
__instancecheck__
(
self
,
o
):
if
type
(
o
)
==
TensorConstant
or
isinstance
(
o
,
DenseConstantMeta
):
if
type
(
o
)
is
TensorConstant
or
isinstance
(
o
,
DenseConstantMeta
):
return
True
return
True
return
False
return
False
...
...
pytensor/typed_list/type.py
浏览文件 @
19358099
...
@@ -55,7 +55,7 @@ class TypedListType(CType):
...
@@ -55,7 +55,7 @@ class TypedListType(CType):
Two lists are equal if they contain the same type.
Two lists are equal if they contain the same type.
"""
"""
return
type
(
self
)
==
type
(
other
)
and
self
.
ttype
==
other
.
ttype
return
type
(
self
)
is
type
(
other
)
and
self
.
ttype
==
other
.
ttype
def
__hash__
(
self
):
def
__hash__
(
self
):
return
hash
((
type
(
self
),
self
.
ttype
))
return
hash
((
type
(
self
),
self
.
ttype
))
...
...
tests/graph/rewriting/test_unify.py
浏览文件 @
19358099
...
@@ -42,7 +42,7 @@ class CustomOpNoPropsNoEq(Op):
...
@@ -42,7 +42,7 @@ class CustomOpNoPropsNoEq(Op):
class
CustomOpNoProps
(
CustomOpNoPropsNoEq
):
class
CustomOpNoProps
(
CustomOpNoPropsNoEq
):
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
and
self
.
a
==
other
.
a
return
type
(
self
)
is
type
(
other
)
and
self
.
a
==
other
.
a
def
__hash__
(
self
):
def
__hash__
(
self
):
return
hash
((
type
(
self
),
self
.
a
))
return
hash
((
type
(
self
),
self
.
a
))
...
...
tests/graph/test_fg.py
浏览文件 @
19358099
...
@@ -31,8 +31,8 @@ class TestFunctionGraph:
...
@@ -31,8 +31,8 @@ class TestFunctionGraph:
s
=
pickle
.
dumps
(
func
)
s
=
pickle
.
dumps
(
func
)
new_func
=
pickle
.
loads
(
s
)
new_func
=
pickle
.
loads
(
s
)
assert
all
(
type
(
a
)
==
type
(
b
)
for
a
,
b
in
zip
(
func
.
inputs
,
new_func
.
inputs
))
assert
all
(
type
(
a
)
is
type
(
b
)
for
a
,
b
in
zip
(
func
.
inputs
,
new_func
.
inputs
))
assert
all
(
type
(
a
)
==
type
(
b
)
for
a
,
b
in
zip
(
func
.
outputs
,
new_func
.
outputs
))
assert
all
(
type
(
a
)
is
type
(
b
)
for
a
,
b
in
zip
(
func
.
outputs
,
new_func
.
outputs
))
assert
all
(
assert
all
(
type
(
a
.
op
)
is
type
(
b
.
op
)
# noqa: E721
type
(
a
.
op
)
is
type
(
b
.
op
)
# noqa: E721
for
a
,
b
in
zip
(
func
.
apply_nodes
,
new_func
.
apply_nodes
)
for
a
,
b
in
zip
(
func
.
apply_nodes
,
new_func
.
apply_nodes
)
...
...
tests/graph/test_op.py
浏览文件 @
19358099
...
@@ -25,7 +25,7 @@ class MyType(Type):
...
@@ -25,7 +25,7 @@ class MyType(Type):
self
.
thingy
=
thingy
self
.
thingy
=
thingy
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
type
(
other
)
==
type
(
self
)
and
other
.
thingy
==
self
.
thingy
return
type
(
other
)
is
type
(
self
)
and
other
.
thingy
==
self
.
thingy
def
__str__
(
self
):
def
__str__
(
self
):
return
str
(
self
.
thingy
)
return
str
(
self
.
thingy
)
...
...
tests/link/c/test_basic.py
浏览文件 @
19358099
...
@@ -71,7 +71,7 @@ class TDouble(CType):
...
@@ -71,7 +71,7 @@ class TDouble(CType):
return
(
1
,)
return
(
1
,)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
return
type
(
self
)
is
type
(
other
)
def
__hash__
(
self
):
def
__hash__
(
self
):
return
hash
(
type
(
self
))
return
hash
(
type
(
self
))
...
...
tests/sparse/test_basic.py
浏览文件 @
19358099
...
@@ -348,7 +348,7 @@ class TestVerifyGradSparse:
...
@@ -348,7 +348,7 @@ class TestVerifyGradSparse:
self
.
structured
=
structured
self
.
structured
=
structured
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
and
self
.
structured
==
other
.
structured
return
(
type
(
self
)
is
type
(
other
))
and
self
.
structured
==
other
.
structured
def
__hash__
(
self
):
def
__hash__
(
self
):
return
hash
(
type
(
self
))
^
hash
(
self
.
structured
)
return
hash
(
type
(
self
))
^
hash
(
self
.
structured
)
...
...
tests/tensor/test_basic.py
浏览文件 @
19358099
...
@@ -3163,7 +3163,7 @@ def test_stack():
...
@@ -3163,7 +3163,7 @@ def test_stack():
sx
,
sy
=
dscalar
(),
dscalar
()
sx
,
sy
=
dscalar
(),
dscalar
()
rval
=
inplace_func
([
sx
,
sy
],
stack
([
sx
,
sy
]))(
-
4.0
,
-
2.0
)
rval
=
inplace_func
([
sx
,
sy
],
stack
([
sx
,
sy
]))(
-
4.0
,
-
2.0
)
assert
type
(
rval
)
==
np
.
ndarray
assert
type
(
rval
)
is
np
.
ndarray
assert
[
-
4
,
-
2
]
==
list
(
rval
)
assert
[
-
4
,
-
2
]
==
list
(
rval
)
...
...
tests/tensor/test_subtensor.py
浏览文件 @
19358099
...
@@ -819,7 +819,7 @@ class TestSubtensor(utt.OptimizationTestMixin):
...
@@ -819,7 +819,7 @@ class TestSubtensor(utt.OptimizationTestMixin):
assert
np
.
allclose
(
val
,
good
),
(
val
,
good
)
assert
np
.
allclose
(
val
,
good
),
(
val
,
good
)
# Test reuse of output memory
# Test reuse of output memory
if
type
(
AdvancedSubtensor1
)
==
AdvancedSubtensor1
:
if
type
(
AdvancedSubtensor1
)
is
AdvancedSubtensor1
:
op
=
AdvancedSubtensor1
()
op
=
AdvancedSubtensor1
()
# When idx is a TensorConstant.
# When idx is a TensorConstant.
if
hasattr
(
idx
,
"data"
):
if
hasattr
(
idx
,
"data"
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论