提交 8c5f2e91 authored 作者: Frederic Bastien's avatar Frederic Bastien

Make elemwise inplace profiler be well done.

上级 23f4d762
......@@ -743,7 +743,7 @@ optdb.register('gpua_elemwise_fusion',
tensor.opt.FusionOptimizer(gpu_local_elemwise_fusion), 49,
'fast_run', 'fusion', 'local_elemwise_fusion', 'gpuarray')
inplace_gpu_elemwise_opt = tensor.opt.inplace_elemwise_optimizer_op(
inplace_gpu_elemwise_opt = tensor.opt.InplaceElemwiseOptimizer(
GpuElemwise)
optdb.register('gpua_inplace_opt', inplace_gpu_elemwise_opt, 75,
'inplace_elemwise_optimizer', 'fast_run', 'inplace', 'gpuarray')
......
......@@ -2181,7 +2181,7 @@ else:
71.00, 'fusion', 'local_elemwise_fusion')
# GpuElemwise inplace
gpu_inplace_elemwise_optimizer = tensor.opt.inplace_elemwise_optimizer_op(
gpu_inplace_elemwise_optimizer = tensor.opt.InplaceElemwiseOptimizer(
GpuElemwise)
# DO NOT PLACE add a 'gpu' tag here! This would enable it in fast_compile.
# It still will be run in fast_run with device=gpu with the current tag.
......
......@@ -23,7 +23,7 @@ from theano import gof
from theano.compat import izip
from theano.gof import opt, InconsistencyError, TopoOptimizer, graph
from theano.gof import Variable, Constant
from theano.gof.opt import copy_stack_trace, in2out
from theano.gof.opt import copy_stack_trace, in2out, Optimizer
from theano.gof.utils import MethodNotDefined
from theano.gradient import DisconnectedType
from theano.configparser import config
......@@ -147,14 +147,34 @@ def broadcast_like(value, template, fgraph, dtype=None):
return rval
def inplace_elemwise_optimizer_op(OP):
class InplaceElemwiseOptimizer(Optimizer):
"""
We parametrise it to make it work for Elemwise and GpuElemwise op.
"""
@gof.inplace_optimizer
def inplace_elemwise_optimizer(fgraph):
def __init__(self, OP):
self.OP = OP
def add_requirements(self, fgraph):
fgraph.attach_feature(theano.gof.destroyhandler.DestroyHandler())
@staticmethod
def print_profile(stream, prof, level=0):
blanc = (' ' * level)
print(blanc, "InplaceElemwiseOptimizer ", prof['opt'].OP, file=stream)
for k in ['node_before',
'nb_call_replace',
'nb_call_validate',
'nb_inconsistent']:
print(blanc, k, prof[k], file=stream)
ndim = prof['ndim']
if ndim:
print(blanc, "ndim", "nb", file=stream)
for n in sorted(ndim.keys()):
print(blanc, n, ndim[n], file=stream)
def apply(self, fgraph):
"""
Usage: inplace_elemwise_optimizer.optimize(fgraph)
Usage: InplaceElemwiseOptimizer(OP).optimize(fgraph)
Attempts to replace all Broadcast ops by versions of them
that operate inplace. It operates greedily: for each Broadcast
......@@ -188,7 +208,8 @@ def inplace_elemwise_optimizer_op(OP):
# the solution is also applicable there.
# We execute `validate` after this number of change.
prof = {'node_before': len(fgraph.apply_nodes),
prof = {'opt': self,
'node_before': len(fgraph.apply_nodes),
'nb_call_replace': 0,
'nb_call_validate': 0,
'nb_inconsistent': 0,
......@@ -217,7 +238,7 @@ def inplace_elemwise_optimizer_op(OP):
for node in list(graph.io_toposort(fgraph.inputs, fgraph.outputs)):
op = node.op
# gpuarray GpuElemwise inherit from Elemwise
if not type(op) == OP:
if not type(op) == self.OP:
continue
# If big graph and the outputs are scalar, do not make it
# inplace.
......@@ -334,7 +355,7 @@ def inplace_elemwise_optimizer_op(OP):
scalar.transfer_type(
*[inplace_pattern.get(i, None)
for i in xrange(len(node.outputs))]))
new_outputs = OP(new_scal, inplace_pattern)(
new_outputs = self.OP(new_scal, inplace_pattern)(
*node.inputs, **dict(return_list=True))
new_node = new_outputs[0].owner
......@@ -343,7 +364,7 @@ def inplace_elemwise_optimizer_op(OP):
fgraph.replace(r, new_r,
reason="inplace_elemwise_optimizer")
nb_change_no_validate += 1
prof['ndim'][node.ndim] += 1
prof['ndim'][candidate_out_var.ndim] += 1
if nb_change_no_validate >= check_each_change:
prof['nb_call_validate'] += 1
fgraph.validate()
......@@ -373,10 +394,11 @@ def inplace_elemwise_optimizer_op(OP):
"performed due to unexpected error"),
file=sys.stderr)
fgraph.revert(chk)
print(prof)
return inplace_elemwise_optimizer
return prof
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
return inplace_elemwise_optimizer
inplace_elemwise_optimizer = inplace_elemwise_optimizer_op(T.Elemwise)
inplace_elemwise_optimizer = InplaceElemwiseOptimizer(T.Elemwise)
compile.optdb.register('inplace_elemwise_opt', inplace_elemwise_optimizer, 75,
'inplace_opt', # for historic reason
'inplace_elemwise_optimizer',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论