提交 41cd2224 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

gof.env.Env.edge was sorely needed + small bugfixes + removed some crud in Env

上级 07ceb6df
...@@ -402,7 +402,7 @@ def make_broadcast(scalar_opclass, inplace_pattern = {}, name = None): ...@@ -402,7 +402,7 @@ def make_broadcast(scalar_opclass, inplace_pattern = {}, name = None):
scalar_name = scalar_opclass.__name__ scalar_name = scalar_opclass.__name__
previous_doc = Broadcast.__doc__ previous_doc = Broadcast.__doc__
scalar_doc = scalar_opclass.__doc__ scalar_doc = scalar_opclass.__doc__ or ""
if scalar_doc: if scalar_doc:
scalar_doc = """ scalar_doc = """
%(scalar_name)s documentation: %(scalar_name)s documentation:
...@@ -411,7 +411,7 @@ def make_broadcast(scalar_opclass, inplace_pattern = {}, name = None): ...@@ -411,7 +411,7 @@ def make_broadcast(scalar_opclass, inplace_pattern = {}, name = None):
doc = """ doc = """
Usage: %(name)s(*inputs) Usage: %(name)s(*inputs)
Equivalent to: Broadcast(%(scalar_name)s, inputs, %(inplace_pattern)s) Equivalent to: Broadcast(scalar.%(scalar_name)s, inputs, %(inplace_pattern)s)
Performs Scalar %(scalar_name)s on each element of the Performs Scalar %(scalar_name)s on each element of the
input tensors. input tensors.
......
...@@ -421,7 +421,7 @@ class CLinker(Linker): ...@@ -421,7 +421,7 @@ class CLinker(Linker):
# orphans are not inputs so we'll just get fetch them when we initialize the struct and assume they stay the same # orphans are not inputs so we'll just get fetch them when we initialize the struct and assume they stay the same
policy = [[get_c_declare, get_c_extract, get_c_cleanup], policy = [[get_c_declare, get_c_extract, get_c_cleanup],
[get_nothing, get_nothing, get_nothing]] [get_nothing, get_nothing, get_nothing]]
elif result in self.temps or not reuse_storage: elif result in self.temps:
# temps don't need to be extracted from Python, so we call c_init rather than c_extract # temps don't need to be extracted from Python, so we call c_init rather than c_extract
# they do not need to be relayed to Python, so we don't sync # they do not need to be relayed to Python, so we don't sync
if result.c_is_simple() or not reuse_storage: if result.c_is_simple() or not reuse_storage:
...@@ -441,6 +441,8 @@ class CLinker(Linker): ...@@ -441,6 +441,8 @@ class CLinker(Linker):
# it is useful for complex outputs to reuse storage at each run, so we only clean up in the destructor # it is useful for complex outputs to reuse storage at each run, so we only clean up in the destructor
policy = [[get_c_declare, get_c_init, get_c_cleanup], policy = [[get_c_declare, get_c_init, get_c_cleanup],
[get_nothing, get_nothing, get_c_sync]] [get_nothing, get_nothing, get_c_sync]]
else:
raise Exception("what the fuck")
builder, block = struct_result_codeblocks(result, policy, id, symbol, sub) builder, block = struct_result_codeblocks(result, policy, id, symbol, sub)
......
...@@ -155,12 +155,6 @@ class Env(graph.Graph): ...@@ -155,12 +155,6 @@ class Env(graph.Graph):
""" """
if feature_class in self._features: if feature_class in self._features:
return # the feature is already present return # the feature is already present
else:
for other_feature_class in self._features:
if issubclass(other_feature_class, feature_class):
return
elif issubclass(feature_class, other_feature_class):
self.__del_feature__(other_feature_class)
self.__add_feature__(feature_class, do_import) self.__add_feature__(feature_class, do_import)
def __add_feature__(self, feature_class, do_import): def __add_feature__(self, feature_class, do_import):
...@@ -193,14 +187,7 @@ class Env(graph.Graph): ...@@ -193,14 +187,7 @@ class Env(graph.Graph):
pass pass
def get_feature(self, feature_class): def get_feature(self, feature_class):
try:
return self._features[feature_class] return self._features[feature_class]
except KeyError:
for other_feature_class in self._features:
if issubclass(other_feature_class, feature_class):
return self._features[other_feature_class]
else:
raise
def has_feature(self, feature_class): def has_feature(self, feature_class):
try: try:
...@@ -213,6 +200,18 @@ class Env(graph.Graph): ...@@ -213,6 +200,18 @@ class Env(graph.Graph):
"Same as len(self.clients(r))." "Same as len(self.clients(r))."
return len(self.clients(r)) return len(self.clients(r))
def edge(self, r):
return r in self.inputs or r in self.orphans()
def follow(self, r):
op = r.owner
if self.edge(r):
return None
else:
if op is None:
raise Exception("what the fuck")
return op.inputs
def ops(self): def ops(self):
"All ops within the subgraph bound by env.inputs and env.outputs." "All ops within the subgraph bound by env.inputs and env.outputs."
return self._ops return self._ops
......
...@@ -67,11 +67,9 @@ class DimShuffleLifter(opt.Optimizer): ...@@ -67,11 +67,9 @@ class DimShuffleLifter(opt.Optimizer):
if r in seen: if r in seen:
return return
seen.add(r) seen.add(r)
op = r.owner if env.edge(r):
if op is None \
or op in env.inputs \
or op in env.orphans():
return return
op = r.owner
if isinstance(op, DimShuffle): if isinstance(op, DimShuffle):
in_op = op.inputs[0].owner in_op = op.inputs[0].owner
if isinstance(in_op, DimShuffle): if isinstance(in_op, DimShuffle):
...@@ -121,9 +119,7 @@ def find_cliques(env, through_broadcast = False): ...@@ -121,9 +119,7 @@ def find_cliques(env, through_broadcast = False):
# is False) a Result which needs to be broadcasted. # is False) a Result which needs to be broadcasted.
op = r.owner op = r.owner
if r in env.inputs \ if env.edge(r) \
or r in env.orphans() \
or op is None \
or not isinstance(op, Broadcast) \ or not isinstance(op, Broadcast) \
or len(op.outputs) > 1: or len(op.outputs) > 1:
# todo: handle multiple-output broadcast ops # todo: handle multiple-output broadcast ops
...@@ -155,7 +151,7 @@ def find_cliques(env, through_broadcast = False): ...@@ -155,7 +151,7 @@ def find_cliques(env, through_broadcast = False):
cliques = [] cliques = []
def find_cliques_helper(r): def find_cliques_helper(r):
if r in env.inputs or r in env.orphans(): if env.edge(r):
return return
clique_inputs = seek_from(r) clique_inputs = seek_from(r)
if clique_inputs is None: if clique_inputs is None:
...@@ -218,7 +214,7 @@ class CliqueOptimizer(opt.Optimizer): ...@@ -218,7 +214,7 @@ class CliqueOptimizer(opt.Optimizer):
if r in equiv: if r in equiv:
return equiv[r] return equiv[r]
op = r.owner op = r.owner
if r in env.inputs or r in env.orphans(): if env.edge(r):
# For each leave we make a Scalar of the corresponding dtype # For each leave we make a Scalar of the corresponding dtype
s = scalar.Scalar(dtype = r.dtype) s = scalar.Scalar(dtype = r.dtype)
_r = r _r = r
......
...@@ -71,7 +71,10 @@ class Canonizer(gof.Optimizer): ...@@ -71,7 +71,10 @@ class Canonizer(gof.Optimizer):
def canonize(r): def canonize(r):
if r in env.inputs or r in env.orphans(): # if r in env.inputs or r in env.orphans():
# return
next = env.follow(r)
if next is None:
return return
def flatten(r, nclients_check = True): def flatten(r, nclients_check = True):
...@@ -79,9 +82,11 @@ class Canonizer(gof.Optimizer): ...@@ -79,9 +82,11 @@ class Canonizer(gof.Optimizer):
# into a list of numerators and a list of denominators # into a list of numerators and a list of denominators
# e.g. (x*(1/y))*(x/(z/a)) aka Mul(Mul(x, (Inv, y)), Div(x, Div(z, a))) -> [x, x, a], [z, y] # e.g. (x*(1/y))*(x/(z/a)) aka Mul(Mul(x, (Inv, y)), Div(x, Div(z, a))) -> [x, x, a], [z, y]
op = r.owner if env.edge(r):
if op is None or r in env.inputs or r in env.orphans():
return [r], [] return [r], []
op = r.owner
# if op is None or r in env.inputs or r in env.orphans():
# return [r], []
results = [r2.dtype == r.dtype and flatten(r2) or ([r2], []) for r2 in op.inputs] results = [r2.dtype == r.dtype and flatten(r2) or ([r2], []) for r2 in op.inputs]
if isinstance(op, self.main) and (not nclients_check or env.nclients(r) == 1): if isinstance(op, self.main) and (not nclients_check or env.nclients(r) == 1):
...@@ -103,12 +108,15 @@ class Canonizer(gof.Optimizer): ...@@ -103,12 +108,15 @@ class Canonizer(gof.Optimizer):
num, denum = flatten(r, False) num, denum = flatten(r, False)
if (num, denum) == ([r], []): if (num, denum) == ([r], []):
if r.owner is None: for input in (env.follow(r) or []):
return
else:
for input in r.owner.inputs:
canonize(input) canonize(input)
return return
# if r.owner is None:
# return
# else:
# for input in r.owner.inputs:
# canonize(input)
# return
# Terms that are both in the num and denum lists cancel each other # Terms that are both in the num and denum lists cancel each other
for d in list(denum): for d in list(denum):
...@@ -194,7 +202,7 @@ def group_powers(env, num, denum): ...@@ -194,7 +202,7 @@ def group_powers(env, num, denum):
# and does d[base].append(power). # and does d[base].append(power).
for factor in list(seq): for factor in list(seq):
op = factor.owner op = factor.owner
if op is None or factor in env.inputs or factor in env.orphans(): if env.edge(factor):
continue continue
if isinstance(op, Exp): if isinstance(op, Exp):
d.setdefault('e', []).append(op.inputs[0]) d.setdefault('e', []).append(op.inputs[0])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论