提交 d40861ec authored 作者: Iban Harlouchet's avatar Iban Harlouchet

Flake8 for theano/tensor/opt.py

上级 34b98041
...@@ -6,8 +6,6 @@ from __future__ import print_function ...@@ -6,8 +6,6 @@ from __future__ import print_function
# TODO: 0*x -> 0 # TODO: 0*x -> 0
import logging import logging
_logger = logging.getLogger('theano.tensor.opt')
import itertools import itertools
import operator import operator
import sys import sys
...@@ -34,12 +32,10 @@ from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice, ...@@ -34,12 +32,10 @@ from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice,
Subtensor, IncSubtensor, make_constant, Subtensor, IncSubtensor, make_constant,
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
AdvancedIncSubtensor, AdvancedIncSubtensor,
AdvancedSubtensor,
AdvancedSubtensor1, AdvancedSubtensor1,
advanced_subtensor, advanced_subtensor,
advanced_subtensor1, advanced_subtensor1,
advanced_inc_subtensor1, advanced_inc_subtensor1)
inc_subtensor)
from theano import scalar from theano import scalar
from theano.scalar import basic from theano.scalar import basic
from theano.tensor import basic as T from theano.tensor import basic as T
...@@ -56,6 +52,8 @@ from theano.gof import toolbox ...@@ -56,6 +52,8 @@ from theano.gof import toolbox
from theano.tensor.basic import get_scalar_constant_value, ShapeError, NotScalarConstantError from theano.tensor.basic import get_scalar_constant_value, ShapeError, NotScalarConstantError
from six import StringIO from six import StringIO
_logger = logging.getLogger('theano.tensor.opt')
theano.configparser.AddConfigVar('on_shape_error', theano.configparser.AddConfigVar('on_shape_error',
"warn: print a warning and use the default" "warn: print a warning and use the default"
" value. raise: raise an error", " value. raise: raise an error",
...@@ -165,8 +163,8 @@ def broadcast_like(value, template, fgraph, dtype=None): ...@@ -165,8 +163,8 @@ def broadcast_like(value, template, fgraph, dtype=None):
# the template may have 1s in its shape without being broadcastable # the template may have 1s in its shape without being broadcastable
if rval.broadcastable != template.broadcastable: if rval.broadcastable != template.broadcastable:
rval = T.unbroadcast(rval, *[i for i in xrange(rval.ndim) rval = T.unbroadcast(rval, *[i for i in xrange(rval.ndim)
if rval.broadcastable[i] if rval.broadcastable[i] and
and not template.broadcastable[i]]) not template.broadcastable[i]])
assert rval.type.dtype == dtype assert rval.type.dtype == dtype
if rval.type.broadcastable != template.broadcastable: if rval.type.broadcastable != template.broadcastable:
...@@ -178,7 +176,8 @@ def broadcast_like(value, template, fgraph, dtype=None): ...@@ -178,7 +176,8 @@ def broadcast_like(value, template, fgraph, dtype=None):
return rval return rval
theano.configparser.AddConfigVar('tensor.insert_inplace_optimizer_validate_nb', theano.configparser.AddConfigVar(
'tensor.insert_inplace_optimizer_validate_nb',
"-1: auto, if graph have less then 500 nodes 1, else 10", "-1: auto, if graph have less then 500 nodes 1, else 10",
theano.configparser.IntParam(-1), theano.configparser.IntParam(-1),
in_c_key=False) in_c_key=False)
...@@ -251,11 +250,10 @@ def inplace_elemwise_optimizer_op(OP): ...@@ -251,11 +250,10 @@ def inplace_elemwise_optimizer_op(OP):
# target. # target.
# Remove here as faster. # Remove here as faster.
candidate_inputs = [i for i in xrange(len(node.inputs)) candidate_inputs = [i for i in xrange(len(node.inputs))
if i not in baseline.values() \ if i not in baseline.values() and
and not isinstance(node.inputs[i], not isinstance(node.inputs[i], Constant) and
Constant)\ not fgraph.destroyers(node.inputs[i]) and
and not fgraph.destroyers(node.inputs[i])\ node.inputs[i] not in protected_inputs]
and node.inputs[i] not in protected_inputs]
verbose = False verbose = False
...@@ -274,12 +272,12 @@ def inplace_elemwise_optimizer_op(OP): ...@@ -274,12 +272,12 @@ def inplace_elemwise_optimizer_op(OP):
if hasattr(op.scalar_op, "make_new_inplace"): if hasattr(op.scalar_op, "make_new_inplace"):
new_scal = op.scalar_op.make_new_inplace( new_scal = op.scalar_op.make_new_inplace(
scalar.transfer_type( scalar.transfer_type(
*[inplace_pattern.get(i, None) \ *[inplace_pattern.get(i, None)
for i in xrange(len(node.outputs))])) for i in xrange(len(node.outputs))]))
else: else:
new_scal = op.scalar_op.__class__( new_scal = op.scalar_op.__class__(
scalar.transfer_type( scalar.transfer_type(
*[inplace_pattern.get(i, None) \ *[inplace_pattern.get(i, None)
for i in xrange(len(node.outputs))])) for i in xrange(len(node.outputs))]))
new_outputs = OP(new_scal, inplace_pattern)( new_outputs = OP(new_scal, inplace_pattern)(
*node.inputs, **dict(return_list=True)) *node.inputs, **dict(return_list=True))
...@@ -295,9 +293,9 @@ def inplace_elemwise_optimizer_op(OP): ...@@ -295,9 +293,9 @@ def inplace_elemwise_optimizer_op(OP):
nb_change_no_validate = 0 nb_change_no_validate = 0
except (ValueError, TypeError, InconsistencyError) as e: except (ValueError, TypeError, InconsistencyError) as e:
if check_each_change != 1 and not raised_warning: if check_each_change != 1 and not raised_warning:
print(( print(("Some inplace optimization was not "
"Some inplace optimization was not " "performed due to unexpected error:"),
"performed due to unexpected error:"), file=sys.stderr) file=sys.stderr)
print(e, file=sys.stderr) print(e, file=sys.stderr)
raised_warning = True raised_warning = True
fgraph.revert(chk) fgraph.revert(chk)
...@@ -313,7 +311,8 @@ def inplace_elemwise_optimizer_op(OP): ...@@ -313,7 +311,8 @@ def inplace_elemwise_optimizer_op(OP):
except Exception: except Exception:
if not raised_warning: if not raised_warning:
print(("Some inplace optimization was not " print(("Some inplace optimization was not "
"performed due to unexpected error"), file=sys.stderr) "performed due to unexpected error"),
file=sys.stderr)
fgraph.revert(chk) fgraph.revert(chk)
return inplace_elemwise_optimizer return inplace_elemwise_optimizer
...@@ -381,8 +380,8 @@ def register_specialize_device(lopt, *tags, **kwargs): ...@@ -381,8 +380,8 @@ def register_specialize_device(lopt, *tags, **kwargs):
# Register merge_optimizer as a global opt during canonicalize # Register merge_optimizer as a global opt during canonicalize
compile.optdb['canonicalize'].register( compile.optdb['canonicalize'].register('canon_merge', merge_optimizer,
'canon_merge', merge_optimizer, 'fast_run', final_opt=True) 'fast_run', final_opt=True)
##################### #####################
...@@ -512,11 +511,10 @@ def local_lift_transpose_through_dot(node): ...@@ -512,11 +511,10 @@ def local_lift_transpose_through_dot(node):
inplace. The newly-introduced transpositions are not inplace, this will inplace. The newly-introduced transpositions are not inplace, this will
be taken care of in a later optimization phase. be taken care of in a later optimization phase.
""" """
if not (isinstance(node.op, T.DimShuffle) if not (isinstance(node.op, T.DimShuffle) and node.op.new_order == (1, 0)):
and node.op.new_order == (1, 0)):
return False return False
if not (node.inputs[0].owner if not (node.inputs[0].owner and
and isinstance(node.inputs[0].owner.op, T.Dot)): isinstance(node.inputs[0].owner.op, T.Dot)):
return False return False
x, y = node.inputs[0].owner.inputs x, y = node.inputs[0].owner.inputs
...@@ -601,10 +599,9 @@ class MakeVector(T.Op): ...@@ -601,10 +599,9 @@ class MakeVector(T.Op):
def make_node(self, *inputs): def make_node(self, *inputs):
inputs = list(map(T.as_tensor_variable, inputs)) inputs = list(map(T.as_tensor_variable, inputs))
if not all(a.type == inputs[0].type for a in inputs) or ( if (not all(a.type == inputs[0].type for a in inputs) or
len(inputs) > 0 and inputs[0].dtype != self.dtype): (len(inputs) > 0 and inputs[0].dtype != self.dtype)):
dtype = theano.scalar.upcast(self.dtype, dtype = theano.scalar.upcast(self.dtype, *[i.dtype for i in inputs])
*[i.dtype for i in inputs])
# upcast the input to the determined dtype, # upcast the input to the determined dtype,
# but don't downcast anything # but don't downcast anything
assert dtype == self.dtype, ( assert dtype == self.dtype, (
...@@ -613,10 +610,8 @@ class MakeVector(T.Op): ...@@ -613,10 +610,8 @@ class MakeVector(T.Op):
if not all(self.dtype == T.cast(i, dtype=dtype).dtype if not all(self.dtype == T.cast(i, dtype=dtype).dtype
for i in inputs): for i in inputs):
raise TypeError("MakeVector.make_node expected inputs" raise TypeError("MakeVector.make_node expected inputs"
" upcastable to %s. got %s" % ( " upcastable to %s. got %s" %
self.dtype, (self.dtype, str([i.dtype for i in inputs])))
str([i.dtype for i in inputs])
))
inputs = [T.cast(i, dtype=dtype) for i in inputs] inputs = [T.cast(i, dtype=dtype) for i in inputs]
assert all(self.dtype == a.dtype for a in inputs) assert all(self.dtype == a.dtype for a in inputs)
assert all(a.ndim == 0 for a in inputs) assert all(a.ndim == 0 for a in inputs)
...@@ -625,11 +620,9 @@ class MakeVector(T.Op): ...@@ -625,11 +620,9 @@ class MakeVector(T.Op):
dtype = inputs[0].type.dtype dtype = inputs[0].type.dtype
else: else:
dtype = self.dtype dtype = self.dtype
#bcastable = (len(inputs) == 1) # bcastable = (len(inputs) == 1)
bcastable = False bcastable = False
otype = T.TensorType( otype = T.TensorType(broadcastable=(bcastable,), dtype=dtype)
broadcastable=(bcastable,),
dtype=dtype)
return T.Apply(self, inputs, [otype()]) return T.Apply(self, inputs, [otype()])
def __str__(self): def __str__(self):
...@@ -700,13 +693,14 @@ class MakeVectorPrinter: ...@@ -700,13 +693,14 @@ class MakeVectorPrinter:
if r.owner is None: if r.owner is None:
raise TypeError("Can only print make_vector.") raise TypeError("Can only print make_vector.")
elif isinstance(r.owner.op, MakeVector): elif isinstance(r.owner.op, MakeVector):
return "[%s]" % ", ".join(pstate.pprinter.process( return "[%s]" % ", ".join(
input, pstate.clone(precedence=1000)) for input pstate.pprinter.process(input, pstate.clone(precedence=1000))
in r.owner.inputs) for input in r.owner.inputs)
else: else:
raise TypeError("Can only print make_vector.") raise TypeError("Can only print make_vector.")
T.pprint.assign(lambda pstate, r: r.owner and isinstance(
r.owner.op, MakeVector), MakeVectorPrinter()) T.pprint.assign(lambda pstate, r: r.owner and
isinstance(r.owner.op, MakeVector), MakeVectorPrinter())
class ShapeFeature(object): class ShapeFeature(object):
...@@ -957,10 +951,10 @@ class ShapeFeature(object): ...@@ -957,10 +951,10 @@ class ShapeFeature(object):
# Merge other_shape with r_shape, giving the priority to other_shape # Merge other_shape with r_shape, giving the priority to other_shape
merged_shape = [] merged_shape = []
for i, ps in enumerate(other_shape): for i, ps in enumerate(other_shape):
if (ps.owner if (ps.owner and
and isinstance(getattr(ps.owner, 'op', None), Shape_i) isinstance(getattr(ps.owner, 'op', None), Shape_i) and
and ps.owner.op.i == i ps.owner.op.i == i and
and ps.owner.inputs[0] in (r, other_r)): ps.owner.inputs[0] in (r, other_r)):
# If other_shape[i] is uninformative, use r_shape[i]. # If other_shape[i] is uninformative, use r_shape[i].
# For now, we consider 2 cases of uninformative other_shape[i]: # For now, we consider 2 cases of uninformative other_shape[i]:
# - Shape_i(i)(other_r); # - Shape_i(i)(other_r);
...@@ -1310,10 +1304,9 @@ def local_fill_to_alloc(node): ...@@ -1310,10 +1304,9 @@ def local_fill_to_alloc(node):
return return
# TODO: cut out un-necessary dimshuffles of v # TODO: cut out un-necessary dimshuffles of v
assert rval[0].type == node.outputs[0].type, ('rval', rval[0].type, assert rval[0].type == node.outputs[0].type, (
'orig', node.outputs[0].type, 'rval', rval[0].type, 'orig', node.outputs[0].type, 'node',
'node', node, node,) # theano.printing.debugprint(node.outputs[0], file='str'))
) # theano.printing.debugprint(node.outputs[0], file='str'))
return rval return rval
...@@ -1404,7 +1397,7 @@ def local_subtensor_make_vector(node): ...@@ -1404,7 +1397,7 @@ def local_subtensor_make_vector(node):
try: try:
idx, = node.op.idx_list idx, = node.op.idx_list
except Exception: except Exception:
#'how can you have multiple indexes into a shape?' # 'how can you have multiple indexes into a shape?'
raise raise
if isinstance(idx, (scalar.Scalar, T.TensorType)): if isinstance(idx, (scalar.Scalar, T.TensorType)):
...@@ -1482,8 +1475,8 @@ def local_useless_elemwise(node): ...@@ -1482,8 +1475,8 @@ def local_useless_elemwise(node):
elif node.op.scalar_op == theano.scalar.add and len(node.inputs) == 1: elif node.op.scalar_op == theano.scalar.add and len(node.inputs) == 1:
return [node.inputs[0]] return [node.inputs[0]]
elif (node.op.scalar_op == theano.scalar.identity elif (node.op.scalar_op == theano.scalar.identity and
and len(node.inputs) == 1): len(node.inputs) == 1):
return [node.inputs[0]] return [node.inputs[0]]
...@@ -1749,10 +1742,8 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1749,10 +1742,8 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
# At least one input must have an owner that is either a AllocOP or a # At least one input must have an owner that is either a AllocOP or a
# DimShuffleOP with an owner that is a AllocOP -- otherwise there is # DimShuffleOP with an owner that is a AllocOP -- otherwise there is
# nothing to optimize. # nothing to optimize.
if not any([i.owner if not any([i.owner and (isinstance(i.owner.op, AllocOP) or
and (isinstance(i.owner.op, AllocOP) or dimshuffled_alloc(i)) for i in node.inputs]):
dimshuffled_alloc(i))
for i in node.inputs]):
return False return False
# Search for input that we can use as a baseline for the dimensions. # Search for input that we can use as a baseline for the dimensions.
...@@ -1761,9 +1752,8 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1761,9 +1752,8 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
if i.type.broadcastable == node.outputs[0].type.broadcastable: if i.type.broadcastable == node.outputs[0].type.broadcastable:
# Prefer an input that is not a AllocOP nor a DimShuffleOP of a # Prefer an input that is not a AllocOP nor a DimShuffleOP of a
# AllocOP so that all allocs can be optimized. # AllocOP so that all allocs can be optimized.
if not (i.owner if not (i.owner and (isinstance(i.owner.op, AllocOP) or
and (isinstance(i.owner.op, AllocOP) dimshuffled_alloc(i))):
or dimshuffled_alloc(i))):
assert_op_idx = idx assert_op_idx = idx
break break
...@@ -1773,8 +1763,8 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1773,8 +1763,8 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
# there is more than one then do all but one. number of # there is more than one then do all but one. number of
# inputs with alloc or dimshuffle alloc # inputs with alloc or dimshuffle alloc
l2 = [i for i in node.inputs l2 = [i for i in node.inputs
if (i.owner and (isinstance(i.owner.op, AllocOP) if (i.owner and (isinstance(i.owner.op, AllocOP) or
or dimshuffled_alloc(i)))] dimshuffled_alloc(i)))]
# If only 1 alloc or dimshuffle alloc, it is the one we # If only 1 alloc or dimshuffle alloc, it is the one we
# will use for the shape. So no alloc would be removed. # will use for the shape. So no alloc would be removed.
if len(l2) > 1: if len(l2) > 1:
...@@ -1794,14 +1784,13 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1794,14 +1784,13 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
same_shape = node.fgraph.shape_feature.same_shape same_shape = node.fgraph.shape_feature.same_shape
for i in node.inputs: for i in node.inputs:
# Remove alloc # Remove alloc
if (i.owner and isinstance(i.owner.op, AllocOP) if (i.owner and isinstance(i.owner.op, AllocOP) and
and i.owner.inputs[0].type != i.owner.outputs[0].type): i.owner.inputs[0].type != i.owner.outputs[0].type):
# when i.owner.inputs[0].type == i.owner.outputs[0].type we # when i.owner.inputs[0].type == i.owner.outputs[0].type we
# will remove that alloc later # will remove that alloc later
assert i.type.ndim == cmp_op.ndim assert i.type.ndim == cmp_op.ndim
if (theano.config.experimental.local_alloc_elemwise_assert if (theano.config.experimental.local_alloc_elemwise_assert and
and not same_shape(i, cmp_op)): not same_shape(i, cmp_op)):
assert_op = assert_(assert_op, assert_op = assert_(assert_op,
*[T.eq(i.shape[idx], cmp_op.shape[idx]) *[T.eq(i.shape[idx], cmp_op.shape[idx])
for idx in xrange(i.type.ndim) for idx in xrange(i.type.ndim)
...@@ -1909,12 +1898,12 @@ def local_upcast_elemwise_constant_inputs(node): ...@@ -1909,12 +1898,12 @@ def local_upcast_elemwise_constant_inputs(node):
i.ndim)) i.ndim))
else: else:
if shape_i is None: if shape_i is None:
return return new_inputs.append(
new_inputs.append(T.alloc(T.cast(cval_i, T.alloc(T.cast(cval_i, output_dtype),
output_dtype), *[shape_i(d)(i)
*[shape_i(d)(i) for d in xrange(i.ndim)])) for d in xrange(i.ndim)]))
#print >> sys.stderr, "AAA", # print >> sys.stderr, "AAA",
#*[Shape_i(d)(i) for d in xrange(i.ndim)] # *[Shape_i(d)(i) for d in xrange(i.ndim)]
except NotScalarConstantError: except NotScalarConstantError:
# for the case of a non-scalar # for the case of a non-scalar
if isinstance(i, T.TensorConstant): if isinstance(i, T.TensorConstant):
...@@ -1994,7 +1983,7 @@ def local_set_to_inc_subtensor(node): ...@@ -1994,7 +1983,7 @@ def local_set_to_inc_subtensor(node):
AdvancedIncSubtensor1(x, other, ilist, set_instead_of_inc=False) AdvancedIncSubtensor1(x, other, ilist, set_instead_of_inc=False)
""" """
if (isinstance(node.op, AdvancedIncSubtensor1) and if (isinstance(node.op, AdvancedIncSubtensor1) and
node.op.set_instead_of_inc == True and node.op.set_instead_of_inc and
node.inputs[1].owner and node.inputs[1].owner and
isinstance(node.inputs[1].owner.op, Elemwise) and isinstance(node.inputs[1].owner.op, Elemwise) and
isinstance(node.inputs[1].owner.op.scalar_op, scalar.Add)): isinstance(node.inputs[1].owner.op.scalar_op, scalar.Add)):
...@@ -2030,8 +2019,8 @@ def local_useless_slice(node): ...@@ -2030,8 +2019,8 @@ def local_useless_slice(node):
last_slice = len(slices) last_slice = len(slices)
for s in slices[::-1]: for s in slices[::-1]:
# check if slice and then check slice indices # check if slice and then check slice indices
if (isinstance(s, slice) and s.start is None and s.stop is None if (isinstance(s, slice) and s.start is None and s.stop is None and
and (s.step is None or T.extract_constant(s.step) == 1)): (s.step is None or T.extract_constant(s.step) == 1)):
last_slice -= 1 last_slice -= 1
else: else:
break break
...@@ -2101,8 +2090,7 @@ def local_useless_subtensor(node): ...@@ -2101,8 +2090,7 @@ def local_useless_subtensor(node):
T.ScalarFromTensor)): T.ScalarFromTensor)):
length_pos_shape_i = length_pos_shape_i.owner.inputs[0] length_pos_shape_i = length_pos_shape_i.owner.inputs[0]
elif (length_pos.owner and elif (length_pos.owner and
isinstance(length_pos.owner.op, isinstance(length_pos.owner.op, T.TensorFromScalar)):
T.TensorFromScalar)):
length_pos = length_pos.owner.inputs[0] length_pos = length_pos.owner.inputs[0]
else: else:
# We did not find underlying variables of the same type # We did not find underlying variables of the same type
...@@ -2346,8 +2334,7 @@ def merge_two_slices(slice1, len1, slice2, len2): ...@@ -2346,8 +2334,7 @@ def merge_two_slices(slice1, len1, slice2, len2):
stop = T.switch(T.lt(reverse2 * reverse1, 0), stop = T.switch(T.lt(reverse2 * reverse1, 0),
T.switch(T.lt(reverse1, 0), np_stop, pn_stop), T.switch(T.lt(reverse1, 0), np_stop, pn_stop),
T.switch(T.lt(reverse1, 0), nn_stop, pp_stop T.switch(T.lt(reverse1, 0), nn_stop, pp_stop))
))
step = T.switch(T.lt(reverse2 * reverse1, 0), n_step, p_step) step = T.switch(T.lt(reverse2 * reverse1, 0), n_step, p_step)
start = T.switch(T.le(flen, 0), 0, start) start = T.switch(T.le(flen, 0), 0, start)
...@@ -2540,7 +2527,8 @@ def local_subtensor_of_dot(node): ...@@ -2540,7 +2527,8 @@ def local_subtensor_of_dot(node):
# We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:] # We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:]
# (dot also handles b.ndim < 2 as a special case) # (dot also handles b.ndim < 2 as a special case)
if b.ndim > 1 and len(b_indices) >= b.ndim - 1: if b.ndim > 1 and len(b_indices) >= b.ndim - 1:
b_indices = b_indices[:b.ndim-2] + (slice(None, None, None),) + b_indices[b.ndim-2:] b_indices = (b_indices[:b.ndim - 2] +
(slice(None, None, None),) + b_indices[b.ndim - 2:])
a_sub = a.__getitem__(tuple(a_indices)) a_sub = a.__getitem__(tuple(a_indices))
b_sub = b.__getitem__(tuple(b_indices)) if b_indices else b b_sub = b.__getitem__(tuple(b_indices)) if b_indices else b
...@@ -2583,14 +2571,13 @@ def local_IncSubtensor_serialize(node): ...@@ -2583,14 +2571,13 @@ def local_IncSubtensor_serialize(node):
""" """
def movable(i): def movable(i):
# Return True iff this is a incsubtensor that we can move # Return True iff this is a incsubtensor that we can move
return i.owner \ return (i.owner and
and isinstance(i.owner.op, (IncSubtensor, isinstance(i.owner.op, (IncSubtensor,
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
AdvancedIncSubtensor, AdvancedIncSubtensor,)) and
)) \ i.type == o_type and
and i.type == o_type \ len(i.clients) == 1 and
and len(i.clients) == 1 \ not i.owner.op.set_instead_of_inc)
and not i.owner.op.set_instead_of_inc
if node.op == T.add: if node.op == T.add:
o_type = node.outputs[0].type o_type = node.outputs[0].type
...@@ -2598,8 +2585,8 @@ def local_IncSubtensor_serialize(node): ...@@ -2598,8 +2585,8 @@ def local_IncSubtensor_serialize(node):
movable_inputs = [i for i in node.inputs if movable(i)] movable_inputs = [i for i in node.inputs if movable(i)]
if movable_inputs: if movable_inputs:
new_inputs = [i for i in node.inputs if not movable(i)] \ new_inputs = ([i for i in node.inputs if not movable(i)] +
+ [mi.owner.inputs[0] for mi in movable_inputs] [mi.owner.inputs[0] for mi in movable_inputs])
new_add = T.add(*new_inputs) new_add = T.add(*new_inputs)
# stack up the new incsubtensors # stack up the new incsubtensors
...@@ -2638,9 +2625,10 @@ def local_inplace_setsubtensor(node): ...@@ -2638,9 +2625,10 @@ def local_inplace_setsubtensor(node):
return [new_node] return [new_node]
return False return False
compile.optdb.register('local_inplace_setsubtensor', compile.optdb.register('local_inplace_setsubtensor',
TopoOptimizer(local_inplace_setsubtensor, TopoOptimizer(
failure_callback=TopoOptimizer.warn_inplace), 60, local_inplace_setsubtensor,
'fast_run', 'inplace') # DEBUG failure_callback=TopoOptimizer.warn_inplace),
60, 'fast_run', 'inplace') # DEBUG
@gof.local_optimizer([AdvancedIncSubtensor1], inplace=True) @gof.local_optimizer([AdvancedIncSubtensor1], inplace=True)
...@@ -2749,11 +2737,11 @@ def local_adv_sub1_adv_inc_sub1(node): ...@@ -2749,11 +2737,11 @@ def local_adv_sub1_adv_inc_sub1(node):
if (not inp.owner.op.set_instead_of_inc and if (not inp.owner.op.set_instead_of_inc and
T.extract_constant(x) != 0): T.extract_constant(x) != 0):
return return
cond = [T.all(T.and_(T.lt(idx, x.shape[0]), cond = [T.all(T.and_(T.lt(idx, x.shape[0]), T.ge(idx, -x.shape[0])))]
T.ge(idx, -x.shape[0])))]
if not node.fgraph.shape_feature.same_shape(idx, y, 0, 0): if not node.fgraph.shape_feature.same_shape(idx, y, 0, 0):
cond.append(T.eq(idx.shape[0], y.shape[0])) cond.append(T.eq(idx.shape[0], y.shape[0]))
y = Assert("Bad indexing or shapes in a AdvancedIncSubtensor1 that was optimized away")(y, *cond) y = Assert("Bad indexing or shapes in a AdvancedIncSubtensor1 "
"that was optimized away")(y, *cond)
if y.dtype == node.outputs[0].dtype: if y.dtype == node.outputs[0].dtype:
return [y] return [y]
...@@ -2828,7 +2816,8 @@ def local_useless_inc_subtensor_alloc(node): ...@@ -2828,7 +2816,8 @@ def local_useless_inc_subtensor_alloc(node):
# Build `z_broad` explicitly to include extra implicit dimensions. # Build `z_broad` explicitly to include extra implicit dimensions.
z_broad = ((True,) * (xi.ndim - z.ndim) + z.broadcastable) z_broad = ((True,) * (xi.ndim - z.ndim) + z.broadcastable)
cond = [# The shapes of `y` and `xi` must either agree or `y` may cond = [
# The shapes of `y` and `xi` must either agree or `y` may
# also have shape equal to 1 which may be treated as a # also have shape equal to 1 which may be treated as a
# broadcastable dimension by the subtensor op. # broadcastable dimension by the subtensor op.
T.or_(T.eq(y.shape[k], 1), T.eq(y.shape[k], xi.shape[k])) T.or_(T.eq(y.shape[k], 1), T.eq(y.shape[k], xi.shape[k]))
...@@ -3552,9 +3541,9 @@ class Canonizer(gof.LocalOptimizer): ...@@ -3552,9 +3541,9 @@ class Canonizer(gof.LocalOptimizer):
# the num/denum of its input # the num/denum of its input
dsn = input.owner # dimshuffle node dsn = input.owner # dimshuffle node
dsop = dsn.op # dimshuffle op dsop = dsn.op # dimshuffle op
dsi0 = dsn.inputs[0] # the first input of the
# dimshuffle i.e. the ndarray to # the first input of the dimshuffle i.e. the ndarray to redim
# redim dsi0 = dsn.inputs[0]
# The compatible order is a DimShuffle "new_order" of the form: # The compatible order is a DimShuffle "new_order" of the form:
# ('x', ..., 'x', 0, 1, 2, ..., dimshuffle_input.type.ndim) # ('x', ..., 'x', 0, 1, 2, ..., dimshuffle_input.type.ndim)
...@@ -3566,9 +3555,9 @@ class Canonizer(gof.LocalOptimizer): ...@@ -3566,9 +3555,9 @@ class Canonizer(gof.LocalOptimizer):
# different numbers of dimensions (hence why we can # different numbers of dimensions (hence why we can
# discard its information - we know we can retrieve it # discard its information - we know we can retrieve it
# later on). # later on).
compatible_order = ('x',) * (input.type.ndim compatible_order = (('x',) *
- dsi0.type.ndim) + tuple( (input.type.ndim - dsi0.type.ndim) +
range(dsi0.type.ndim)) tuple(range(dsi0.type.ndim)))
if dsop.new_order == compatible_order: if dsop.new_order == compatible_order:
# If the "new_order" is the one we recognize, # If the "new_order" is the one we recognize,
# we return the num_denum of the dimshuffled input. # we return the num_denum of the dimshuffled input.
...@@ -3815,7 +3804,7 @@ class Canonizer(gof.LocalOptimizer): ...@@ -3815,7 +3804,7 @@ class Canonizer(gof.LocalOptimizer):
new = self.merge_num_denum(num, denum) new = self.merge_num_denum(num, denum)
if new.type.dtype != out.type.dtype: if new.type.dtype != out.type.dtype:
#new = T.fill(out, new) # new = T.fill(out, new)
elem_op = T.Elemwise(scalar.Identity(scalar.specific_out( elem_op = T.Elemwise(scalar.Identity(scalar.specific_out(
getattr(scalar, out.type.dtype)))) getattr(scalar, out.type.dtype))))
new = elem_op(new) new = elem_op(new)
...@@ -3924,10 +3913,10 @@ def local_elemwise_sub_zeros(node): ...@@ -3924,10 +3913,10 @@ def local_elemwise_sub_zeros(node):
""" """
Elemwise{sub}(X,X) -> zeros_like(X) Elemwise{sub}(X,X) -> zeros_like(X)
""" """
if (isinstance(node.op, T.Elemwise) if (isinstance(node.op, T.Elemwise) and
and node.op.scalar_op.nin == 2 node.op.scalar_op.nin == 2 and
and node.op.scalar_op == scalar.sub node.op.scalar_op == scalar.sub and
and node.inputs[0] == node.inputs[1]): node.inputs[0] == node.inputs[1]):
return [T.zeros_like(node.inputs[0])] return [T.zeros_like(node.inputs[0])]
...@@ -4014,8 +4003,7 @@ def local_sum_div_dimshuffle(node): ...@@ -4014,8 +4003,7 @@ def local_sum_div_dimshuffle(node):
new_denom = T.DimShuffle( new_denom = T.DimShuffle(
thing_dimshuffled.type.broadcastable, thing_dimshuffled.type.broadcastable,
new_new_order new_new_order)(thing_dimshuffled)
)(thing_dimshuffled)
return [T.true_div(node.op(numerator), new_denom)] return [T.true_div(node.op(numerator), new_denom)]
# else: # else:
# print 'incompatible dims:', axis, new_order # print 'incompatible dims:', axis, new_order
...@@ -4052,8 +4040,9 @@ def local_op_of_op(node): ...@@ -4052,8 +4040,9 @@ def local_op_of_op(node):
# We manipulate the graph so this is done to make sure the opt # We manipulate the graph so this is done to make sure the opt
# doesn't affect other computations. # doesn't affect other computations.
if len(node_inps.clients) == 1: if len(node_inps.clients) == 1:
if (node_inps.owner and (isinstance(node_inps.owner.op, T.elemwise.Prod) if (node_inps.owner and
or isinstance(node_inps.owner.op, T.elemwise.Sum))): (isinstance(node_inps.owner.op, T.elemwise.Prod) or
isinstance(node_inps.owner.op, T.elemwise.Sum))):
# check to see either the inner or outer prod is doing a # check to see either the inner or outer prod is doing a
# product over all axis, in which case we can remove it # product over all axis, in which case we can remove it
...@@ -4074,7 +4063,6 @@ def local_op_of_op(node): ...@@ -4074,7 +4063,6 @@ def local_op_of_op(node):
assert len(newaxis) == len(list(node_inps.owner.op.axis) + assert len(newaxis) == len(list(node_inps.owner.op.axis) +
list(node.op.axis)) list(node.op.axis))
# The old bugged logic. We keep it there to generate a warning # The old bugged logic. We keep it there to generate a warning
# when we generated bad code. # when we generated bad code.
alldims = list(range(node_inps.owner.inputs[0].type.ndim)) alldims = list(range(node_inps.owner.inputs[0].type.ndim))
...@@ -4128,7 +4116,6 @@ def local_reduce_join(node): ...@@ -4128,7 +4116,6 @@ def local_reduce_join(node):
if (isinstance(node.op, T.CAReduce) and if (isinstance(node.op, T.CAReduce) and
node.inputs[0].owner and node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, T.Join)): isinstance(node.inputs[0].owner.op, T.Join)):
join = node.inputs[0].owner join = node.inputs[0].owner
if T.extract_constant(join.inputs[0]) != 0: if T.extract_constant(join.inputs[0]) != 0:
return return
...@@ -4149,7 +4136,8 @@ def local_reduce_join(node): ...@@ -4149,7 +4136,8 @@ def local_reduce_join(node):
if not inp: if not inp:
return return
if (not isinstance(inp.op, DimShuffle) or if (not isinstance(inp.op, DimShuffle) or
inp.op.new_order != ('x',) + tuple(range(inp.inputs[0].ndim))): inp.op.new_order != ('x',) +
tuple(range(inp.inputs[0].ndim))):
return return
new_inp.append(inp.inputs[0]) new_inp.append(inp.inputs[0])
ret = Elemwise(node.op.scalar_op)(*new_inp) ret = Elemwise(node.op.scalar_op)(*new_inp)
...@@ -4174,8 +4162,7 @@ def local_reduce_join(node): ...@@ -4174,8 +4162,7 @@ def local_reduce_join(node):
'optimization, that modified the pattern ' 'optimization, that modified the pattern '
'"Reduce{scalar.op}(Join(axis=0, a, b), axis=0)", ' '"Reduce{scalar.op}(Join(axis=0, a, b), axis=0)", '
'did not check the reduction axis. So if the ' 'did not check the reduction axis. So if the '
'reduction axis was not 0, you got a wrong answer.' 'reduction axis was not 0, you got a wrong answer.'))
))
return return
# We add the new check late to don't add extra warning. # We add the new check late to don't add extra warning.
...@@ -4204,7 +4191,7 @@ def local_cut_useless_reduce(node): ...@@ -4204,7 +4191,7 @@ def local_cut_useless_reduce(node):
# theano/tensor/tests/test_opt.py:T_local_reduce.test_local_reduce_broadcast_some_0 # theano/tensor/tests/test_opt.py:T_local_reduce.test_local_reduce_broadcast_some_0
# see gh-790 issue. # see gh-790 issue.
# #
#@register_canonicalize # @register_canonicalize
@register_uncanonicalize @register_uncanonicalize
@register_specialize @register_specialize
@gof.local_optimizer(ALL_REDUCE) @gof.local_optimizer(ALL_REDUCE)
...@@ -4501,7 +4488,8 @@ def local_pow_specialize_device(node): ...@@ -4501,7 +4488,8 @@ def local_pow_specialize_device(node):
if abs(y) > 2: if abs(y) > 2:
# We fuse all the pow together here to make # We fuse all the pow together here to make
# compilation faster # compilation faster
rval1 = Elemwise(theano.scalar.Composite( rval1 = Elemwise(
theano.scalar.Composite(
[pow2_scal[0]], [rval1_scal])).make_node(xsym) [pow2_scal[0]], [rval1_scal])).make_node(xsym)
if y < 0: if y < 0:
rval = [T.inv(rval1)] rval = [T.inv(rval1)]
...@@ -4640,8 +4628,8 @@ def check_for_x_over_absX(numerators, denominators): ...@@ -4640,8 +4628,8 @@ def check_for_x_over_absX(numerators, denominators):
# TODO: this function should dig/search through dimshuffles # TODO: this function should dig/search through dimshuffles
# This won't catch a dimshuffled absolute value # This won't catch a dimshuffled absolute value
for den in list(denominators): for den in list(denominators):
if (den.owner and den.owner.op == T.abs_ if (den.owner and den.owner.op == T.abs_ and
and den.owner.inputs[0] in numerators): den.owner.inputs[0] in numerators):
if den.owner.inputs[0].type.dtype.startswith('complex'): if den.owner.inputs[0].type.dtype.startswith('complex'):
# TODO: Make an Op that projects a complex number to # TODO: Make an Op that projects a complex number to
# have unit length but projects 0 to 0. That # have unit length but projects 0 to 0. That
...@@ -4715,8 +4703,8 @@ def local_log1p(node): ...@@ -4715,8 +4703,8 @@ def local_log1p(node):
if node.op == T.log: if node.op == T.log:
log_arg, = node.inputs log_arg, = node.inputs
if log_arg.owner and log_arg.owner.op == T.add: if log_arg.owner and log_arg.owner.op == T.add:
scalars, scalar_inputs, nonconsts = \ scalars, scalar_inputs, nonconsts = scalarconsts_rest(
scalarconsts_rest(log_arg.owner.inputs) log_arg.owner.inputs)
# scalar_inputs are potentially dimshuffled and fill'd scalars # scalar_inputs are potentially dimshuffled and fill'd scalars
if scalars and numpy.allclose(numpy.sum(scalars), 1): if scalars and numpy.allclose(numpy.sum(scalars), 1):
if not nonconsts: if not nonconsts:
...@@ -4748,7 +4736,7 @@ def local_log_add(node): ...@@ -4748,7 +4736,7 @@ def local_log_add(node):
if len(zi) != 2: if len(zi) != 2:
# -- upgrading Maximum to handle multiple inputs wasn't trivial # -- upgrading Maximum to handle multiple inputs wasn't trivial
# TODO # TODO
#raise NotImplementedError() # raise NotImplementedError()
return return
pre_exp = [x.owner.inputs[0] for x in zi pre_exp = [x.owner.inputs[0] for x in zi
if x.owner and x.owner.op == T.exp] if x.owner and x.owner.op == T.exp]
...@@ -4946,7 +4934,6 @@ def constant_folding(node): ...@@ -4946,7 +4934,6 @@ def constant_folding(node):
compute_map[o] = [False] compute_map[o] = [False]
if (hasattr(node.op, 'python_constant_folding') and if (hasattr(node.op, 'python_constant_folding') and
node.op.python_constant_folding(node)): node.op.python_constant_folding(node)):
old_value = getattr(node.op, '_op_use_c_code', False) old_value = getattr(node.op, '_op_use_c_code', False)
try: try:
node.op._op_use_c_code = False node.op._op_use_c_code = False
...@@ -5058,7 +5045,7 @@ register_canonicalize(local_one_plus_neg_erf) ...@@ -5058,7 +5045,7 @@ register_canonicalize(local_one_plus_neg_erf)
register_stabilize(local_one_plus_neg_erf) register_stabilize(local_one_plus_neg_erf)
register_specialize(local_one_plus_neg_erf) register_specialize(local_one_plus_neg_erf)
#(-1)+erf(x) => -erfc(x) don't need erf(x)+(-1) as the canonicalize # (-1)+erf(x) => -erfc(x) don't need erf(x)+(-1) as the canonicalize
# will put the -1 as the first argument. # will put the -1 as the first argument.
local_erf_minus_one = gof.PatternSub((T.add, local_erf_minus_one = gof.PatternSub((T.add,
dict(pattern='y', constraint=_is_minus1), dict(pattern='y', constraint=_is_minus1),
...@@ -5124,7 +5111,7 @@ register_canonicalize(local_one_add_neg_erfc) ...@@ -5124,7 +5111,7 @@ register_canonicalize(local_one_add_neg_erfc)
register_stabilize(local_one_add_neg_erfc) register_stabilize(local_one_add_neg_erfc)
register_specialize(local_one_add_neg_erfc) register_specialize(local_one_add_neg_erfc)
#(-1)+erfc(-x)=>erf(x) # (-1)+erfc(-x)=>erf(x)
local_erf_neg_minus_one = gof.PatternSub((T.add, local_erf_neg_minus_one = gof.PatternSub((T.add,
dict(pattern='y', constraint=_is_minus1), dict(pattern='y', constraint=_is_minus1),
(T.erfc, (T.neg, 'x'))), (T.erfc, (T.neg, 'x'))),
...@@ -5137,7 +5124,7 @@ register_canonicalize(local_erf_neg_minus_one) ...@@ -5137,7 +5124,7 @@ register_canonicalize(local_erf_neg_minus_one)
register_stabilize(local_erf_neg_minus_one) register_stabilize(local_erf_neg_minus_one)
register_specialize(local_erf_neg_minus_one) register_specialize(local_erf_neg_minus_one)
#(-1)+erfc(-1*x)=>erf(x) # (-1)+erfc(-1*x)=>erf(x)
local_erf_neg_minus_one2 = gof.PatternSub((T.add, local_erf_neg_minus_one2 = gof.PatternSub((T.add,
dict(pattern='y', constraint=_is_minus1), dict(pattern='y', constraint=_is_minus1),
(T.erfc, (T.mul, -1, 'x'))), (T.erfc, (T.mul, -1, 'x'))),
...@@ -5176,8 +5163,8 @@ def local_log_erfc(node): ...@@ -5176,8 +5163,8 @@ def local_log_erfc(node):
x = node.inputs[0].owner.inputs[0] x = node.inputs[0].owner.inputs[0]
stab_value = (-x ** 2 - T.log(x) - .5 * T.log(numpy.pi) + stab_value = (-x ** 2 - T.log(x) - .5 * T.log(numpy.pi) +
T.log(1 - 1 / (2 * x ** 2) + 3 / (4 * x ** 4) T.log(1 - 1 / (2 * x ** 2) + 3 / (4 * x ** 4) -
- 15 / (8 * x ** 6))) 15 / (8 * x ** 6)))
if (node.outputs[0].dtype == 'float32' or if (node.outputs[0].dtype == 'float32' or
node.outputs[0].dtype == 'float16'): node.outputs[0].dtype == 'float16'):
...@@ -5191,8 +5178,8 @@ def local_log_erfc(node): ...@@ -5191,8 +5178,8 @@ def local_log_erfc(node):
# Stability optimization of the grad of log(erfc(x)) # Stability optimization of the grad of log(erfc(x))
#([y*]exp(-(x**2)))/erfc(x) # The y* is optional # ([y*]exp(-(x**2)))/erfc(x) # The y* is optional
#([y*]exp(x**2))/erfc(-x) => [y*](when x>threashold, # ([y*]exp(x**2))/erfc(-x) => [y*](when x>threashold,
# sqrt(pi)*-x/(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6))) # sqrt(pi)*-x/(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6)))
# for float64: threshold=26.63 see at the end of the fct for the explaination # for float64: threshold=26.63 see at the end of the fct for the explaination
# for float32: threshold=9.3 see at the end of the fct for the explaination # for float32: threshold=9.3 see at the end of the fct for the explaination
...@@ -5226,8 +5213,8 @@ def local_grad_log_erfc_neg(node): ...@@ -5226,8 +5213,8 @@ def local_grad_log_erfc_neg(node):
if mul.owner.inputs[0].owner or len(mul.owner.inputs) != 2: if mul.owner.inputs[0].owner or len(mul.owner.inputs) != 2:
return False return False
y = mul.owner.inputs[0] y = mul.owner.inputs[0]
if (not mul.owner.inputs[1].owner if (not mul.owner.inputs[1].owner or
or mul.owner.inputs[1].owner.op != T.exp): mul.owner.inputs[1].owner.op != T.exp):
return False return False
exp = mul.owner.inputs[1] exp = mul.owner.inputs[1]
...@@ -5236,8 +5223,8 @@ def local_grad_log_erfc_neg(node): ...@@ -5236,8 +5223,8 @@ def local_grad_log_erfc_neg(node):
if exp.owner.inputs[0].owner.op == T.neg: if exp.owner.inputs[0].owner.op == T.neg:
neg = exp.owner.inputs[0] neg = exp.owner.inputs[0]
if (not neg.owner.inputs[0].owner if (not neg.owner.inputs[0].owner or
or neg.owner.inputs[0].owner.op != T.sqr): neg.owner.inputs[0].owner.op != T.sqr):
return False return False
sqr = neg.owner.inputs[0] sqr = neg.owner.inputs[0]
x = sqr.owner.inputs[0] x = sqr.owner.inputs[0]
...@@ -5279,8 +5266,8 @@ def local_grad_log_erfc_neg(node): ...@@ -5279,8 +5266,8 @@ def local_grad_log_erfc_neg(node):
return False return False
if len(mul_neg.owner.inputs) == 2: if len(mul_neg.owner.inputs) == 2:
if (not mul_neg.owner.inputs[1].owner if (not mul_neg.owner.inputs[1].owner or
or mul_neg.owner.inputs[1].owner.op != T.sqr): mul_neg.owner.inputs[1].owner.op != T.sqr):
return False return False
sqr = mul_neg.owner.inputs[1] sqr = mul_neg.owner.inputs[1]
x = sqr.owner.inputs[0] x = sqr.owner.inputs[0]
...@@ -5292,8 +5279,8 @@ def local_grad_log_erfc_neg(node): ...@@ -5292,8 +5279,8 @@ def local_grad_log_erfc_neg(node):
return False return False
if cst2 != -1: if cst2 != -1:
if (not erfc_x.owner or erfc_x.owner.op != T.mul if (not erfc_x.owner or erfc_x.owner.op != T.mul or
or len(erfc_x.owner.inputs) != 2): len(erfc_x.owner.inputs) != 2):
# todo implement that case # todo implement that case
return False return False
if erfc_x.owner.inputs[1] is not mul_neg.owner.inputs[1]: if erfc_x.owner.inputs[1] is not mul_neg.owner.inputs[1]:
...@@ -5324,12 +5311,12 @@ def local_grad_log_erfc_neg(node): ...@@ -5324,12 +5311,12 @@ def local_grad_log_erfc_neg(node):
# aaron value # aaron value
stab_value = (x * T.pow(1 - 1 / (2 * (x ** 2)) + stab_value = (x * T.pow(1 - 1 / (2 * (x ** 2)) +
3 / (4 * (x ** 4)) - 15 / (8 * (x ** 6)), -1) 3 / (4 * (x ** 4)) - 15 / (8 * (x ** 6)), -1) *
* T.cast(T.sqrt(numpy.pi), dtype=x.dtype)) T.cast(T.sqrt(numpy.pi), dtype=x.dtype))
if x.dtype == 'float32' or x.dtype == 'float16': if x.dtype == 'float32' or x.dtype == 'float16':
threshold = 9.3 threshold = 9.3
#threshold = 10.1 # threshold = 10.1
elif x.dtype == 'float64': elif x.dtype == 'float64':
threshold = 26.641747557 threshold = 26.641747557
ret = T.switch(x < threshold, true_div_no_mul, stab_value) * y ret = T.switch(x < threshold, true_div_no_mul, stab_value) * y
...@@ -5531,6 +5518,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, ...@@ -5531,6 +5518,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
if maker is None: if maker is None:
def maker(node, scalar_op): def maker(node, scalar_op):
return OP(scalar_op) return OP(scalar_op)
def local_fuse(node): def local_fuse(node):
""" """
As part of specialization, we fuse two consecutive elemwise Ops of the As part of specialization, we fuse two consecutive elemwise Ops of the
...@@ -5603,8 +5591,8 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, ...@@ -5603,8 +5591,8 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
# Do not merge elemwise that don't have the same # Do not merge elemwise that don't have the same
# broadcastable pattern to don't redo duplicate # broadcastable pattern to don't redo duplicate
# computation due to broadcast. # computation due to broadcast.
i.owner.outputs[0].broadcastable == node.outputs[0].broadcastable): i.owner.outputs[0].broadcastable ==
node.outputs[0].broadcastable):
do_fusion = True do_fusion = True
try: try:
tmp_s_input = [] tmp_s_input = []
...@@ -5882,13 +5870,15 @@ else: ...@@ -5882,13 +5870,15 @@ else:
# just returns the input, it should be removed from the graph to # just returns the input, it should be removed from the graph to
# make sure all possible optimizations can be applied. # make sure all possible optimizations can be applied.
register_canonicalize(gof.OpRemove(theano.gradient.consider_constant_), register_canonicalize(gof.OpRemove(theano.gradient.consider_constant_),
'fast_compile', 'fast_run', name='remove_consider_constant') 'fast_compile', 'fast_run',
name='remove_consider_constant')
register_canonicalize(gof.OpRemove(theano.gradient.zero_grad_), register_canonicalize(gof.OpRemove(theano.gradient.zero_grad_),
'fast_compile', 'fast_run', name='remove_zero_grad') 'fast_compile', 'fast_run', name='remove_zero_grad')
register_canonicalize(gof.OpRemove(theano.gradient.disconnected_grad_), register_canonicalize(gof.OpRemove(theano.gradient.disconnected_grad_),
'fast_compile', 'fast_run', name='remove_disconnected_grad') 'fast_compile', 'fast_run',
name='remove_disconnected_grad')
@register_canonicalize @register_canonicalize
......
...@@ -63,7 +63,6 @@ whitelist_flake8 = [ ...@@ -63,7 +63,6 @@ whitelist_flake8 = [
"tensor/sort.py", "tensor/sort.py",
"tensor/__init__.py", "tensor/__init__.py",
"tensor/opt_uncanonicalize.py", "tensor/opt_uncanonicalize.py",
"tensor/opt.py",
"tensor/blas.py", "tensor/blas.py",
"tensor/extra_ops.py", "tensor/extra_ops.py",
"tensor/nlinalg.py", "tensor/nlinalg.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论