Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
980ecacf
提交
980ecacf
authored
1月 25, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
5月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add AtomicVariable and NominalVariable classes
上级
96366f99
隐藏空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
269 行增加
和
70 行删除
+269
-70
basic.py
aesara/graph/basic.py
+107
-13
fg.py
aesara/graph/fg.py
+7
-8
opt.py
aesara/graph/opt.py
+38
-43
basic.py
aesara/link/c/basic.py
+14
-5
test_basic.py
tests/graph/test_basic.py
+77
-0
test_fg.py
tests/graph/test_fg.py
+26
-1
没有找到文件。
aesara/graph/basic.py
浏览文件 @
980ecacf
"""Core graph classes."""
import
abc
import
warnings
from
collections
import
deque
from
copy
import
copy
...
...
@@ -10,6 +11,7 @@ from typing import (
Deque
,
Dict
,
Generator
,
Hashable
,
Iterable
,
Iterator
,
List
,
...
...
@@ -396,6 +398,14 @@ class Variable(Node):
def
owner
(
self
,
value
)
->
None
:
self
.
_owner
=
value
@property
def
index
(
self
):
return
self
.
_index
@index.setter
def
index
(
self
,
value
):
self
.
_index
=
value
def
__init__
(
self
,
type
,
...
...
@@ -411,7 +421,7 @@ class Variable(Node):
if
owner
is
not
None
and
not
isinstance
(
owner
,
Apply
):
raise
TypeError
(
"owner must be an Apply instance"
,
owner
)
self
.
_
owner
=
owner
self
.
owner
=
owner
if
index
is
not
None
and
not
isinstance
(
index
,
int
):
raise
TypeError
(
"index must be an int"
,
index
)
...
...
@@ -586,7 +596,98 @@ class Variable(Node):
return
d
class
Constant
(
Variable
):
class
AtomicVariable
(
Variable
):
"""A node type that has no ancestors and should never be considered an input to a graph."""
def
__init__
(
self
,
type
,
**
kwargs
):
super
()
.
__init__
(
type
,
**
kwargs
)
@abc.abstractmethod
def
signature
(
self
):
...
def
merge_signature
(
self
):
return
self
.
signature
()
def
equals
(
self
,
other
):
"""
This does what `__eq__` would normally do, but `Variable` and `Apply`
should always be hashable by `id`.
"""
return
isinstance
(
other
,
type
(
self
))
and
self
.
signature
()
==
other
.
signature
()
@property
def
owner
(
self
):
return
None
@owner.setter
def
owner
(
self
,
value
):
if
value
is
not
None
:
raise
ValueError
(
"AtomicVariable instances cannot have an owner."
)
@property
def
index
(
self
):
return
None
@index.setter
def
index
(
self
,
value
):
if
value
is
not
None
:
raise
ValueError
(
"AtomicVariable instances cannot have an index."
)
class
NominalVariable
(
AtomicVariable
):
"""A variable that enables alpha-equivalent comparisons."""
__instances__
:
Dict
[
Hashable
,
type
]
=
{}
def
__new__
(
cls
,
id
,
typ
,
**
kwargs
):
if
(
id
,
typ
)
not
in
cls
.
__instances__
:
var_type
=
typ
.
variable_type
type_name
=
f
"Nominal{var_type.__name__}"
def
_reduce
(
self
):
return
cls
,
(
self
.
id
,
self
.
type
)
def
_str
(
self
):
return
f
"*{self.id}-{var_type.__str__(self)}"
new_type
=
type
(
type_name
,
(
cls
,
var_type
),
{
"__reduce__"
:
_reduce
,
"__str__"
:
_str
}
)
res
=
super
()
.
__new__
(
new_type
)
cls
.
__instances__
[(
id
,
typ
)]
=
res
return
cls
.
__instances__
[(
id
,
typ
)]
def
__init__
(
self
,
id
,
typ
,
**
kwargs
):
self
.
id
=
id
super
()
.
__init__
(
typ
,
**
kwargs
)
def
clone
(
self
):
return
self
def
__eq__
(
self
,
other
):
if
self
is
other
:
return
True
return
(
type
(
self
)
==
type
(
other
)
and
self
.
id
==
other
.
id
and
self
.
type
==
other
.
type
)
def
__hash__
(
self
):
return
hash
((
type
(
self
),
self
.
id
,
self
.
type
))
def
__repr__
(
self
):
return
f
"{type(self).__name__}({repr(self.id)}, {repr(self.type)})"
def
signature
(
self
):
return
(
self
.
type
,
self
.
id
)
class
Constant
(
AtomicVariable
):
"""A `Variable` with a fixed `data` field.
`Constant` nodes make numerous optimizations possible (e.g. constant
...
...
@@ -602,23 +703,16 @@ class Constant(Variable):
# __slots__ = ['data']
def
__init__
(
self
,
type
,
data
,
name
=
None
):
super
()
.
__init__
(
type
,
None
,
None
,
name
)
super
()
.
__init__
(
type
,
name
=
name
)
self
.
data
=
type
.
filter
(
data
)
add_tag_trace
(
self
)
def
get_test_value
(
self
):
return
self
.
data
def
equals
(
self
,
other
):
# this does what __eq__ should do, but Variable and Apply should always be hashable by id
return
isinstance
(
other
,
Constant
)
and
self
.
signature
()
==
other
.
signature
()
def
signature
(
self
):
return
(
self
.
type
,
self
.
data
)
def
merge_signature
(
self
):
return
self
.
signature
()
def
__str__
(
self
):
if
self
.
name
is
not
None
:
return
self
.
name
...
...
@@ -641,9 +735,9 @@ class Constant(Variable):
if
value
is
not
None
:
raise
ValueError
(
"Constant instances cannot have an owner."
)
value
=
property
(
lambda
self
:
self
.
data
,
doc
=
"read-only data access method"
)
# index is not defined, because the `owner` attribute must necessarily be None
@property
def
value
(
self
):
return
self
.
data
def
walk
(
...
...
aesara/graph/fg.py
浏览文件 @
980ecacf
...
...
@@ -18,7 +18,7 @@ from typing_extensions import Literal
import
aesara
from
aesara.configdefaults
import
config
from
aesara.graph.basic
import
Apply
,
Constant
,
Node
,
Variable
,
applys_between
from
aesara.graph.basic
import
Apply
,
AtomicVariable
,
Node
,
Variable
,
applys_between
from
aesara.graph.basic
import
as_string
as
graph_as_string
from
aesara.graph.basic
import
clone_get_equiv
,
graph_inputs
,
io_toposort
,
vars_between
from
aesara.graph.features
import
AlreadyThere
,
Feature
,
ReplaceValidate
...
...
@@ -101,7 +101,9 @@ class FunctionGraph(MetaObject):
raise
ValueError
(
"No outputs specified"
)
if
inputs
is
None
:
inputs
=
[
i
for
i
in
graph_inputs
(
outputs
)
if
not
isinstance
(
i
,
Constant
)]
inputs
=
[
i
for
i
in
graph_inputs
(
outputs
)
if
not
isinstance
(
i
,
AtomicVariable
)
]
if
clone
:
_memo
=
clone_get_equiv
(
...
...
@@ -306,7 +308,7 @@ class FunctionGraph(MetaObject):
self
.
import_node
(
var
.
owner
,
reason
=
reason
,
import_missing
=
import_missing
)
elif
(
var
.
owner
is
None
and
not
isinstance
(
var
,
Constant
)
and
not
isinstance
(
var
,
AtomicVariable
)
and
var
not
in
self
.
inputs
):
from
aesara.graph.null_type
import
NullType
...
...
@@ -354,7 +356,7 @@ class FunctionGraph(MetaObject):
for
var
in
node
.
inputs
:
if
(
var
.
owner
is
None
and
not
isinstance
(
var
,
Constant
)
and
not
isinstance
(
var
,
AtomicVariable
)
and
var
not
in
self
.
inputs
):
if
import_missing
:
...
...
@@ -515,9 +517,6 @@ class FunctionGraph(MetaObject):
)
for
node
,
i
in
list
(
self
.
clients
[
var
]):
assert
(
node
==
"output"
and
self
.
outputs
[
i
]
is
var
)
or
(
isinstance
(
node
,
Apply
)
and
node
.
inputs
[
i
]
is
var
)
self
.
change_node_input
(
node
,
i
,
new_var
,
reason
=
reason
,
import_missing
=
import_missing
)
...
...
@@ -839,7 +838,7 @@ class FunctionGraph(MetaObject):
if
(
variable
.
owner
is
None
and
variable
not
in
self
.
inputs
and
not
isinstance
(
variable
,
Constant
)
and
not
isinstance
(
variable
,
AtomicVariable
)
):
raise
Exception
(
f
"Undeclared input: {variable}"
)
for
cl_node
,
i
in
self
.
clients
[
variable
]:
...
...
aesara/graph/opt.py
浏览文件 @
980ecacf
...
...
@@ -19,13 +19,12 @@ from functools import _compose_mro, partial, reduce # type: ignore
from
itertools
import
chain
from
typing
import
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
from
typing_extensions
import
TypeAlias
import
aesara
from
aesara.configdefaults
import
config
from
aesara.graph
import
destroyhandler
as
dh
from
aesara.graph.basic
import
(
Apply
,
AtomicVariable
,
Constant
,
Variable
,
applys_between
,
...
...
@@ -500,12 +499,9 @@ class MergeFeature(Feature):
assert
not
hasattr
(
fgraph
,
"merge_feature"
)
fgraph
.
merge_feature
=
self
# For constants
self
.
seen_constants
=
set
()
# variable -> signature (for constants)
self
.
const_sig
=
AssocList
()
# signature -> variable (for constants)
self
.
const_sig_inv
=
AssocList
()
self
.
seen_atomics
=
set
()
self
.
atomic_sig
=
AssocList
()
self
.
atomic_sig_inv
=
AssocList
()
# For all Apply nodes
# Set of distinct (not mergeable) nodes
...
...
@@ -539,17 +535,13 @@ class MergeFeature(Feature):
self
.
nodes_seen
.
discard
(
node
)
self
.
process_node
(
fgraph
,
node
)
# Since we are in on_change_input, node should have inputs.
if
not
isinstance
(
node
,
str
):
assert
node
.
inputs
if
isinstance
(
new_r
,
Constant
):
self
.
process_constant
(
fgraph
,
new_r
)
if
isinstance
(
new_r
,
AtomicVariable
):
self
.
process_atomic
(
fgraph
,
new_r
)
def
on_import
(
self
,
fgraph
,
node
,
reason
):
for
c
in
node
.
inputs
:
if
isinstance
(
c
,
Constant
):
self
.
process_
constant
(
fgraph
,
c
)
if
isinstance
(
c
,
AtomicVariable
):
self
.
process_
atomic
(
fgraph
,
c
)
self
.
process_node
(
fgraph
,
node
)
...
...
@@ -558,19 +550,19 @@ class MergeFeature(Feature):
if
not
node
.
inputs
:
self
.
noinput_nodes
.
discard
(
node
)
for
c
in
node
.
inputs
:
if
isinstance
(
c
,
Constant
)
and
(
len
(
fgraph
.
clients
[
c
])
<=
1
)
:
if
isinstance
(
c
,
AtomicVariable
)
and
len
(
fgraph
.
clients
[
c
])
<=
1
:
# This was the last node using this constant
sig
=
self
.
const
_sig
[
c
]
self
.
const
_sig
.
discard
(
c
)
self
.
const
_sig_inv
.
discard
(
sig
)
self
.
seen_
constant
s
.
discard
(
id
(
c
))
def
process_
constant
(
self
,
fgraph
,
c
):
"""Check if a
constant
`c` can be merged, and queue that replacement."""
if
id
(
c
)
in
self
.
seen_
constant
s
:
sig
=
self
.
atomic
_sig
[
c
]
self
.
atomic
_sig
.
discard
(
c
)
self
.
atomic
_sig_inv
.
discard
(
sig
)
self
.
seen_
atomic
s
.
discard
(
id
(
c
))
def
process_
atomic
(
self
,
fgraph
,
c
):
"""Check if a
n atomic
`c` can be merged, and queue that replacement."""
if
id
(
c
)
in
self
.
seen_
atomic
s
:
return
sig
=
c
.
merge_signature
()
other_c
=
self
.
const
_sig_inv
.
get
(
sig
,
None
)
other_c
=
self
.
atomic
_sig_inv
.
get
(
sig
,
None
)
if
other_c
is
not
None
:
# multiple names will clobber each other..
# we adopt convention to keep the last name
...
...
@@ -579,9 +571,9 @@ class MergeFeature(Feature):
self
.
scheduled
.
append
([[(
c
,
other_c
,
"merge"
)]])
else
:
# this is a new constant
self
.
const
_sig
[
c
]
=
sig
self
.
const
_sig_inv
[
sig
]
=
c
self
.
seen_
constant
s
.
add
(
id
(
c
))
self
.
atomic
_sig
[
c
]
=
sig
self
.
atomic
_sig_inv
[
sig
]
=
c
self
.
seen_
atomic
s
.
add
(
id
(
c
))
def
process_node
(
self
,
fgraph
,
node
):
r"""Check if a `node` can be merged, and queue that replacement.
...
...
@@ -602,6 +594,7 @@ class MergeFeature(Feature):
# using `node.inputs[0]` will make us look at more nodes on
# average, so by picking the smallest clients list, we might speed
# things up?
clients
=
sorted
(
(
fgraph
.
clients
[
inp
]
for
inp
in
node
.
inputs
),
key
=
lambda
x
:
len
(
x
)
)[
0
]
...
...
@@ -616,6 +609,7 @@ class MergeFeature(Feature):
replacement_candidates
=
[]
for
candidate
in
merge_candidates
:
if
candidate
is
node
:
continue
if
len
(
node
.
inputs
)
!=
len
(
candidate
.
inputs
):
...
...
@@ -658,9 +652,10 @@ class MergeOptimizer(GlobalOptimizer):
one are transferred to the other and one of them is removed from the graph.
This procedure is carried out in input-to-output order throughout the graph.
The first step of merging is constant-merging, so that all clients of an
``int(1)`` for example, are transferred to just one particular instance of
``int(1)``.
The first step of merging is atomic variable-merging, so that all clients of a
:class:`Constant` like ``int(1)``, are transferred to just one particular
instance of ``int(1)``. :class:`NominalVariable`\s are not merged individually
like this; only the nodes that use them are.
"""
...
...
@@ -678,7 +673,7 @@ class MergeOptimizer(GlobalOptimizer):
callbacks_before
=
fgraph
.
execute_callbacks_times
.
copy
()
nb_merged
=
0
nb_
constant
=
0
nb_
atomic
=
0
while
sched
:
pairs_list
=
sched
.
pop
()
success
=
True
...
...
@@ -739,8 +734,8 @@ class MergeOptimizer(GlobalOptimizer):
pairs
=
[(
pairs
[
0
][
1
],
pairs
[
0
][
0
])]
try
:
# If they're all `
Constant
`s, there's no need to call validate.
if
all
(
isinstance
(
old
,
Constant
)
for
old
,
_
in
pairs
):
# If they're all `
AtomicVariable
`s, there's no need to call validate.
if
all
(
isinstance
(
old
,
AtomicVariable
)
for
old
,
_
in
pairs
):
fgraph
.
replace_all
(
pairs
,
reason
=
"MergeOptimizer"
)
else
:
fgraph
.
replace_all_validate
(
pairs
,
reason
=
"MergeOptimizer"
)
...
...
@@ -753,8 +748,8 @@ class MergeOptimizer(GlobalOptimizer):
if
success
:
nb_merged
+=
len
(
pairs
)
if
isinstance
(
pairs
[
0
][
0
],
Constant
):
nb_
constant
+=
1
if
isinstance
(
pairs
[
0
][
0
],
AtomicVariable
):
nb_
atomic
+=
1
break
if
fgraph
.
profile
:
...
...
@@ -782,7 +777,7 @@ class MergeOptimizer(GlobalOptimizer):
callback_time
,
callbacks_time
,
nb_merged
,
nb_
constant
,
nb_
atomic
,
)
def
__str__
(
self
):
...
...
@@ -798,14 +793,14 @@ class MergeOptimizer(GlobalOptimizer):
callback_time
,
callbacks_time
,
nb_merged
,
nb_
constant
,
nb_
atomic
,
)
=
prof
blanc
=
" "
*
level
print
(
blanc
,
"MergeOptimizer"
,
file
=
stream
)
print
(
blanc
,
f
" nb fail={nb_fail:5d} merged={nb_merged:5d}
constant={nb_constant
:5d}"
,
f
" nb fail={nb_fail:5d} merged={nb_merged:5d}
atomic={nb_atomic
:5d}"
,
file
=
stream
,
)
print
(
...
...
@@ -836,7 +831,7 @@ class MergeOptimizer(GlobalOptimizer):
callback_time
=
merge_none_number
(
prof1
[
3
],
prof2
[
3
])
callbacks_time
=
merge_dict
(
prof1
[
4
],
prof2
[
4
])
nb_merged
=
prof1
[
5
]
+
prof2
[
5
]
nb_
constant
=
prof1
[
6
]
+
prof2
[
6
]
nb_
atomic
=
prof1
[
6
]
+
prof2
[
6
]
return
(
nb_fail
,
replace_time
,
...
...
@@ -844,7 +839,7 @@ class MergeOptimizer(GlobalOptimizer):
callback_time
,
callbacks_time
,
nb_merged
,
nb_
constant
,
nb_
atomic
,
)
...
...
@@ -1127,7 +1122,7 @@ class LocalOptTracker:
def
__init__
(
self
):
self
.
tracked_instances
:
Dict
[
Op
,
List
[
LocalOptimizer
]]
=
{}
self
.
tracked_types
:
Dict
[
TypeAlias
,
List
[
LocalOptimizer
]]
=
{}
self
.
tracked_types
:
Dict
[
type
,
List
[
LocalOptimizer
]]
=
{}
self
.
untracked_opts
:
List
[
LocalOptimizer
]
=
[]
def
add_tracker
(
self
,
rw
:
LocalOptimizer
):
...
...
aesara/link/c/basic.py
浏览文件 @
980ecacf
...
...
@@ -14,7 +14,13 @@ import numpy as np
from
aesara.compile.compilelock
import
lock_ctx
from
aesara.configdefaults
import
config
from
aesara.graph.basic
import
Constant
,
NoParams
,
io_toposort
,
vars_between
from
aesara.graph.basic
import
(
AtomicVariable
,
Constant
,
NoParams
,
io_toposort
,
vars_between
,
)
from
aesara.graph.callcache
import
CallCache
from
aesara.link.basic
import
Container
,
Linker
,
LocalLinker
,
PerformLinker
from
aesara.link.c.cmodule
import
(
...
...
@@ -641,7 +647,7 @@ class CLinker(Linker):
self
.
orphans
=
list
(
r
for
r
in
self
.
variables
if
isinstance
(
r
,
Constant
)
and
r
not
in
self
.
inputs
if
isinstance
(
r
,
AtomicVariable
)
and
r
not
in
self
.
inputs
)
# C type constants (aesara.scalar.ScalarType). They don't request an object
self
.
consts
=
[]
...
...
@@ -730,7 +736,7 @@ class CLinker(Linker):
[
get_c_declare
,
get_c_extract
,
get_c_cleanup
],
]
elif
variable
in
self
.
orphans
:
if
not
isinstance
(
variable
,
Constant
):
if
not
isinstance
(
variable
,
AtomicVariable
):
raise
TypeError
(
"All orphans to CLinker must be Constant instances. "
f
"Got {variable}"
...
...
@@ -1404,7 +1410,7 @@ class CLinker(Linker):
# It is important that a variable (i)
# yield a 'position' that reflects its role in code_gen()
if
isinstance
(
i
,
Constant
):
# orphans
if
isinstance
(
i
,
AtomicVariable
):
# orphans
if
id
(
i
)
not
in
constant_ids
:
isig
=
(
i
.
signature
(),
topological_pos
,
i_idx
)
# If the Aesara constant provides a strong hash
...
...
@@ -1634,7 +1640,10 @@ class CLinker(Linker):
]
in_storage
=
[
x
for
i
,
x
in
enumerate
(
in_storage
)
if
i
not
in
dupidx
]
if
storage_map
is
None
:
orphd
=
[[
orphan
.
data
]
for
orphan
in
self
.
orphans
]
orphd
=
[
[
orphan
.
data
]
if
isinstance
(
orphan
,
Constant
)
else
[]
for
orphan
in
self
.
orphans
]
else
:
orphd
=
[
storage_map
[
orphan
]
for
orphan
in
self
.
orphans
]
...
...
tests/graph/test_basic.py
浏览文件 @
980ecacf
...
...
@@ -8,6 +8,7 @@ from aesara import config, function, shared
from
aesara
import
tensor
as
at
from
aesara.graph.basic
import
(
Apply
,
NominalVariable
,
Variable
,
ancestors
,
applys_between
,
...
...
@@ -52,6 +53,9 @@ class MyType(Type):
def
__eq__
(
self
,
other
):
return
isinstance
(
other
,
MyType
)
and
other
.
thingy
==
self
.
thingy
def
__hash__
(
self
):
return
hash
((
type
(
self
),
self
.
thingy
))
def
__str__
(
self
):
return
f
"R{self.thingy}"
...
...
@@ -694,3 +698,76 @@ def test_clone_new_inputs():
assert
z_node_new
.
outputs
[
0
]
.
type
.
shape
==
(
1
,)
assert
z_node_new
.
inputs
[
0
]
.
type
.
shape
==
(
1
,)
assert
z_node_new
.
inputs
[
1
]
.
type
.
shape
==
(
1
,)
def
test_NominalVariable
():
type1
=
MyType
(
1
)
nv1
=
NominalVariable
(
1
,
type1
)
nv2
=
NominalVariable
(
1
,
type1
)
assert
nv1
is
nv2
assert
nv1
.
equals
(
nv2
)
assert
hash
(
nv1
)
==
hash
(
nv2
)
type2
=
MyType
(
2
)
nv3
=
NominalVariable
(
1
,
type2
)
assert
not
nv1
.
equals
(
nv3
)
assert
hash
(
nv1
)
!=
hash
(
nv3
)
type3
=
MyType
(
1
)
assert
type3
==
type1
nv4
=
NominalVariable
(
1
,
type3
)
assert
nv1
is
nv4
assert
nv1
.
equals
(
nv4
)
assert
hash
(
nv1
)
==
hash
(
nv4
)
nv5
=
NominalVariable
(
2
,
type3
)
assert
not
nv4
.
equals
(
nv5
)
assert
hash
(
nv4
)
!=
hash
(
nv5
)
assert
repr
(
nv5
)
==
f
"NominalVariable(2, {repr(type3)})"
assert
nv5
.
signature
()
==
(
type3
,
2
)
nv5_pkld
=
pickle
.
dumps
(
nv5
)
nv5_unpkld
=
pickle
.
loads
(
nv5_pkld
)
assert
type
(
nv5_unpkld
)
is
type
(
nv5
)
assert
nv5_unpkld
.
equals
(
nv5
)
assert
nv5_unpkld
is
nv5
nv5_clone
=
nv5
.
clone
()
assert
type
(
nv5_clone
)
is
type
(
nv5
)
assert
nv5_clone
.
equals
(
nv5
)
assert
nv5_clone
is
nv5
def
test_NominalVariable_create_variable_type
():
ttype
=
TensorType
(
"float64"
,
(
None
,
None
))
ntv
=
NominalVariable
(
0
,
ttype
)
assert
isinstance
(
ntv
,
TensorVariable
)
assert
isinstance
(
ntv
,
NominalVariable
)
assert
ntv
.
ndim
==
2
assert
ntv
.
broadcastable
==
(
False
,
False
)
assert
ntv
.
dtype
==
"float64"
ntv2
=
NominalVariable
(
0
,
ttype
)
assert
type
(
ntv2
)
is
type
(
ntv
)
assert
ntv2
.
equals
(
ntv
)
assert
ntv2
is
ntv
ntv_pkld
=
pickle
.
dumps
(
ntv
)
ntv_unpkld
=
pickle
.
loads
(
ntv_pkld
)
assert
type
(
ntv_unpkld
)
is
type
(
ntv
)
assert
ntv_unpkld
.
equals
(
ntv
)
assert
ntv_unpkld
is
ntv
tests/graph/test_fg.py
浏览文件 @
980ecacf
...
...
@@ -4,9 +4,19 @@ import numpy as np
import
pytest
from
aesara.configdefaults
import
config
from
aesara.graph.basic
import
NominalVariable
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.utils
import
MissingInputError
from
tests.graph.utils
import
MyConstant
,
MyOp
,
MyVariable
,
MyVariable2
,
op1
,
op2
,
op3
from
tests.graph.utils
import
(
MyConstant
,
MyOp
,
MyType
,
MyVariable
,
MyVariable2
,
op1
,
op2
,
op3
,
)
class
TestFunctionGraph
:
...
...
@@ -683,3 +693,18 @@ class TestFunctionGraph:
assert
not
fg
.
variables
assert
not
fg
.
apply_nodes
assert
fg
.
clients
==
{
var1
:
[],
var2
:
[]}
def
test_nominals
(
self
):
t1
=
MyType
()
nm
=
NominalVariable
(
1
,
t1
)
nm2
=
NominalVariable
(
2
,
t1
)
v1
=
op1
(
nm
,
nm2
)
fg
=
FunctionGraph
(
outputs
=
[
v1
],
clone
=
False
)
assert
nm
not
in
fg
.
inputs
assert
nm2
not
in
fg
.
inputs
assert
nm
in
fg
.
variables
assert
nm2
in
fg
.
variables
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论