提交 4b3fce0a authored 作者: Frederic Bastien's avatar Frederic Bastien

Rename an optimization to have it more clear and more consistent.

上级 6b24aefb
...@@ -141,7 +141,7 @@ def local_gpu_elemwise_0(node): ...@@ -141,7 +141,7 @@ def local_gpu_elemwise_0(node):
if isinstance(node.op, tensor.Elemwise) and dtype_in_elemwise_supported(node.op): if isinstance(node.op, tensor.Elemwise) and dtype_in_elemwise_supported(node.op):
if numpy.any([i.owner and isinstance(i.owner.op, HostFromGpu) for i in node.inputs]): if numpy.any([i.owner and isinstance(i.owner.op, HostFromGpu) for i in node.inputs]):
if numpy.all([o.type.dtype == 'float32' for o in node.outputs]): if numpy.all([o.type.dtype == 'float32' for o in node.outputs]):
#don't set any inplace pattern. gpu_insert_inplace_optimizer will do it later #don't set any inplace pattern. gpu_inplace_optimizer will do it later
new_op = GpuElemwise(node.op.scalar_op) new_op = GpuElemwise(node.op.scalar_op)
# first establish that float32 can store all inputs # first establish that float32 can store all inputs
...@@ -184,7 +184,7 @@ def local_gpu_elemwise_1(node): ...@@ -184,7 +184,7 @@ def local_gpu_elemwise_1(node):
dtype_in_elemwise_supported(node.op)): dtype_in_elemwise_supported(node.op)):
elemwise_node = host_i.owner elemwise_node = host_i.owner
#don't set any inplace pattern. gpu_insert_inplace_optimizer will do it later #don't set any inplace pattern. gpu_inplace_optimizer will do it later
new_op = GpuElemwise(elemwise_node.op.scalar_op) new_op = GpuElemwise(elemwise_node.op.scalar_op)
if all([i.dtype=='float32' for i in elemwise_node.inputs]): if all([i.dtype=='float32' for i in elemwise_node.inputs]):
gpu_elemwise = new_op(*[gpu_from_host(i) for i in elemwise_node.inputs]) gpu_elemwise = new_op(*[gpu_from_host(i) for i in elemwise_node.inputs])
...@@ -993,10 +993,10 @@ else: ...@@ -993,10 +993,10 @@ else:
compile.optdb.register('gpu_elemwise_fusion', tensor.opt.FusionOptimizer(gpu_local_elemwise_fusion), 71.00, 'fusion', 'local_elemwise_fusion') compile.optdb.register('gpu_elemwise_fusion', tensor.opt.FusionOptimizer(gpu_local_elemwise_fusion), 71.00, 'fusion', 'local_elemwise_fusion')
#GpuElemwise inplace #GpuElemwise inplace
gpu_insert_inplace_optimizer = tensor.opt.insert_inplace_optimizer_op( gpu_inplace_elemwise_optimizer = tensor.opt.inplace_elemwise_optimizer_op(
GpuElemwise) GpuElemwise)
compile.optdb.register('gpu_inplace_opt', gpu_insert_inplace_optimizer, 75, optdb.register('gpu_inplace_elemwise_opt', gpu_inplace_elemwise_optimizer, 75,
'fast_run', 'inplace','gpu_inplace') 'fast_run', 'inplace','gpu_inplace', 'gpu')
@register_opt() @register_opt()
@local_optimizer([tensor.Alloc]) @local_optimizer([tensor.Alloc])
......
...@@ -107,14 +107,14 @@ theano.configparser.AddConfigVar('tensor.insert_inplace_optimizer_validate_nb', ...@@ -107,14 +107,14 @@ theano.configparser.AddConfigVar('tensor.insert_inplace_optimizer_validate_nb',
theano.configparser.IntParam(-1), theano.configparser.IntParam(-1),
in_c_key=False) in_c_key=False)
def insert_inplace_optimizer_op(OP): def inplace_elemwise_optimizer_op(OP):
""" """
We parametrise it to make it work for Elemwise and GpuElemwise op. We parametrise it to make it work for Elemwise and GpuElemwise op.
""" """
@gof.optimizer @gof.optimizer
def insert_inplace_optimizer(env): def inplace_elemwise_optimizer(env):
""" """
Usage: inplace_optimizer.optimize(env) Usage: inplace_elemwise_optimizer.optimize(env)
Attempts to replace all Broadcast ops by versions of them Attempts to replace all Broadcast ops by versions of them
that operate inplace. It operates greedily: for each Broadcast that operate inplace. It operates greedily: for each Broadcast
...@@ -193,7 +193,7 @@ def insert_inplace_optimizer_op(OP): ...@@ -193,7 +193,7 @@ def insert_inplace_optimizer_op(OP):
for r,new_r in zip(node.outputs,new.outputs): for r,new_r in zip(node.outputs,new.outputs):
env.replace(r,new_r, env.replace(r,new_r,
reason="insert_inplace_optimizer") reason="inplace_elemwise_optimizer")
nb_change_no_validate +=1 nb_change_no_validate +=1
if nb_change_no_validate >= validate_each_change: if nb_change_no_validate >= validate_each_change:
env.validate() env.validate()
...@@ -218,11 +218,11 @@ def insert_inplace_optimizer_op(OP): ...@@ -218,11 +218,11 @@ def insert_inplace_optimizer_op(OP):
if not raised_warning: if not raised_warning:
print >> sys.stderr, "Their was some inplace optimization that was not done due to unexpected error" print >> sys.stderr, "Their was some inplace optimization that was not done due to unexpected error"
env.revert(chk) env.revert(chk)
return insert_inplace_optimizer return inplace_elemwise_optimizer
insert_inplace_optimizer = insert_inplace_optimizer_op(T.Elemwise) inplace_elemwise_optimizer = inplace_elemwise_optimizer_op(T.Elemwise)
compile.optdb.register('inplace_opt', insert_inplace_optimizer, 75, 'fast_run', 'inplace') compile.optdb.register('inplace_opt', inplace_elemwise_optimizer, 75, 'fast_run', 'inplace')
def register_canonicalize(lopt, *tags, **kwargs): def register_canonicalize(lopt, *tags, **kwargs):
name = (kwargs and kwargs.pop('name')) or lopt.__name__ name = (kwargs and kwargs.pop('name')) or lopt.__name__
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论