提交 41cd2224 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

gof.env.Env.edge was sorely needed + small bugfixes + removed some crud in Env

上级 07ceb6df
......@@ -402,7 +402,7 @@ def make_broadcast(scalar_opclass, inplace_pattern = {}, name = None):
scalar_name = scalar_opclass.__name__
previous_doc = Broadcast.__doc__
scalar_doc = scalar_opclass.__doc__
scalar_doc = scalar_opclass.__doc__ or ""
if scalar_doc:
scalar_doc = """
%(scalar_name)s documentation:
......@@ -411,7 +411,7 @@ def make_broadcast(scalar_opclass, inplace_pattern = {}, name = None):
doc = """
Usage: %(name)s(*inputs)
Equivalent to: Broadcast(%(scalar_name)s, inputs, %(inplace_pattern)s)
Equivalent to: Broadcast(scalar.%(scalar_name)s, inputs, %(inplace_pattern)s)
Performs Scalar %(scalar_name)s on each element of the
input tensors.
......
......@@ -421,7 +421,7 @@ class CLinker(Linker):
# orphans are not inputs so we'll just get fetch them when we initialize the struct and assume they stay the same
policy = [[get_c_declare, get_c_extract, get_c_cleanup],
[get_nothing, get_nothing, get_nothing]]
elif result in self.temps or not reuse_storage:
elif result in self.temps:
# temps don't need to be extracted from Python, so we call c_init rather than c_extract
# they do not need to be relayed to Python, so we don't sync
if result.c_is_simple() or not reuse_storage:
......@@ -441,6 +441,8 @@ class CLinker(Linker):
# it is useful for complex outputs to reuse storage at each run, so we only clean up in the destructor
policy = [[get_c_declare, get_c_init, get_c_cleanup],
[get_nothing, get_nothing, get_c_sync]]
else:
raise Exception("what the fuck")
builder, block = struct_result_codeblocks(result, policy, id, symbol, sub)
......
......@@ -155,12 +155,6 @@ class Env(graph.Graph):
"""
if feature_class in self._features:
return # the feature is already present
else:
for other_feature_class in self._features:
if issubclass(other_feature_class, feature_class):
return
elif issubclass(feature_class, other_feature_class):
self.__del_feature__(other_feature_class)
self.__add_feature__(feature_class, do_import)
def __add_feature__(self, feature_class, do_import):
......@@ -193,14 +187,7 @@ class Env(graph.Graph):
pass
def get_feature(self, feature_class):
try:
return self._features[feature_class]
except KeyError:
for other_feature_class in self._features:
if issubclass(other_feature_class, feature_class):
return self._features[other_feature_class]
else:
raise
def has_feature(self, feature_class):
try:
......@@ -213,6 +200,18 @@ class Env(graph.Graph):
"Same as len(self.clients(r))."
return len(self.clients(r))
def edge(self, r):
return r in self.inputs or r in self.orphans()
def follow(self, r):
op = r.owner
if self.edge(r):
return None
else:
if op is None:
raise Exception("what the fuck")
return op.inputs
def ops(self):
"All ops within the subgraph bound by env.inputs and env.outputs."
return self._ops
......
......@@ -67,11 +67,9 @@ class DimShuffleLifter(opt.Optimizer):
if r in seen:
return
seen.add(r)
op = r.owner
if op is None \
or op in env.inputs \
or op in env.orphans():
if env.edge(r):
return
op = r.owner
if isinstance(op, DimShuffle):
in_op = op.inputs[0].owner
if isinstance(in_op, DimShuffle):
......@@ -121,9 +119,7 @@ def find_cliques(env, through_broadcast = False):
# is False) a Result which needs to be broadcasted.
op = r.owner
if r in env.inputs \
or r in env.orphans() \
or op is None \
if env.edge(r) \
or not isinstance(op, Broadcast) \
or len(op.outputs) > 1:
# todo: handle multiple-output broadcast ops
......@@ -155,7 +151,7 @@ def find_cliques(env, through_broadcast = False):
cliques = []
def find_cliques_helper(r):
if r in env.inputs or r in env.orphans():
if env.edge(r):
return
clique_inputs = seek_from(r)
if clique_inputs is None:
......@@ -218,7 +214,7 @@ class CliqueOptimizer(opt.Optimizer):
if r in equiv:
return equiv[r]
op = r.owner
if r in env.inputs or r in env.orphans():
if env.edge(r):
# For each leave we make a Scalar of the corresponding dtype
s = scalar.Scalar(dtype = r.dtype)
_r = r
......
......@@ -71,7 +71,10 @@ class Canonizer(gof.Optimizer):
def canonize(r):
if r in env.inputs or r in env.orphans():
# if r in env.inputs or r in env.orphans():
# return
next = env.follow(r)
if next is None:
return
def flatten(r, nclients_check = True):
......@@ -79,9 +82,11 @@ class Canonizer(gof.Optimizer):
# into a list of numerators and a list of denominators
# e.g. (x*(1/y))*(x/(z/a)) aka Mul(Mul(x, (Inv, y)), Div(x, Div(z, a))) -> [x, x, a], [z, y]
op = r.owner
if op is None or r in env.inputs or r in env.orphans():
if env.edge(r):
return [r], []
op = r.owner
# if op is None or r in env.inputs or r in env.orphans():
# return [r], []
results = [r2.dtype == r.dtype and flatten(r2) or ([r2], []) for r2 in op.inputs]
if isinstance(op, self.main) and (not nclients_check or env.nclients(r) == 1):
......@@ -103,12 +108,15 @@ class Canonizer(gof.Optimizer):
num, denum = flatten(r, False)
if (num, denum) == ([r], []):
if r.owner is None:
return
else:
for input in r.owner.inputs:
for input in (env.follow(r) or []):
canonize(input)
return
# if r.owner is None:
# return
# else:
# for input in r.owner.inputs:
# canonize(input)
# return
# Terms that are both in the num and denum lists cancel each other
for d in list(denum):
......@@ -194,7 +202,7 @@ def group_powers(env, num, denum):
# and does d[base].append(power).
for factor in list(seq):
op = factor.owner
if op is None or factor in env.inputs or factor in env.orphans():
if env.edge(factor):
continue
if isinstance(op, Exp):
d.setdefault('e', []).append(op.inputs[0])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论