提交 f5df680d authored 作者: Frederic's avatar Frederic

pep8

上级 76426562
......@@ -22,9 +22,6 @@ Also, we should make the fgraph refuse optimization that break the canonization
# TODO: intelligent merge for mul/add
# TODO: 0*x -> 0
import logging
_logger = logging.getLogger('theano.tensor.opt')
......@@ -35,10 +32,12 @@ from theano.tensor import basic as T
from theano.gof.opt import Optimizer
from theano.gof import InconsistencyError, toolbox
from theano.tensor.basic import get_scalar_constant_value, NotScalarConstantError
from theano.tensor.basic import (get_scalar_constant_value,
NotScalarConstantError)
from theano.tensor.opt import register_uncanonicalize
from theano import scalar as scal
class MaxAndArgmaxOptimizer(Optimizer):
"""Replace MaxAndArgmax by CAReduce when the argmax is not used
......@@ -56,23 +55,25 @@ class MaxAndArgmaxOptimizer(Optimizer):
did_something = False
for node in nodelist:
if node.op == T._max_and_argmax:
if len(node.outputs[1].clients)==0:
if len(node.outputs[1].clients) == 0:
try:
axis=get_scalar_constant_value(node.inputs[1])
axis = get_scalar_constant_value(node.inputs[1])
except NotScalarConstantError:
return False
new = CAReduce(scal.maximum,axis)(node.inputs[0])
new = CAReduce(scal.maximum, axis)(node.inputs[0])
try:
fgraph.replace_all_validate(
((node.outputs[0],new),),
reason = self.__class__.__name__)
((node.outputs[0], new),),
reason=self.__class__.__name__)
did_something = True
break
except InconsistencyError, e:
pass
register_uncanonicalize(MaxAndArgmaxOptimizer(),name='MaxAndArgmaxOptimizer')
register_uncanonicalize(MaxAndArgmaxOptimizer(),
name='MaxAndArgmaxOptimizer')
@register_uncanonicalize
@gof.local_optimizer([T._shape])
......@@ -87,9 +88,12 @@ def local_max_to_min(node):
"""
if node.op == T.neg and node.inputs[0].owner:
max = node.inputs[0]
if max.owner and isinstance(max.owner.op, CAReduce) and max.owner.op.scalar_op==scal.maximum:
if (max.owner and
isinstance(max.owner.op, CAReduce)
and max.owner.op.scalar_op == scal.maximum):
neg = max.owner.inputs[0]
if neg.owner and neg.owner.op == T.neg:
return [CAReduce(scal.minimum,max.owner.op.axis)(neg.owner.inputs[0])]
return [CAReduce(scal.minimum,
max.owner.op.axis)(neg.owner.inputs[0])]
return False
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论