提交 f5df680d authored 作者: Frederic's avatar Frederic

pep8

上级 76426562
...@@ -22,9 +22,6 @@ Also, we should make the fgraph refuse optimization that break the canonization ...@@ -22,9 +22,6 @@ Also, we should make the fgraph refuse optimization that break the canonization
# TODO: intelligent merge for mul/add # TODO: intelligent merge for mul/add
# TODO: 0*x -> 0 # TODO: 0*x -> 0
import logging import logging
_logger = logging.getLogger('theano.tensor.opt') _logger = logging.getLogger('theano.tensor.opt')
...@@ -35,10 +32,12 @@ from theano.tensor import basic as T ...@@ -35,10 +32,12 @@ from theano.tensor import basic as T
from theano.gof.opt import Optimizer from theano.gof.opt import Optimizer
from theano.gof import InconsistencyError, toolbox from theano.gof import InconsistencyError, toolbox
from theano.tensor.basic import get_scalar_constant_value, NotScalarConstantError from theano.tensor.basic import (get_scalar_constant_value,
NotScalarConstantError)
from theano.tensor.opt import register_uncanonicalize from theano.tensor.opt import register_uncanonicalize
from theano import scalar as scal from theano import scalar as scal
class MaxAndArgmaxOptimizer(Optimizer): class MaxAndArgmaxOptimizer(Optimizer):
"""Replace MaxAndArgmax by CAReduce when the argmax is not used """Replace MaxAndArgmax by CAReduce when the argmax is not used
...@@ -56,23 +55,25 @@ class MaxAndArgmaxOptimizer(Optimizer): ...@@ -56,23 +55,25 @@ class MaxAndArgmaxOptimizer(Optimizer):
did_something = False did_something = False
for node in nodelist: for node in nodelist:
if node.op == T._max_and_argmax: if node.op == T._max_and_argmax:
if len(node.outputs[1].clients)==0: if len(node.outputs[1].clients) == 0:
try: try:
axis=get_scalar_constant_value(node.inputs[1]) axis = get_scalar_constant_value(node.inputs[1])
except NotScalarConstantError: except NotScalarConstantError:
return False return False
new = CAReduce(scal.maximum,axis)(node.inputs[0]) new = CAReduce(scal.maximum, axis)(node.inputs[0])
try: try:
fgraph.replace_all_validate( fgraph.replace_all_validate(
((node.outputs[0],new),), ((node.outputs[0], new),),
reason = self.__class__.__name__) reason=self.__class__.__name__)
did_something = True did_something = True
break break
except InconsistencyError, e: except InconsistencyError, e:
pass pass
register_uncanonicalize(MaxAndArgmaxOptimizer(),name='MaxAndArgmaxOptimizer') register_uncanonicalize(MaxAndArgmaxOptimizer(),
name='MaxAndArgmaxOptimizer')
@register_uncanonicalize @register_uncanonicalize
@gof.local_optimizer([T._shape]) @gof.local_optimizer([T._shape])
...@@ -87,9 +88,12 @@ def local_max_to_min(node): ...@@ -87,9 +88,12 @@ def local_max_to_min(node):
""" """
if node.op == T.neg and node.inputs[0].owner: if node.op == T.neg and node.inputs[0].owner:
max = node.inputs[0] max = node.inputs[0]
if max.owner and isinstance(max.owner.op, CAReduce) and max.owner.op.scalar_op==scal.maximum: if (max.owner and
isinstance(max.owner.op, CAReduce)
and max.owner.op.scalar_op == scal.maximum):
neg = max.owner.inputs[0] neg = max.owner.inputs[0]
if neg.owner and neg.owner.op == T.neg: if neg.owner and neg.owner.op == T.neg:
return [CAReduce(scal.minimum,max.owner.op.axis)(neg.owner.inputs[0])] return [CAReduce(scal.minimum,
max.owner.op.axis)(neg.owner.inputs[0])]
return False return False
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论