Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
4573a825
提交
4573a825
authored
5月 21, 2008
作者:
Olivier Breuleux
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
cleanup
上级
3d9074ac
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
18 行增加
和
498 行删除
+18
-498
_test_opt.py
gof/_test_opt.py
+1
-57
_test_toolbox.py
gof/_test_toolbox.py
+9
-8
opt.py
gof/opt.py
+6
-432
utils.py
gof/utils.py
+2
-1
没有找到文件。
gof/_test_opt.py
浏览文件 @
4573a825
...
@@ -245,26 +245,6 @@ class _test_PatternOptimizer(unittest.TestCase):
...
@@ -245,26 +245,6 @@ class _test_PatternOptimizer(unittest.TestCase):
# PatternOptimizer((op4, (op1, 'x', 'y'), (op1, 'x', 'y')),
# PatternOptimizer((op4, (op1, 'x', 'y'), (op1, 'x', 'y')),
# (op3, 'x', 'y')).optimize(g)
# (op3, 'x', 'y')).optimize(g)
# assert str(g) == "[Op3(x, y)]"
# assert str(g) == "[Op3(x, y)]"
# class _test_PatternDescOptimizer(unittest.TestCase):
# def test_replace_output(self):
# # replacing the whole graph
# x, y, z = inputs()
# e = op1(op2(x, y), z)
# g = env([x, y, z], [e])
# PatternDescOptimizer((Op1, (Op2, '1', '2'), '3'),
# (Op4, '3', '2')).optimize(g)
# assert str(g) == "[Op4(z, y)]"
# def test_eq(self):
# x, y, z = inputs()
# e = op1(op_y(x, y, 37, 88), op2(op_y(y, z, 1, 7)))
# g = env([x, y, z], [e])
# PatternDescOptimizer((op_z, '1', '2'),
# (op3, '2', '1')).optimize(g)
# assert str(g) == "[Op1(Op3(y, x), Op2(OpZ(y, z)))]"
OpSubOptimizer
=
lambda
op1
,
op2
:
TopoOptimizer
(
OpSub
(
op1
,
op2
))
OpSubOptimizer
=
lambda
op1
,
op2
:
TopoOptimizer
(
OpSub
(
op1
,
op2
))
...
@@ -384,42 +364,6 @@ class _test_MergeOptimizer(unittest.TestCase):
...
@@ -384,42 +364,6 @@ class _test_MergeOptimizer(unittest.TestCase):
# class _test_ConstantFinder(unittest.TestCase):
# def test_straightforward(self):
# x, y, z = inputs()
# y.data = 2
# z.data = 2
# e = op1(x, y, z)
# g = env([x], [e])
# ConstantFinder().optimize(g)
# assert y.constant and z.constant
# MergeOptimizer().optimize(g)
# assert str(g) == "[Op1(x, y, y)]" \
# or str(g) == "[Op1(x, z, z)]"
# def test_deep(self):
# x, y, z = inputs()
# y.data = 2
# z.data = 2
# e = op1(op2(x, y), op2(x, y), op2(x, z))
# g = env([x], [e])
# ConstantFinder().optimize(g)
# assert y.constant and z.constant
# MergeOptimizer().optimize(g)
# assert str(g) == "[Op1(*1 -> Op2(x, y), *1, *1)]" \
# or str(g) == "[Op1(*1 -> Op2(x, z), *1, *1)]"
# def test_destroyed_orphan_not_constant(self):
# x, y, z = inputs()
# y.data = 2
# z.data = 2
# e = op_d(x, op2(y, z)) # here x is destroyed by op_d
# g = env([y], [e])
# ConstantFinder().optimize(g)
# assert not getattr(x, 'constant', False) and z.constant
# MergeOptimizer().optimize(g)
reenter
=
Exception
(
"Re-Entered"
)
reenter
=
Exception
(
"Re-Entered"
)
class
LoopyMacro
(
Macro
):
class
LoopyMacro
(
Macro
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -448,7 +392,7 @@ class _test_ExpandMacro(unittest.TestCase):
...
@@ -448,7 +392,7 @@ class _test_ExpandMacro(unittest.TestCase):
x
,
y
,
z
=
inputs
()
x
,
y
,
z
=
inputs
()
e
=
Macro1
()(
x
,
y
)
e
=
Macro1
()(
x
,
y
)
g
=
Env
([
x
,
y
],
[
e
])
g
=
Env
([
x
,
y
],
[
e
])
expand_macros
.
optimize
(
g
)
ExpandMacros
()
.
optimize
(
g
)
assert
str
(
g
)
==
"[Op1(y, x)]"
assert
str
(
g
)
==
"[Op1(y, x)]"
def
test_loopy_1
(
self
):
def
test_loopy_1
(
self
):
...
...
gof/_test_toolbox.py
浏览文件 @
4573a825
...
@@ -100,14 +100,15 @@ class _test_NodeFinder(unittest.TestCase):
...
@@ -100,14 +100,15 @@ class _test_NodeFinder(unittest.TestCase):
if
not
len
([
x
for
x
in
g
.
get_nodes
(
type
)])
==
num
:
if
not
len
([
x
for
x
in
g
.
get_nodes
(
type
)])
==
num
:
self
.
fail
((
type
,
num
))
self
.
fail
((
type
,
num
))
def
test_robustness
(
self
):
# def test_robustness(self):
x
,
y
,
z
=
inputs
()
# # this test used to make sense to have, but it doesn't work like that anymore
e
=
add
(
add
(
sigmoid
(
x
),
sigmoid
(
sigmoid
(
z
))),
dot
(
add
(
x
,
y
),
dot
(
y
,
z
)))
# x, y, z = inputs()
g
=
Env
([
x
,
y
,
z
],
[
e
])
# e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), dot(y, z)))
g
.
extend
(
NodeFinder
())
# g = Env([x, y, z], [e])
gen
=
g
.
get_nodes
(
sigmoid
)
# I want to get Sigmoid instances
# g.extend(NodeFinder())
g
.
replace
(
e
,
add
(
x
,
y
))
# but here I prune them all
# gen = g.get_nodes(sigmoid) # I want to get Sigmoid instances
assert
len
([
x
for
x
in
gen
])
==
0
# the generator should not yield them
# g.replace(e, add(x, y)) # but here I prune them all
# assert len([x for x in gen]) == 0 # the generator should not yield them
...
...
gof/opt.py
浏览文件 @
4573a825
...
@@ -128,9 +128,7 @@ class MergeOptimizer(Optimizer):
...
@@ -128,9 +128,7 @@ class MergeOptimizer(Optimizer):
"""
"""
def
add_requirements
(
self
,
env
):
def
add_requirements
(
self
,
env
):
try
:
env
.
extend
(
toolbox
.
ReplaceValidate
())
env
.
extend
(
toolbox
.
ReplaceValidate
())
except
:
pass
def
apply
(
self
,
env
):
def
apply
(
self
,
env
):
cid
=
_metadict
()
#result -> result.desc() (for constants)
cid
=
_metadict
()
#result -> result.desc() (for constants)
...
@@ -139,7 +137,7 @@ class MergeOptimizer(Optimizer):
...
@@ -139,7 +137,7 @@ class MergeOptimizer(Optimizer):
sig
=
r
.
signature
()
sig
=
r
.
signature
()
other_r
=
inv_cid
.
get
(
sig
,
None
)
other_r
=
inv_cid
.
get
(
sig
,
None
)
if
other_r
is
not
None
:
if
other_r
is
not
None
:
env
.
replace
(
r
,
other_r
)
env
.
replace
_validate
(
r
,
other_r
)
else
:
else
:
cid
[
r
]
=
sig
cid
[
r
]
=
sig
inv_cid
[
sig
]
=
r
inv_cid
[
sig
]
=
r
...
@@ -559,6 +557,10 @@ class OpKeyOptimizer(NavigatorOptimizer):
...
@@ -559,6 +557,10 @@ class OpKeyOptimizer(NavigatorOptimizer):
env
.
extend
(
toolbox
.
NodeFinder
())
env
.
extend
(
toolbox
.
NodeFinder
())
def
keep_going
(
exc
,
nav
,
repl_pairs
):
pass
##############################
##############################
### Pre-defined optimizers ###
### Pre-defined optimizers ###
##############################
##############################
...
@@ -567,431 +569,3 @@ def ExpandMacros(filter = None):
...
@@ -567,431 +569,3 @@ def ExpandMacros(filter = None):
return
TopoOptimizer
(
ExpandMacro
(
filter
=
filter
),
return
TopoOptimizer
(
ExpandMacro
(
filter
=
filter
),
order
=
'in_to_out'
,
order
=
'in_to_out'
,
ignore_newtrees
=
False
)
ignore_newtrees
=
False
)
# class TopoOptimizer(Optimizer):
# def __init__(self, local_opt, order = 'out_to_in', ignore_newtrees = False, failure_callback = None):
# self.local_opt = local_opt
# if order not in ['out_to_in', 'in_to_out']:
# raise ValueError("order must be 'out_to_in' or 'in_to_out'")
# self.order = order
# self.ignore_newtrees = ignore_newtrees
# self.failure_callback = failure_callback
# def apply(self, env):
# ignore_newtrees = self.ignore_newtrees
# q = deque()
# class Updater:
# def on_attach(self, env):
# for node in graph.io_toposort(env.inputs, env.outputs):
# q.append(node)
# if not ignore_newtrees:
# def on_import(self, env, node):
# q.append(node)
# def on_prune(self, env, node):
# if node is not current_node:
# q.remove(node)
# u = Updater()
# env.extend(u)
# while q:
# if self.order == 'out_to_in':
# node = q.pop()
# else:
# node = q.popleft()
# current_node = node
# if not self.local_opt.applies(node):
# continue
# replacements = self.local_opt.transform(node)
# repl_pairs = zip(node.outputs, replacements)
# try:
# env.replace_all_validate(repl_pairs)
# except Exception, e:
# if self.failure_callback is not None:
# self.failure_callback(e, self, repl_pairs)
# else:
# raise
# env.remove_feature(u)
# def add_requirements(self, env):
# try:
# env.extend(toolbox.ReplaceValidate())
# except: pass
# class OpKeyOptimizer(Optimizer):
# def __init__(self, local_opt, ignore_newtrees = False, failure_callback = None):
# self.local_opt = local_opt
# if not hasattr(local_opt, 'op_key'):
# raise TypeError("LocalOptimizer for OpKeyOptimizer must have an 'op_key' method.")
# self.ignore_newtrees = ignore_newtrees
# self.failure_callback = failure_callback
# def apply(self, env):
# ignore_newtrees = self.ignore_newtrees
# op = self.local_opt.op_key()
# q = []
# class Updater:
# def on_attach(self, env):
# for node in graph.io_toposort(env.inputs, env.outputs):
# if node.op == op: q.append(node)
# if not ignore_newtrees:
# def on_import(self, env, node):
# if node.op == op: q.append(node)
# def on_prune(self, env, node):
# if node is not current_node:
# q.remove(node)
# u = Updater()
# env.extend(u)
# q = list(env.get_nodes(op))
# while q:
# node = q.pop()
# current_node = node
# if not self.local_opt.applies(node):
# continue
# replacements = self.local_opt.transform(node)
# repl_pairs = zip(node.outputs, replacements)
# try:
# env.replace_all_validate(repl_pairs)
# except Exception, e:
# if self.failure_callback is not None:
# self.failure_callback(e, self, repl_pairs)
# else:
# raise
# env.remove_feature(u)
# def add_requirements(self):
# """
# Requires the following features:
# - NodeFinder
# - ReplaceValidate
# """
# try:
# env.extend(toolbox.NodeFinder())
# env.extend(toolbox.ReplaceValidate())
# except: pass
# class OpSpecificOptimizer(LocalOptimizer):
# """
# Generic L{Optimizer} that applies only to ops of a certain
# type. The type in question is accessed through L{self.op}.
# op can also be a class variable of the subclass.
# """
# def add_requirements(self, env):
# try:
# env.extend(toolbox.NodeFinder())
# env.extend(toolbox.ReplaceValidate())
# except: pass
# def candidates(self, env):
# """
# Returns all nodes that have L{self.op} in their op field.
# """
# return env.get_nodes(self.op)
# class OpSubOptimizer(Optimizer):
# """
# Replaces all applications of a certain op by the application of
# another op that take the same inputs as what they are replacing.
# e.g. OpSubOptimizer(add, sub) ==> add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
# OpSubOptimizer requires the following features:
# - NodeFinder
# - ReplaceValidate
# """
# def add_requirements(self, env):
# """
# Requires the following features:
# - NodeFinder
# - ReplaceValidate
# """
# try:
# env.extend(toolbox.NodeFinder())
# env.extend(toolbox.ReplaceValidate())
# except: pass
# def __init__(self, op1, op2, failure_callback = None):
# """
# op1.make_node and op2.make_node must take the same number of
# inputs and have the same number of outputs.
# If failure_callback is not None, it will be called whenever
# the Optimizer fails to do a replacement in the graph. The
# arguments to the callback are: (node, replacement, exception)
# """
# self.op1 = op1
# self.op2 = op2
# self.failure_callback = failure_callback
# def apply(self, env):
# """
# Replaces all applications of self.op1 by applications of self.op2
# with the same inputs.
# """
# candidates = env.get_nodes(self.op1)
# for node in candidates:
# try:
# repl = self.op2.make_node(*node.inputs)
# assert len(node.outputs) == len(repl.outputs)
# for old, new in zip(node.outputs, repl.outputs):
# env.replace_validate(old, new)
# except Exception, e:
# if self.failure_callback is not None:
# self.failure_callback(node, repl, e)
# def str(self):
# return "%s -> %s" % (self.op1, self.op2)
# class OpRemover(Optimizer):
# """
# @todo untested
# Removes all applications of an op by transferring each of its
# outputs to the corresponding input.
# """
# def add_requirements(self, env):
# try:
# env.extend(toolbox.NodeFinder())
# env.extend(toolbox.ReplaceValidate())
# except: pass
# def __init__(self, op, failure_callback = None):
# """
# Applications of the op must have as many inputs as outputs.
# If failure_callback is not None, it will be called whenever
# the Optimizer fails to remove an operation in the graph. The
# arguments to the callback are: (node, exception)
# """
# self.op = op
# self.failure_callback = failure_callback
# def apply(self, env):
# """
# Removes all applications of self.op.
# """
# candidates = env.get_nodes(self.op)
# for node in candidates:
# try:
# assert len(node.inputs) == len(node.outputs)
# for input, output in zip(node.inputs, node.outputs):
# env.replace(output, input)
# except Exception, e:
# if self.failure_callback is not None:
# self.failure_callback(node, e)
# pass
# def str(self):
# return "f(%s(x)) -> f(x)" % self.op
# class PatternOptimizer(OpSpecificOptimizer):
# """
# @todo update
# Replaces all occurrences of the input pattern by the output pattern:
# input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...)
# input_pattern ::= dict(pattern = <input_pattern>,
# constraint = <constraint>)
# sub_pattern ::= input_pattern
# sub_pattern ::= string
# sub_pattern ::= a Constant instance
# constraint ::= lambda env, expr: additional matching condition
# output_pattern ::= (op, <output_pattern1>, <output_pattern2>, ...)
# output_pattern ::= string
# Each string in the input pattern is a variable that will be set to
# whatever expression is found in its place. If the same string is
# used more than once, the same expression must be found in those
# places. If a string used in the input pattern is used in the
# output pattern, the matching expression will be inserted in its
# place. The input pattern cannot just be a string but the output
# pattern can.
# If you put a constant result in the input pattern, there will be a
# match iff a constant result with the same value and the same type
# is found in its place.
# You can add a constraint to the match by using the dict(...) form
# described above with a 'constraint' key. The constraint must be a
# function that takes the env and the current Result that we are
# trying to match and returns True or False according to an
# arbitrary criterion.
# Examples:
# PatternOptimizer((add, 'x', 'y'), (add, 'y', 'x'))
# PatternOptimizer((multiply, 'x', 'x'), (square, 'x'))
# PatternOptimizer((subtract, (add, 'x', 'y'), 'y'), 'x')
# PatternOptimizer((power, 'x', Constant(double, 2.0)), (square, 'x'))
# PatternOptimizer((boggle, {'pattern': 'x',
# 'constraint': lambda env, expr: expr.type == scrabble}),
# (scrabble, 'x'))
# """
# def __init__(self, in_pattern, out_pattern, allow_multiple_clients = False, failure_callback = None):
# """
# Creates a PatternOptimizer that replaces occurrences of
# in_pattern by occurrences of out_pattern.
# If failure_callback is not None, if there is a match but a
# replacement fails to occur, the callback will be called with
# arguments (result_to_replace, replacement, exception).
# If allow_multiple_clients is False, he pattern matching will
# fail if one of the subpatterns has more than one client.
# """
# self.in_pattern = in_pattern
# self.out_pattern = out_pattern
# if isinstance(in_pattern, (list, tuple)):
# self.op = self.in_pattern[0]
# elif isinstance(in_pattern, dict):
# self.op = self.in_pattern['pattern'][0]
# else:
# raise TypeError("The pattern to search for must start with a specific Op instance.")
# self.__doc__ = self.__class__.__doc__ + "\n\nThis instance does: " + str(self) + "\n"
# self.failure_callback = failure_callback
# self.allow_multiple_clients = allow_multiple_clients
# def apply_on_node(self, env, node):
# """
# Checks if the graph from node corresponds to in_pattern. If it does,
# constructs out_pattern and performs the replacement.
# """
# def match(pattern, expr, u, first = False):
# if isinstance(pattern, (list, tuple)):
# if expr.owner is None:
# return False
# if not (expr.owner.op == pattern[0]) or (not self.allow_multiple_clients and not first and env.nclients(expr) > 1):
# return False
# if len(pattern) - 1 != len(expr.owner.inputs):
# return False
# for p, v in zip(pattern[1:], expr.owner.inputs):
# u = match(p, v, u)
# if not u:
# return False
# elif isinstance(pattern, dict):
# try:
# real_pattern = pattern['pattern']
# constraint = pattern['constraint']
# except KeyError:
# raise KeyError("Malformed pattern: %s (expected keys pattern and constraint)" % pattern)
# if constraint(env, expr):
# return match(real_pattern, expr, u, False)
# elif isinstance(pattern, str):
# v = unify.Var(pattern)
# if u[v] is not v and u[v] is not expr:
# return False
# else:
# u = u.merge(expr, v)
# elif isinstance(pattern, graph.Constant) and isinstance(expr, graph.Constant) and pattern.equals(expr):
# return u
# else:
# return False
# return u
# def build(pattern, u):
# if isinstance(pattern, (list, tuple)):
# args = [build(p, u) for p in pattern[1:]]
# return pattern[0](*args)
# elif isinstance(pattern, str):
# return u[unify.Var(pattern)]
# else:
# return pattern
# u = match(self.in_pattern, node.out, unify.Unification(), True)
# if u:
# try:
# # note: only replaces the default 'out' port if it exists
# p = self.out_pattern
# new = 'unassigned' # this is for the callback if build fails
# new = build(p, u)
# env.replace(node.out, new)
# except Exception, e:
# if self.failure_callback is not None:
# self.failure_callback(node.out, new, e)
# pass
# def __str__(self):
# def pattern_to_str(pattern):
# if isinstance(pattern, (list, tuple)):
# return "%s(%s)" % (str(pattern[0]), ", ".join([pattern_to_str(p) for p in pattern[1:]]))
# elif isinstance(pattern, dict):
# return "%s subject to %s" % (pattern_to_str(pattern['pattern']), str(pattern['constraint']))
# else:
# return str(pattern)
# return "%s -> %s" % (pattern_to_str(self.in_pattern), pattern_to_str(self.out_pattern))
# class LocalOptimizer(Optimizer):
# """
# Generic L{Optimizer} class that considers local parts of
# the L{Env}. It must be subclassed and should override the
# following two methods:
# - candidates(env) -> returns a set of ops that can be
# optimized
# - apply_on_node(env, node) -> for each node in candidates,
# this function will be called to perform the actual
# optimization.
# """
# def candidates(self, env):
# """
# Must return a set of nodes that can be optimized.
# """
# raise utils.AbstractFunctionError()
# def apply_on_node(self, env, node):
# """
# For each node in candidates, this function will be called to
# perform the actual optimization.
# """
# raise utils.AbstractFunctionError()
# def apply(self, env):
# """
# Calls self.apply_on_op(env, op) for each op in self.candidates(env).
# """
# for node in self.candidates(env):
# if node in env.nodes:
# self.apply_on_node(env, node)
gof/utils.py
浏览文件 @
4573a825
...
@@ -34,8 +34,9 @@ class scratchpad:
...
@@ -34,8 +34,9 @@ class scratchpad:
self
.
__dict__
.
clear
()
self
.
__dict__
.
clear
()
def
__update__
(
self
,
other
):
def
__update__
(
self
,
other
):
self
.
__dict__
.
update
(
other
.
__dict__
)
self
.
__dict__
.
update
(
other
.
__dict__
)
return
self
def
__str__
(
self
):
def
__str__
(
self
):
print
"scratch"
+
str
(
self
.
__dict__
)
return
"scratch"
+
str
(
self
.
__dict__
)
def
memoize
(
f
):
def
memoize
(
f
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论