提交 aecbbb99 authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #5001 from nouiz/Composite_name

Postpone Composite name creating
......@@ -1838,9 +1838,7 @@ class _Linker(gof.link.LocalLinker):
thunk.outputs = [storage_map[v] for v in node.outputs]
thunk_other = thunk
else:
new_node = node.op.prepare_node(node, storage_map, compute_map)
if new_node is not None:
node = new_node
node.op.prepare_node(node, storage_map, compute_map)
debug = hasattr(node.op, 'debug_perform')
......
......@@ -1582,6 +1582,9 @@ class CLinker(link.Linker):
# If we can't get a key, then forget the cache mechanism.
module = self.compile_cmodule()
else:
# Set compute_map as None as clinker do not support lazy evaluation
for node in self.node_order:
node.op.prepare_node(node, storage_map, None)
module = get_module_cache().module_from_key(
key=key, lnk=self, keep_lock=keep_lock)
......
......@@ -795,7 +795,7 @@ class Op(utils.object2, PureOp, CLinkerOp):
Make any special modifications that the Op needs before doing
make_thunk().
This can either modify the node inplace or return a new one.
This can modify the node inplace and should return nothing.
"""
pass
......@@ -916,10 +916,9 @@ class Op(utils.object2, PureOp, CLinkerOp):
"""
logger = logging.getLogger('theano.gof.op.Op')
new_node = self.prepare_node(node, storage_map=storage_map,
compute_map=compute_map)
if new_node is not None:
node = new_node
self.prepare_node(node, storage_map=storage_map,
compute_map=compute_map)
if not hasattr(self, '_op_use_c_code'):
warnings.warn(
"The __getstate__ method of '%s' is not implemented correctly."
......
......@@ -345,6 +345,11 @@ class PrinterState(gof.utils.scratchpad):
else:
self.__dict__.update(props)
self.__dict__.update(more_props)
# A dict from the object to print to its string
# representation. If it is a dag and not a tree, it allow to
# parse each node of the graph only once. They will still be
# printed many times
self.memo = {}
def clone(self, props=None, **more_props):
if props is None:
......@@ -361,6 +366,8 @@ class OperatorPrinter:
assert self.assoc in VALID_ASSOC
def process(self, output, pstate):
if output in pstate.memo:
return pstate.memo[output]
pprinter = pstate.pprinter
node = output.owner
if node is None:
......@@ -393,9 +400,11 @@ class OperatorPrinter:
else:
s = (" %s " % self.operator).join(input_strings)
if parenthesize:
return "(%s)" % s
r = "(%s)" % s
else:
return s
r = s
pstate.memo[output] = r
return r
class PatternPrinter:
......@@ -409,6 +418,8 @@ class PatternPrinter:
self.patterns.append((pattern[0], pattern[1:]))
def process(self, output, pstate):
if output in pstate.memo:
return pstate.memo[output]
pprinter = pstate.pprinter
node = output.owner
if node is None:
......@@ -425,7 +436,9 @@ class PatternPrinter:
for i, x in enumerate(pp_process(input, precedence)
for input, precedence in
zip(node.inputs, precedences)))
return pattern % d
r = pattern % d
pstate.memo[output] = r
return r
class FunctionPrinter:
......@@ -434,6 +447,8 @@ class FunctionPrinter:
self.names = names
def process(self, output, pstate):
if output in pstate.memo:
return pstate.memo[output]
pprinter = pstate.pprinter
node = output.owner
if node is None:
......@@ -441,40 +456,27 @@ class FunctionPrinter:
"not the result of an operation" % self.names)
idx = node.outputs.index(output)
name = self.names[idx]
return "%s(%s)" % (name, ", ".join(
r = "%s(%s)" % (name, ", ".join(
[pprinter.process(input, pstate.clone(precedence=-1000))
for input in node.inputs]))
class MemberPrinter:
def __init__(self, *names):
self.names = names
def process(self, output, pstate):
pprinter = pstate.pprinter
node = output.owner
if node is None:
raise TypeError("function %s cannot represent a variable that is"
" not the result of an operation" % self.function)
idx = node.outputs.index(output)
name = self.names[idx]
input = node.inputs[0]
return "%s.%s" % (pprinter.process(input,
pstate.clone(precedence=1000)),
name)
pstate.memo[output] = r
return r
class IgnorePrinter:
def process(self, output, pstate):
if output in pstate.memo:
return pstate.memo[output]
pprinter = pstate.pprinter
node = output.owner
if node is None:
raise TypeError("function %s cannot represent a variable that is"
" not the result of an operation" % self.function)
input = node.inputs[0]
return "%s" % pprinter.process(input, pstate)
r = "%s" % pprinter.process(input, pstate)
pstate.memo[output] = r
return r
class DefaultPrinter:
......@@ -482,22 +484,30 @@ class DefaultPrinter:
def __init__(self):
pass
def process(self, r, pstate):
def process(self, output, pstate):
if output in pstate.memo:
return pstate.memo[output]
pprinter = pstate.pprinter
node = r.owner
node = output.owner
if node is None:
return LeafPrinter().process(r, pstate)
return "%s(%s)" % (str(node.op), ", ".join(
return LeafPrinter().process(output, pstate)
r = "%s(%s)" % (str(node.op), ", ".join(
[pprinter.process(input, pstate.clone(precedence=-1000))
for input in node.inputs]))
pstate.memo[output] = r
return r
class LeafPrinter:
def process(self, r, pstate):
if r.name in greek:
return greek[r.name]
def process(self, output, pstate):
if output in pstate.memo:
return pstate.memo[output]
if output.name in greek:
r = greek[output.name]
else:
return str(r)
r = str(output)
pstate.memo[output] = r
return r
class PPrinter:
......
......@@ -3462,6 +3462,8 @@ class Composite(ScalarOp):
init_param = ('inputs', 'outputs')
def __str__(self):
if self.name is None:
self.init_name()
return self.name
def make_new_inplace(self, output_types_preference=None, name=None):
......@@ -3485,6 +3487,9 @@ class Composite(ScalarOp):
Return the C code for this Composite Op.
"""
# It was already called
if hasattr(self, '_c_code'):
return
subd = dict(chain(
((e, "%%(i%i)s" % i) for i, e in enumerate(self.fgraph.inputs)),
((e, "%%(o%i)s" % i) for i, e in enumerate(self.fgraph.outputs))))
......@@ -3533,21 +3538,46 @@ class Composite(ScalarOp):
Return a list of functions that compute each output of self.
"""
# In the case where the graph is a dag, but not a tree like:
# add(*1 -> mul(x, y), *1)
# We have an efficent way to build the executable (we build
# and traverse each node only once).
# But we don't have an efficient execution. We will execute
# like a tree, so nodes that have more then 1 client will be
# executed as many times as there number of clients. In the
# example aboce, it will calculate *1 twice. Doing otherwise
# imply making a complicated execution engine.
# We need the fast creation of the executor as we always do it
# even if we will use the c code. The Python implementation is
# already slow, so it is not as much important to have a fast
# execution there.
memo = {}
def compose_impl(r):
# this is not optimal at all eg in add(*1 -> mul(x, y), *1)
# it will calculate *1 twice
# it also doesn't follow fgraph.toposort but that's (presumably)
# still correct since we only have scalar ops
if r in memo:
return memo[r]
if r in self.fgraph.inputs:
idx = self.fgraph.inputs.index(r)
return lambda inputs: inputs[idx]
def f(inputs):
return inputs[idx]
memo[r] = f
return f
elif r.owner is None: # in fgraph.orphans:
return lambda inputs: r.data
def f(inputs):
return r.data
memo[r] = f
return f
node = r.owner
producers = [compose_impl(input) for input in node.inputs]
def f(inputs):
return node.op.impl(*[p(inputs) for p in producers])
memo[r] = f
return f
self._impls = [compose_impl(r) for r in self.fgraph.outputs]
......@@ -3556,32 +3586,19 @@ class Composite(ScalarOp):
Return a readable string representation of self.fgraph.
"""
try:
rval = self.name
except AttributeError:
if 0:
l = []
for n in self.fgraph.toposort():
if hasattr(n.op, "name") and n.op.name is not None:
v = n.op.name
if v.startswith("Composite"):
v = v[len("Composite"):]
else:
v = n.op.__class__.__name__
l.append(v)
rval = "Composite{" + ",".join(l) + "}"
else:
for i, r in enumerate(self.fgraph.inputs):
r.name = 'i%i' % i
for i, r in enumerate(self.fgraph.outputs):
r.name = 'o%i' % i
io = set(self.fgraph.inputs + self.fgraph.outputs)
for i, r in enumerate(self.fgraph.variables):
if r not in io and len(r.clients) > 1:
r.name = 't%i' % i
rval = "Composite{%s}" % ', '.join([pprint(output) for output
in self.fgraph.outputs])
self.name = rval
rval = self.name
if rval is None:
for i, r in enumerate(self.fgraph.inputs):
r.name = 'i%i' % i
for i, r in enumerate(self.fgraph.outputs):
r.name = 'o%i' % i
io = set(self.fgraph.inputs + self.fgraph.outputs)
for i, r in enumerate(self.fgraph.variables):
if r not in io and len(r.clients) > 1:
r.name = 't%i' % i
rval = "Composite{%s}" % ', '.join([pprint(output) for output
in self.fgraph.outputs])
self.name = rval
def init_fgraph(self):
# The clone done by FunctionGraph is needed as we don't want
......@@ -3642,9 +3659,15 @@ class Composite(ScalarOp):
self.nin = len(inputs)
self.nout = len(outputs)
self.init_fgraph() # self.fgraph
self.init_name() # self.name
self.init_c_code() # self._c_code and self.nodenames
# Postpone the creation in case it isn't needed.
# self.init_name() # self.name
self.name = None
def prepare_node(self, node, storage_map, compute_map):
self.init_py_impls() # self._impls
for n in theano.gof.graph.list_of_nodes(self.inputs, self.outputs):
n.op.prepare_node(n, None, None)
def output_types(self, input_types):
if tuple(input_types) != self.inputs_type:
......@@ -3688,6 +3711,9 @@ class Composite(ScalarOp):
raise NotImplementedError("grad is not implemented for Composite")
def c_code(self, node, nodename, inames, onames, sub):
if not hasattr(self, '_c_code'):
self.init_c_code()
d = dict(chain(izip(("i%i" % i for i in xrange(len(inames))), inames),
izip(("o%i" % i for i in xrange(len(onames))),
onames)), **sub)
......@@ -3745,9 +3771,13 @@ class Composite(ScalarOp):
return False
# see __hash__ for comment on why there is no mention of fgraph
# or module cache key here.
if not hasattr(self, '_c_code'):
self.init_c_code() # self._c_code and self.nodenames
return (self._c_code == other._c_code)
def __hash__(self):
if not hasattr(self, '_c_code'):
self.init_c_code() # self._c_code and self.nodenames
rval = hash((type(self),
self.nin,
self.nout,
......@@ -3764,7 +3794,7 @@ class Composite(ScalarOp):
def __getstate__(self):
rval = dict(self.__dict__)
del rval['_impls']
rval.pop('_impls', None)
del rval['fgraph']
return rval
......
......@@ -849,6 +849,14 @@ second dimension
char = numpy.sctype2char(out_dtype)
sig = char * node.nin + '->' + char * node.nout
node.tag.sig = sig
node.tag.fake_node = Apply(
self.scalar_op,
[get_scalar_type(dtype=input.type.dtype).make_variable()
for input in node.inputs],
[get_scalar_type(dtype=output.type.dtype).make_variable()
for output in node.outputs])
self.scalar_op.prepare_node(node.tag.fake_node, None, None)
def perform(self, node, inputs, output_storage):
if len(node.inputs) >= 32:
......@@ -991,6 +999,11 @@ second dimension
return rval
def _c_all(self, node, nodename, inames, onames, sub):
# Some ops call directly the Elemwise._c_all or Elemwise.c_code
# To not request all of them to call prepare_node(), do it here.
# There is no harm if it get called multile time.
if not hasattr(node.tag, 'fake_node'):
self.prepare_node(node, None, None)
_inames = inames
_onames = onames
......@@ -1109,11 +1122,7 @@ second dimension
# We generate the C code of the inner loop using the scalar op
task_code = self.scalar_op.c_code(
Apply(self.scalar_op,
[get_scalar_type(dtype=input.type.dtype).make_variable()
for input in node.inputs],
[get_scalar_type(dtype=output.type.dtype).make_variable()
for output in node.outputs]),
node.tag.fake_node,
nodename + '_scalar_',
["%s_i" % s for s in _inames],
["%s_i" % s for s in onames],
......
......@@ -7183,7 +7183,9 @@ def local_add_mul_fusion(node):
for inp in node.inputs:
if (inp.owner and
isinstance(inp.owner.op, Elemwise) and
isinstance(inp.owner.op.scalar_op, s_op)):
isinstance(inp.owner.op.scalar_op, s_op) and
# Do not duplicate the operation.
len(inp.clients) == 1):
new_inp.extend(inp.owner.inputs)
fused = True
else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论