Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
e1d46639
提交
e1d46639
authored
1月 01, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
1月 01, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove theano.gof.unify-specific objects from theano.gof.utils
上级
9a32adb3
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
109 行增加
和
170 行删除
+109
-170
unify.py
theano/gof/unify.py
+109
-58
utils.py
theano/gof/utils.py
+0
-112
没有找到文件。
theano/gof/unify.py
浏览文件 @
e1d46639
...
@@ -13,10 +13,117 @@ that satisfies the constraints. That's useful for pattern matching.
...
@@ -13,10 +13,117 @@ that satisfies the constraints. That's useful for pattern matching.
from
copy
import
copy
from
copy
import
copy
from
functools
import
partial
from
functools
import
partial
from
theano.gof.utils
import
ANY_TYPE
,
FALL_THROUGH
,
comm_guard
class
Keyword
:
def
__init__
(
self
,
name
,
nonzero
=
True
):
self
.
name
=
name
self
.
nonzero
=
nonzero
def
__nonzero__
(
self
):
# Python 2.x
return
self
.
__bool__
()
def
__bool__
(
self
):
# Python 3.x
return
self
.
nonzero
def
__str__
(
self
):
return
f
"<{self.name}>"
def
__repr__
(
self
):
return
f
"<{self.name}>"
ABORT
=
Keyword
(
"ABORT"
,
False
)
RETRY
=
Keyword
(
"RETRY"
,
False
)
FAILURE
=
Keyword
(
"FAILURE"
,
False
)
simple_types
=
(
int
,
str
,
float
,
bool
,
type
(
None
),
Keyword
)
################################
ANY_TYPE
=
Keyword
(
"ANY_TYPE"
)
FALL_THROUGH
=
Keyword
(
"FALL_THROUGH"
)
def
comm_guard
(
type1
,
type2
):
def
wrap
(
f
):
old_f
=
f
.
__globals__
[
f
.
__name__
]
def
new_f
(
arg1
,
arg2
,
*
rest
):
if
(
type1
is
ANY_TYPE
or
isinstance
(
arg1
,
type1
))
and
(
type2
is
ANY_TYPE
or
isinstance
(
arg2
,
type2
)
):
pass
elif
(
type1
is
ANY_TYPE
or
isinstance
(
arg2
,
type1
))
and
(
type2
is
ANY_TYPE
or
isinstance
(
arg1
,
type2
)
):
arg1
,
arg2
=
arg2
,
arg1
else
:
return
old_f
(
arg1
,
arg2
,
*
rest
)
variable
=
f
(
arg1
,
arg2
,
*
rest
)
if
variable
is
FALL_THROUGH
:
return
old_f
(
arg1
,
arg2
,
*
rest
)
else
:
return
variable
new_f
.
__name__
=
f
.
__name__
def
typename
(
type
):
if
isinstance
(
type
,
Keyword
):
return
str
(
type
)
elif
isinstance
(
type
,
(
tuple
,
list
)):
return
"("
+
", "
.
join
([
x
.
__name__
for
x
in
type
])
+
")"
else
:
return
type
.
__name__
new_f
.
__doc__
=
(
str
(
old_f
.
__doc__
)
+
"
\n
"
+
", "
.
join
([
typename
(
type
)
for
type
in
(
type1
,
type2
)])
+
"
\n
"
+
str
(
f
.
__doc__
or
""
)
)
return
new_f
return
wrap
def
type_guard
(
type1
):
def
wrap
(
f
):
old_f
=
f
.
__globals__
[
f
.
__name__
]
def
new_f
(
arg1
,
*
rest
):
if
type1
is
ANY_TYPE
or
isinstance
(
arg1
,
type1
):
variable
=
f
(
arg1
,
*
rest
)
if
variable
is
FALL_THROUGH
:
return
old_f
(
arg1
,
*
rest
)
else
:
return
variable
else
:
return
old_f
(
arg1
,
*
rest
)
new_f
.
__name__
=
f
.
__name__
def
typename
(
type
):
if
isinstance
(
type
,
Keyword
):
return
str
(
type
)
elif
isinstance
(
type
,
(
tuple
,
list
)):
return
"("
+
", "
.
join
([
x
.
__name__
for
x
in
type
])
+
")"
else
:
return
type
.
__name__
new_f
.
__doc__
=
(
str
(
old_f
.
__doc__
)
+
"
\n
"
+
", "
.
join
([
typename
(
type
)
for
type
in
(
type1
,)])
+
"
\n
"
+
str
(
f
.
__doc__
or
""
)
)
return
new_f
return
wrap
class
Variable
:
class
Variable
:
...
@@ -111,9 +218,6 @@ class VariableInList: # not a subclass of Variable
...
@@ -111,9 +218,6 @@ class VariableInList: # not a subclass of Variable
self
.
variable
=
variable
self
.
variable
=
variable
################################
_all
=
{}
_all
=
{}
...
@@ -133,9 +237,6 @@ OrV = partial(var_lookup, OrVariable)
...
@@ -133,9 +237,6 @@ OrV = partial(var_lookup, OrVariable)
NV
=
partial
(
var_lookup
,
NotVariable
)
NV
=
partial
(
var_lookup
,
NotVariable
)
################################
class
Unification
:
class
Unification
:
"""
"""
This class represents a possible unification of a group of variables
This class represents a possible unification of a group of variables
...
@@ -191,9 +292,6 @@ class Unification:
...
@@ -191,9 +292,6 @@ class Unification:
return
self
.
unif
.
get
(
v
,
(
v
,
None
))[
0
]
return
self
.
unif
.
get
(
v
,
(
v
,
None
))[
0
]
################################
def
unify_walk
(
a
,
b
,
U
):
def
unify_walk
(
a
,
b
,
U
):
"""
"""
unify_walk(a, b, U) returns an Unification where a and b are unified,
unify_walk(a, b, U) returns an Unification where a and b are unified,
...
@@ -416,9 +514,6 @@ def unify_walk(v, o, U):
...
@@ -416,9 +514,6 @@ def unify_walk(v, o, U):
return
FALL_THROUGH
# call the next version of unify_walk that matches the type signature
return
FALL_THROUGH
# call the next version of unify_walk that matches the type signature
################################
class
FVar
:
class
FVar
:
def
__init__
(
self
,
fn
,
*
args
):
def
__init__
(
self
,
fn
,
*
args
):
self
.
fn
=
fn
self
.
fn
=
fn
...
@@ -428,9 +523,6 @@ class FVar:
...
@@ -428,9 +523,6 @@ class FVar:
return
self
.
fn
(
*
[
unify_build
(
arg
,
u
)
for
arg
in
self
.
args
])
return
self
.
fn
(
*
[
unify_build
(
arg
,
u
)
for
arg
in
self
.
args
])
################################
def
unify_merge
(
a
,
b
,
U
):
def
unify_merge
(
a
,
b
,
U
):
return
a
return
a
...
@@ -503,54 +595,13 @@ def unify_merge(v, o, U):
...
@@ -503,54 +595,13 @@ def unify_merge(v, o, U):
return
FALL_THROUGH
# call the next version of unify_walk that matches the type signature
return
FALL_THROUGH
# call the next version of unify_walk that matches the type signature
################################
def
unify_build
(
x
,
U
):
def
unify_build
(
x
,
U
):
return
unify_merge
(
x
,
x
,
U
)
return
unify_merge
(
x
,
x
,
U
)
################################
def
unify
(
a
,
b
):
def
unify
(
a
,
b
):
U
=
unify_walk
(
a
,
b
,
Unification
())
U
=
unify_walk
(
a
,
b
,
Unification
())
if
not
U
:
if
not
U
:
return
None
,
False
return
None
,
False
else
:
else
:
return
unify_merge
(
a
,
b
,
U
),
U
return
unify_merge
(
a
,
b
,
U
),
U
################################
if
__name__
==
"__main__"
:
vx
=
NotVariable
(
"x"
,
[
"big"
,
"bones"
])
vy
=
OrVariable
(
"y"
,
[
"hello"
,
"big"
])
vz
=
V
(
"z"
)
va
=
V
(
"a"
)
vl
=
VariableInList
(
vz
)
pattern1
=
dict
(
hey
=
vx
,
ulala
=
va
,
a
=
1
)
pattern2
=
dict
(
hey
=
vy
,
ulala
=
10
,
b
=
2
)
# pattern1 = ["hello", "big", "bones"]
# pattern2 = vl
# pattern1 = [vx]#, "big", "bones"]
# pattern2 = [vy]#, vy, vz]
U
=
unify_walk
(
pattern1
,
pattern2
,
Unification
())
if
U
:
print
(
U
[
va
])
print
(
U
[
vx
])
print
(
U
[
vy
])
print
(
U
[
vz
])
print
(
unify_merge
(
pattern1
,
pattern2
,
U
))
else
:
print
(
"no match"
)
U
=
unify_walk
((
1
,
2
),
(
va
,
va
),
Unification
())
print
(
U
[
va
])
theano/gof/utils.py
浏览文件 @
e1d46639
...
@@ -461,118 +461,6 @@ def toposort(prereqs_d):
...
@@ -461,118 +461,6 @@ def toposort(prereqs_d):
return
seq
return
seq
class
Keyword
:
def
__init__
(
self
,
name
,
nonzero
=
True
):
self
.
name
=
name
self
.
nonzero
=
nonzero
def
__nonzero__
(
self
):
# Python 2.x
return
self
.
__bool__
()
def
__bool__
(
self
):
# Python 3.x
return
self
.
nonzero
def
__str__
(
self
):
return
f
"<{self.name}>"
def
__repr__
(
self
):
return
f
"<{self.name}>"
ABORT
=
Keyword
(
"ABORT"
,
False
)
RETRY
=
Keyword
(
"RETRY"
,
False
)
FAILURE
=
Keyword
(
"FAILURE"
,
False
)
simple_types
=
(
int
,
str
,
float
,
bool
,
type
(
None
),
Keyword
)
ANY_TYPE
=
Keyword
(
"ANY_TYPE"
)
FALL_THROUGH
=
Keyword
(
"FALL_THROUGH"
)
def
comm_guard
(
type1
,
type2
):
def
wrap
(
f
):
old_f
=
f
.
__globals__
[
f
.
__name__
]
def
new_f
(
arg1
,
arg2
,
*
rest
):
if
(
type1
is
ANY_TYPE
or
isinstance
(
arg1
,
type1
))
and
(
type2
is
ANY_TYPE
or
isinstance
(
arg2
,
type2
)
):
pass
elif
(
type1
is
ANY_TYPE
or
isinstance
(
arg2
,
type1
))
and
(
type2
is
ANY_TYPE
or
isinstance
(
arg1
,
type2
)
):
arg1
,
arg2
=
arg2
,
arg1
else
:
return
old_f
(
arg1
,
arg2
,
*
rest
)
variable
=
f
(
arg1
,
arg2
,
*
rest
)
if
variable
is
FALL_THROUGH
:
return
old_f
(
arg1
,
arg2
,
*
rest
)
else
:
return
variable
new_f
.
__name__
=
f
.
__name__
def
typename
(
type
):
if
isinstance
(
type
,
Keyword
):
return
str
(
type
)
elif
isinstance
(
type
,
(
tuple
,
list
)):
return
"("
+
", "
.
join
([
x
.
__name__
for
x
in
type
])
+
")"
else
:
return
type
.
__name__
new_f
.
__doc__
=
(
str
(
old_f
.
__doc__
)
+
"
\n
"
+
", "
.
join
([
typename
(
type
)
for
type
in
(
type1
,
type2
)])
+
"
\n
"
+
str
(
f
.
__doc__
or
""
)
)
return
new_f
return
wrap
def
type_guard
(
type1
):
def
wrap
(
f
):
old_f
=
f
.
__globals__
[
f
.
__name__
]
def
new_f
(
arg1
,
*
rest
):
if
type1
is
ANY_TYPE
or
isinstance
(
arg1
,
type1
):
variable
=
f
(
arg1
,
*
rest
)
if
variable
is
FALL_THROUGH
:
return
old_f
(
arg1
,
*
rest
)
else
:
return
variable
else
:
return
old_f
(
arg1
,
*
rest
)
new_f
.
__name__
=
f
.
__name__
def
typename
(
type
):
if
isinstance
(
type
,
Keyword
):
return
str
(
type
)
elif
isinstance
(
type
,
(
tuple
,
list
)):
return
"("
+
", "
.
join
([
x
.
__name__
for
x
in
type
])
+
")"
else
:
return
type
.
__name__
new_f
.
__doc__
=
(
str
(
old_f
.
__doc__
)
+
"
\n
"
+
", "
.
join
([
typename
(
type
)
for
type
in
(
type1
,)])
+
"
\n
"
+
str
(
f
.
__doc__
or
""
)
)
return
new_f
return
wrap
def
flatten
(
a
):
def
flatten
(
a
):
"""
"""
Recursively flatten tuple, list and set in a list.
Recursively flatten tuple, list and set in a list.
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论