提交 3c70348f authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #4524 from nouiz/opt

Opt related changes.
...@@ -84,10 +84,15 @@ def _atexit_print_fn(): ...@@ -84,10 +84,15 @@ def _atexit_print_fn():
cum_attr[key] = val cum_attr[key] = val
if cum.optimizer_profile and ps.optimizer_profile: if cum.optimizer_profile and ps.optimizer_profile:
merge = cum.optimizer_profile[0].merge_profile( try:
cum.optimizer_profile[1], merge = cum.optimizer_profile[0].merge_profile(
ps.optimizer_profile[1]) cum.optimizer_profile[1],
cum.optimizer_profile = (cum.optimizer_profile[0], merge) ps.optimizer_profile[1])
cum.optimizer_profile = (cum.optimizer_profile[0], merge)
except Exception as e:
print("Got an exception while merging profile")
print(e)
cum.optimizer_profile = None
else: else:
cum.optimizer_profile = None cum.optimizer_profile = None
......
...@@ -220,8 +220,10 @@ class SeqOptimizer(Optimizer, list): ...@@ -220,8 +220,10 @@ class SeqOptimizer(Optimizer, list):
if fgraph.profile: if fgraph.profile:
validate_before = fgraph.profile.validate_time validate_before = fgraph.profile.validate_time
sub_validate_time = [validate_before] sub_validate_time = [validate_before]
callbacks_before = fgraph.execute_callbacks_times.copy()
else: else:
sub_validate_time = [] sub_validate_time = []
callbacks_before = []
callback_before = fgraph.execute_callbacks_time callback_before = fgraph.execute_callbacks_time
nb_node_before = len(fgraph.apply_nodes) nb_node_before = len(fgraph.apply_nodes)
sub_profs = [] sub_profs = []
...@@ -249,12 +251,22 @@ class SeqOptimizer(Optimizer, list): ...@@ -249,12 +251,22 @@ class SeqOptimizer(Optimizer, list):
if fgraph.profile: if fgraph.profile:
validate_time = fgraph.profile.validate_time - validate_before validate_time = fgraph.profile.validate_time - validate_before
callbacks_time = {}
for k, v in iteritems(fgraph.execute_callbacks_times):
if k in callbacks_before:
t = v - callbacks_before[k]
if t > 0:
callbacks_time[k] = t
else:
callbacks_time[k] = v
else: else:
validate_time = None validate_time = None
callbacks_time = {}
callback_time = fgraph.execute_callbacks_time - callback_before callback_time = fgraph.execute_callbacks_time - callback_before
return (self, l, validate_time, callback_time, nb_node_before, return (self, l, validate_time, callback_time, nb_node_before,
len(fgraph.apply_nodes), sub_profs, sub_validate_time, len(fgraph.apply_nodes), sub_profs, sub_validate_time,
nb_nodes) nb_nodes, callbacks_time)
def __str__(self): def __str__(self):
return "SeqOpt(%s)" % list.__str__(self) return "SeqOpt(%s)" % list.__str__(self)
...@@ -274,8 +286,9 @@ class SeqOptimizer(Optimizer, list): ...@@ -274,8 +286,9 @@ class SeqOptimizer(Optimizer, list):
@staticmethod @staticmethod
def print_profile(stream, prof, level=0): def print_profile(stream, prof, level=0):
(opts, prof, validate_time, callback_time, nb_node_before, (opts, prof, validate_time, callback_time,
nb_node_after, sub_profs, sub_validate_time, nb_nodes) = prof nb_node_before, nb_node_after, sub_profs, sub_validate_time,
nb_nodes, callbacks_time) = prof
blanc = (' ' * level) blanc = (' ' * level)
print(blanc, "SeqOptimizer", end=' ', file=stream) print(blanc, "SeqOptimizer", end=' ', file=stream)
...@@ -287,9 +300,20 @@ class SeqOptimizer(Optimizer, list): ...@@ -287,9 +300,20 @@ class SeqOptimizer(Optimizer, list):
" before/after optimization" % ( " before/after optimization" % (
sum(prof), nb_node_before, nb_node_after)), file=stream) sum(prof), nb_node_before, nb_node_after)), file=stream)
print(blanc, " %.3fs for callback" % (callback_time), file=stream) print(blanc, " %.3fs for callback" % (callback_time), file=stream)
print(blanc, " %.3fs for fgraph.validate()" % (validate_time), file=stream) print(blanc, " %.3fs for fgraph.validate()" % (validate_time),
file=stream)
if callback_time > 1:
print(blanc, " callbacks_time", file=stream)
for i in sorted(iteritems(callbacks_time), key=lambda a: -a[1]):
if i[1] > 0:
# We want to have the __str__ called, so we can't
# just print i.
print(blanc, " ", i[0], ',', i[1], file=stream)
if level == 0: if level == 0:
print(blanc, " time - (name, class, index, nodes before, nodes after) - validate time", file=stream) print(blanc,
" time - (name, class, index, nodes before, nodes after) - validate time",
file=stream)
ll = [] ll = []
for opt in opts: for opt in opts:
if hasattr(opt, "__name__"): if hasattr(opt, "__name__"):
...@@ -298,7 +322,7 @@ class SeqOptimizer(Optimizer, list): ...@@ -298,7 +322,7 @@ class SeqOptimizer(Optimizer, list):
name = opt.name name = opt.name
idx = opts.index(opt) idx = opts.index(opt)
ll.append((name, opt.__class__.__name__, ll.append((name, opt.__class__.__name__,
idx) + nb_nodes[idx]) idx))
lll = sorted(zip(prof, ll, nb_nodes), key=lambda a: a[0]) lll = sorted(zip(prof, ll, nb_nodes), key=lambda a: a[0])
for (t, opt, nb_n) in lll[::-1]: for (t, opt, nb_n) in lll[::-1]:
...@@ -375,6 +399,7 @@ class SeqOptimizer(Optimizer, list): ...@@ -375,6 +399,7 @@ class SeqOptimizer(Optimizer, list):
new_sub_profile.append(p[6][idx]) new_sub_profile.append(p[6][idx])
new_opt = SeqOptimizer(*new_l) new_opt = SeqOptimizer(*new_l)
new_callbacks_times = merge_dict(prof1[9], prof2[9])
# We need to assert based on the name as we merge also based on # We need to assert based on the name as we merge also based on
# the name. # the name.
assert set([l.name for l in prof1[0]]).issubset( assert set([l.name for l in prof1[0]]).issubset(
...@@ -384,7 +409,8 @@ class SeqOptimizer(Optimizer, list): ...@@ -384,7 +409,8 @@ class SeqOptimizer(Optimizer, list):
assert len(new_t) == len(new_opt) == len(new_sub_profile) assert len(new_t) == len(new_opt) == len(new_sub_profile)
return (new_opt, new_t, prof1[2] + prof2[2], return (new_opt, new_t, prof1[2] + prof2[2],
prof1[3] + prof2[3], prof1[3] + prof2[3],
-1, -1, new_sub_profile, []) -1, -1, new_sub_profile, [],
new_callbacks_times)
class _metadict: class _metadict:
...@@ -838,7 +864,9 @@ class MergeOptimizer(Optimizer): ...@@ -838,7 +864,9 @@ class MergeOptimizer(Optimizer):
callbacks_time = {} callbacks_time = {}
for k, v in iteritems(fgraph.execute_callbacks_times): for k, v in iteritems(fgraph.execute_callbacks_times):
if k in callbacks_before: if k in callbacks_before:
callbacks_time[k] = v - callbacks_before[k] t = v - callbacks_before[k]
if t > 0:
callbacks_time[k] = t
else: else:
callbacks_time[k] = v callbacks_time[k] = v
else: else:
...@@ -868,7 +896,9 @@ class MergeOptimizer(Optimizer): ...@@ -868,7 +896,9 @@ class MergeOptimizer(Optimizer):
print(blanc, " callbacks_time", file=stream) print(blanc, " callbacks_time", file=stream)
for i in sorted(iteritems(callbacks_time), key=lambda a: a[1]): for i in sorted(iteritems(callbacks_time), key=lambda a: a[1]):
if i[1] > 0: if i[1] > 0:
print(i) # We want to have the __str__ called, so we can't
# just print i.
print(blanc, " ", i[0], ',', i[1], file=stream)
@staticmethod @staticmethod
def merge_profile(prof1, prof2): def merge_profile(prof1, prof2):
...@@ -1591,10 +1621,14 @@ class PatternSub(LocalOptimizer): ...@@ -1591,10 +1621,14 @@ class PatternSub(LocalOptimizer):
# Use the following classes to apply LocalOptimizers # Use the following classes to apply LocalOptimizers
class Updater: class Updater:
def __init__(self, importer, pruner, chin): def __init__(self, importer, pruner, chin, name=None):
self.importer = importer self.importer = importer
self.pruner = pruner self.pruner = pruner
self.chin = chin self.chin = chin
self.name = name
def __str__(self):
return "Updater{%s}" % str(self.name)
def on_import(self, fgraph, node, reason): def on_import(self, fgraph, node, reason):
if self.importer: if self.importer:
...@@ -1694,7 +1728,7 @@ class NavigatorOptimizer(Optimizer): ...@@ -1694,7 +1728,7 @@ class NavigatorOptimizer(Optimizer):
self.ignore_newtrees = ignore_newtrees self.ignore_newtrees = ignore_newtrees
self.failure_callback = failure_callback self.failure_callback = failure_callback
def attach_updater(self, fgraph, importer, pruner, chin=None): def attach_updater(self, fgraph, importer, pruner, chin=None, name=None):
""" """
Install some FunctionGraph listeners to help the navigator deal with Install some FunctionGraph listeners to help the navigator deal with
the ignore_trees-related functionality. the ignore_trees-related functionality.
...@@ -1709,6 +1743,8 @@ class NavigatorOptimizer(Optimizer): ...@@ -1709,6 +1743,8 @@ class NavigatorOptimizer(Optimizer):
from the graph. from the graph.
chin chin
"on change input" called whenever a node's inputs change. "on change input" called whenever a node's inputs change.
name
name of the Updater to attach.
Returns Returns
------- -------
...@@ -1723,7 +1759,7 @@ class NavigatorOptimizer(Optimizer): ...@@ -1723,7 +1759,7 @@ class NavigatorOptimizer(Optimizer):
if importer is None and pruner is None: if importer is None and pruner is None:
return None return None
u = Updater(importer, pruner, chin) u = Updater(importer, pruner, chin, name=name)
fgraph.attach_feature(u) fgraph.attach_feature(u)
return u return u
...@@ -1875,8 +1911,8 @@ class TopoOptimizer(NavigatorOptimizer): ...@@ -1875,8 +1911,8 @@ class TopoOptimizer(NavigatorOptimizer):
q.remove(node) q.remove(node)
except ValueError: except ValueError:
pass pass
u = self.attach_updater(fgraph, importer, pruner,
u = self.attach_updater(fgraph, importer, pruner) name=getattr(self, 'name', None))
nb = 0 nb = 0
try: try:
t0 = time.time() t0 = time.time()
...@@ -1888,10 +1924,8 @@ class TopoOptimizer(NavigatorOptimizer): ...@@ -1888,10 +1924,8 @@ class TopoOptimizer(NavigatorOptimizer):
current_node = node current_node = node
nb += self.process_node(fgraph, node) nb += self.process_node(fgraph, node)
loop_t = time.time() - t0 loop_t = time.time() - t0
except Exception: finally:
self.detach_updater(fgraph, u) self.detach_updater(fgraph, u)
raise
self.detach_updater(fgraph, u)
callback_time = fgraph.execute_callbacks_time - callback_before callback_time = fgraph.execute_callbacks_time - callback_before
nb_nodes_end = len(fgraph.apply_nodes) nb_nodes_end = len(fgraph.apply_nodes)
...@@ -1950,16 +1984,15 @@ class OpKeyOptimizer(NavigatorOptimizer): ...@@ -1950,16 +1984,15 @@ class OpKeyOptimizer(NavigatorOptimizer):
q.remove(node) q.remove(node)
except ValueError: except ValueError:
pass pass
u = self.attach_updater(fgraph, importer, pruner) u = self.attach_updater(fgraph, importer, pruner,
name=getattr(self, 'name', None))
try: try:
while q: while q:
node = q.pop() node = q.pop()
current_node = node current_node = node
self.process_node(fgraph, node) self.process_node(fgraph, node)
except Exception: finally:
self.detach_updater(fgraph, u) self.detach_updater(fgraph, u)
raise
self.detach_updater(fgraph, u)
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
""" """
...@@ -1990,6 +2023,9 @@ class ChangeTracker: ...@@ -1990,6 +2023,9 @@ class ChangeTracker:
def on_attach(self, fgraph): def on_attach(self, fgraph):
fgraph.change_tracker = self fgraph.change_tracker = self
def on_detach(self, fgraph):
del fgraph.change_tracker
def merge_dict(d1, d2): def merge_dict(d1, d2):
""" """
...@@ -2033,6 +2069,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2033,6 +2069,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
optimizers, optimizers,
failure_callback=None, failure_callback=None,
ignore_newtrees=True, ignore_newtrees=True,
tracks_on_change_inputs=False,
max_use_ratio=None, max_use_ratio=None,
final_optimizers=None, final_optimizers=None,
cleanup_optimizers=None): cleanup_optimizers=None):
...@@ -2045,6 +2082,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2045,6 +2082,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
self.global_optimizers = [] self.global_optimizers = []
self.final_optimizers = [] self.final_optimizers = []
self.cleanup_optimizers = [] self.cleanup_optimizers = []
self.tracks_on_change_inputs = tracks_on_change_inputs
for opt in optimizers: for opt in optimizers:
if isinstance(opt, LocalOptimizer): if isinstance(opt, LocalOptimizer):
...@@ -2191,8 +2229,14 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2191,8 +2229,14 @@ class EquilibriumOptimizer(NavigatorOptimizer):
q.remove(node) q.remove(node)
except ValueError: except ValueError:
pass pass
chin = None
u = self.attach_updater(fgraph, importer, pruner) if self.tracks_on_change_inputs:
def chin(node, i, r, new_r, reason):
if node is not current_node and not isinstance(node, str):
q.append(node)
u = self.attach_updater(fgraph, importer, pruner,
chin=chin,
name=getattr(self, 'name', None))
try: try:
while q: while q:
node = q.pop() node = q.pop()
......
...@@ -244,16 +244,26 @@ class EquilibriumDB(DB): ...@@ -244,16 +244,26 @@ class EquilibriumDB(DB):
optimization application. This could result in less fgraph iterations, optimization application. This could result in less fgraph iterations,
but this doesn't mean it will be faster globally. but this doesn't mean it will be faster globally.
tracks_on_change_inputs
If True, we will re-apply local opt on nodes whose inputs
changed during local optimization application. This could
result in less fgraph iterations, but this doesn't mean it
will be faster globally.
Notes Notes
----- -----
We can put LocalOptimizer and Optimizer as EquilibriumOptimizer We can put LocalOptimizer and Optimizer as EquilibriumOptimizer
suppor both. suppor both.
It is probably not a good idea to have ignore_newtrees=False and
tracks_on_change_inputs=True
""" """
def __init__(self, ignore_newtrees=True): def __init__(self, ignore_newtrees=True, tracks_on_change_inputs=False):
super(EquilibriumDB, self).__init__() super(EquilibriumDB, self).__init__()
self.ignore_newtrees = ignore_newtrees self.ignore_newtrees = ignore_newtrees
self.tracks_on_change_inputs = tracks_on_change_inputs
self.__final__ = {} self.__final__ = {}
self.__cleanup__ = {} self.__cleanup__ = {}
...@@ -281,6 +291,7 @@ class EquilibriumDB(DB): ...@@ -281,6 +291,7 @@ class EquilibriumDB(DB):
opts, opts,
max_use_ratio=config.optdb.max_use_ratio, max_use_ratio=config.optdb.max_use_ratio,
ignore_newtrees=self.ignore_newtrees, ignore_newtrees=self.ignore_newtrees,
tracks_on_change_inputs=self.tracks_on_change_inputs,
failure_callback=opt.NavigatorOptimizer.warn_inplace, failure_callback=opt.NavigatorOptimizer.warn_inplace,
final_optimizers=final_opts, final_optimizers=final_opts,
cleanup_optimizers=cleanup_opts) cleanup_optimizers=cleanup_opts)
......
...@@ -1493,7 +1493,7 @@ def local_dnn_convi_output_merge(node, *inputs): ...@@ -1493,7 +1493,7 @@ def local_dnn_convi_output_merge(node, *inputs):
return [GpuDnnConvGradI(algo=node.op.algo)(*inputs)] return [GpuDnnConvGradI(algo=node.op.algo)(*inputs)]
@register_opt('cudnn') @register_opt('cudnn', 'fast_compile')
@op_lifter([Pool]) @op_lifter([Pool])
def local_pool_dnn_alternative(node, ctx_name): def local_pool_dnn_alternative(node, ctx_name):
if not dnn_available(ctx_name): if not dnn_available(ctx_name):
...@@ -1509,7 +1509,7 @@ def local_pool_dnn_alternative(node, ctx_name): ...@@ -1509,7 +1509,7 @@ def local_pool_dnn_alternative(node, ctx_name):
return dnn_pool(gpu_contiguous(img), ds, stride=stride, pad=pad, mode=mode) return dnn_pool(gpu_contiguous(img), ds, stride=stride, pad=pad, mode=mode)
@register_opt('cudnn') @register_opt('cudnn', 'fast_compile')
@op_lifter([MaxPoolGrad]) @op_lifter([MaxPoolGrad])
def local_pool_dnn_grad_stride(node, ctx_name): def local_pool_dnn_grad_stride(node, ctx_name):
if not dnn_available(ctx_name): if not dnn_available(ctx_name):
...@@ -1533,7 +1533,7 @@ def local_pool_dnn_grad_stride(node, ctx_name): ...@@ -1533,7 +1533,7 @@ def local_pool_dnn_grad_stride(node, ctx_name):
pad) pad)
@register_opt('cudnn') @register_opt('cudnn', 'fast_compile')
@op_lifter([AveragePoolGrad]) @op_lifter([AveragePoolGrad])
def local_avg_pool_dnn_grad_stride(node, ctx_name): def local_avg_pool_dnn_grad_stride(node, ctx_name):
if not dnn_available(ctx_name): if not dnn_available(ctx_name):
...@@ -1556,7 +1556,7 @@ def local_avg_pool_dnn_grad_stride(node, ctx_name): ...@@ -1556,7 +1556,7 @@ def local_avg_pool_dnn_grad_stride(node, ctx_name):
return GpuDnnPoolGrad(mode=mode)(gpu_contiguous(inp), cg, cg, ds, st, pad) return GpuDnnPoolGrad(mode=mode)(gpu_contiguous(inp), cg, cg, ds, st, pad)
@register_opt('cudnn') @register_opt('cudnn', 'fast_compile')
@local_optimizer([GpuSoftmax]) @local_optimizer([GpuSoftmax])
def local_softmax_dnn(node): def local_softmax_dnn(node):
if isinstance(node.op, GpuSoftmax): if isinstance(node.op, GpuSoftmax):
...@@ -1569,7 +1569,7 @@ def local_softmax_dnn(node): ...@@ -1569,7 +1569,7 @@ def local_softmax_dnn(node):
return [out] return [out]
@register_opt('cudnn') @register_opt('cudnn', 'stabilize')
@local_optimizer([GpuElemwise]) @local_optimizer([GpuElemwise])
def local_log_softmax_dnn(node): def local_log_softmax_dnn(node):
# This looks for GpuDnnSoftmax so we know that we have cudnn. # This looks for GpuDnnSoftmax so we know that we have cudnn.
...@@ -1586,7 +1586,7 @@ def local_log_softmax_dnn(node): ...@@ -1586,7 +1586,7 @@ def local_log_softmax_dnn(node):
return [new_softmax(softmax_node.inputs[0])] return [new_softmax(softmax_node.inputs[0])]
@register_opt('cudnn') @register_opt('cudnn', 'fast_compile')
@op_lifter([LogSoftmax]) @op_lifter([LogSoftmax])
def local_logsoftmax_to_dnn(node, ctx_name): def local_logsoftmax_to_dnn(node, ctx_name):
# Transform the input in the format expected by GpuDnnSoftmax # Transform the input in the format expected by GpuDnnSoftmax
...@@ -1624,7 +1624,7 @@ class NoCuDNNRaise(Optimizer): ...@@ -1624,7 +1624,7 @@ class NoCuDNNRaise(Optimizer):
gpu_seqopt.register("NoCuDNNRaise", NoCuDNNRaise(), 0, 'cudnn') gpu_seqopt.register("NoCuDNNRaise", NoCuDNNRaise(), 0, 'cudnn')
@register_opt('cudnn') @register_opt('cudnn', 'fast_compile')
@op_lifter([SoftmaxGrad]) @op_lifter([SoftmaxGrad])
def local_softmax_dnn_grad(node, ctx_name): def local_softmax_dnn_grad(node, ctx_name):
if not dnn_available(ctx_name): if not dnn_available(ctx_name):
......
...@@ -711,18 +711,14 @@ def local_gpua_careduce(node, context_name): ...@@ -711,18 +711,14 @@ def local_gpua_careduce(node, context_name):
assert reduce_mask[a] == 0 assert reduce_mask[a] == 0
reduce_mask[a] = 1 reduce_mask[a] = 1
shape_of = node.fgraph.shape_feature.shape_of new_in_shp = [shape_i(x, 0)]
x_shape = shape_of[x]
new_in_shp = [x_shape[0]]
new_mask = [reduce_mask[0]] new_mask = [reduce_mask[0]]
for i in xrange(1, x.type.ndim): for i in xrange(1, x.type.ndim):
if reduce_mask[i] == reduce_mask[i - 1]: if reduce_mask[i] == reduce_mask[i - 1]:
new_in_shp[-1] *= x_shape[i] new_in_shp[-1] *= shape_i(x, i)
else: else:
new_mask.append(reduce_mask[i]) new_mask.append(reduce_mask[i])
new_in_shp.append(x_shape[i]) new_in_shp.append(shape_i(x, i))
new_axis = [] new_axis = []
for idx, m in enumerate(new_mask): for idx, m in enumerate(new_mask):
if m == 1: if m == 1:
...@@ -744,8 +740,12 @@ def local_gpua_careduce(node, context_name): ...@@ -744,8 +740,12 @@ def local_gpua_careduce(node, context_name):
greduce(gpu_reshaped_x)) greduce(gpu_reshaped_x))
if reduce_reshaped_x.ndim != node.outputs[0].ndim: if reduce_reshaped_x.ndim != node.outputs[0].ndim:
out_shp = []
for i in range(x.ndim):
if i not in node.op.axis:
out_shp.append(shape_i(x, i))
unreshaped_reduce = reduce_reshaped_x.reshape( unreshaped_reduce = reduce_reshaped_x.reshape(
tensor.stack(shape_of[node.outputs[0]])) tensor.stack(out_shp))
else: else:
unreshaped_reduce = reduce_reshaped_x unreshaped_reduce = reduce_reshaped_x
return [unreshaped_reduce] return [unreshaped_reduce]
......
...@@ -249,6 +249,7 @@ if __name__ == "__main__": ...@@ -249,6 +249,7 @@ if __name__ == "__main__":
cuda version 7.5 7.0 6.5 cuda version 7.5 7.0 6.5
gpu gpu
M40 0.47s
k80 0.96s k80 0.96s
K6000/NOECC 0.69s K6000/NOECC 0.69s
K40 0.88s K40 0.88s
......
...@@ -2526,7 +2526,8 @@ if True: ...@@ -2526,7 +2526,8 @@ if True:
out = as_cuda_ndarray_variable(out.dimshuffle(0, 1)) out = as_cuda_ndarray_variable(out.dimshuffle(0, 1))
return [out] return [out]
@register_opt('cudnn') @register_opt('cudnn', 'stabilize', 'fast_compile')
# We put fast_compile as otherwise it won't be on the GPU.
@local_optimizer([GpuElemwise, LogSoftmax]) @local_optimizer([GpuElemwise, LogSoftmax])
def local_log_softmax_dnn(node): def local_log_softmax_dnn(node):
# The log-softmax implementation is only available starting at cuDNN V3 # The log-softmax implementation is only available starting at cuDNN V3
......
...@@ -14,6 +14,7 @@ from . import dnn ...@@ -14,6 +14,7 @@ from . import dnn
import theano import theano
from theano import scalar as scal from theano import scalar as scal
from theano import config, tensor, gof from theano import config, tensor, gof
from theano.compile.ops import shape_i
import theano.ifelse import theano.ifelse
import theano.tensor.signal.pool import theano.tensor.signal.pool
import theano.tensor.nnet import theano.tensor.nnet
...@@ -900,18 +901,14 @@ def local_gpu_careduce(node): ...@@ -900,18 +901,14 @@ def local_gpu_careduce(node):
# to make them a single dimension, do the reduction, and # to make them a single dimension, do the reduction, and
# then reshape to get them back. # then reshape to get them back.
shape_of = node.fgraph.shape_feature.shape_of new_in_shp = [shape_i(x, 0)]
x_shape = shape_of[x]
new_in_shp = [x_shape[0]]
new_mask = [reduce_mask[0]] new_mask = [reduce_mask[0]]
for i in xrange(1, x.type.ndim): for i in xrange(1, x.type.ndim):
if reduce_mask[i] == reduce_mask[i - 1]: if reduce_mask[i] == reduce_mask[i - 1]:
new_in_shp[-1] *= x_shape[i] new_in_shp[-1] *= shape_i(x, i)
else: else:
new_mask.append(reduce_mask[i]) new_mask.append(reduce_mask[i])
new_in_shp.append(x_shape[i]) new_in_shp.append(shape_i(x, i))
new_greduce = GpuCAReduce(new_mask, scalar_op) new_greduce = GpuCAReduce(new_mask, scalar_op)
new_x = x.reshape(tensor.stack(new_in_shp)) new_x = x.reshape(tensor.stack(new_in_shp))
...@@ -936,8 +933,11 @@ def local_gpu_careduce(node): ...@@ -936,8 +933,11 @@ def local_gpu_careduce(node):
# Restore the expected shape of the output # Restore the expected shape of the output
if rval.ndim != out.ndim: if rval.ndim != out.ndim:
rval = rval.reshape( out_shp = []
tensor.stack(shape_of[out])) for i in range(x.ndim):
if i not in node.op.axis:
out_shp.append(shape_i(x, i))
rval = rval.reshape(tensor.stack(out_shp))
if rval.type == out.type: if rval.type == out.type:
return [rval] return [rval]
......
...@@ -1436,7 +1436,8 @@ class GemmOptimizer(Optimizer): ...@@ -1436,7 +1436,8 @@ class GemmOptimizer(Optimizer):
if new_node is not node: if new_node is not node:
nodelist.append(new_node) nodelist.append(new_node)
u = theano.gof.opt.Updater(on_import, None, None) u = theano.gof.opt.Updater(on_import, None, None,
name="GemmOptimizer")
fgraph.attach_feature(u) fgraph.attach_feature(u)
while did_something: while did_something:
nb_iter += 1 nb_iter += 1
......
...@@ -1260,6 +1260,12 @@ class ShapeFeature(object): ...@@ -1260,6 +1260,12 @@ class ShapeFeature(object):
for node in fgraph.toposort(): for node in fgraph.toposort():
self.on_import(fgraph, node, reason='on_attach') self.on_import(fgraph, node, reason='on_attach')
def on_detach(self, fgraph):
self.shape_of = {}
self.scheduled = {}
self.shape_of_reverse_index = {}
del fgraph.shape_feature
def on_import(self, fgraph, node, reason): def on_import(self, fgraph, node, reason):
if node.outputs[0] in self.shape_of: if node.outputs[0] in self.shape_of:
# this is a revert, not really an import # this is a revert, not really an import
...@@ -1436,10 +1442,23 @@ class ShapeOptimizer(Optimizer): ...@@ -1436,10 +1442,23 @@ class ShapeOptimizer(Optimizer):
def apply(self, fgraph): def apply(self, fgraph):
pass pass
class UnShapeOptimizer(Optimizer):
"""Optimizer remove ShapeFeature as an fgraph feature."""
def apply(self, fgraph):
for feature in fgraph._features:
if isinstance(feature, ShapeFeature):
fgraph.remove_feature(feature)
# Register it after merge1 optimization at 0. We don't want to track # Register it after merge1 optimization at 0. We don't want to track
# the shape of merged node. # the shape of merged node.
theano.compile.mode.optdb.register('ShapeOpt', ShapeOptimizer(), theano.compile.mode.optdb.register('ShapeOpt', ShapeOptimizer(),
0.1, 'fast_run', 'fast_compile') 0.1, 'fast_run', 'fast_compile')
# Not enabled by default for now. Some crossentropy opt use the
# shape_feature. They are at step 2.01. uncanonicalize is at step
# 3. After it goes to 48.5 that move to the gpu. So 10 seem resonable.
theano.compile.mode.optdb.register('UnShapeOpt', UnShapeOptimizer(),
10)
def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论