提交 d40861ec authored 作者: Iban Harlouchet's avatar Iban Harlouchet

Flake8 for theano/tensor/opt.py

上级 34b98041
......@@ -6,8 +6,6 @@ from __future__ import print_function
# TODO: 0*x -> 0
import logging
_logger = logging.getLogger('theano.tensor.opt')
import itertools
import operator
import sys
......@@ -34,12 +32,10 @@ from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice,
Subtensor, IncSubtensor, make_constant,
AdvancedIncSubtensor1,
AdvancedIncSubtensor,
AdvancedSubtensor,
AdvancedSubtensor1,
advanced_subtensor,
advanced_subtensor1,
advanced_inc_subtensor1,
inc_subtensor)
advanced_inc_subtensor1)
from theano import scalar
from theano.scalar import basic
from theano.tensor import basic as T
......@@ -56,6 +52,8 @@ from theano.gof import toolbox
from theano.tensor.basic import get_scalar_constant_value, ShapeError, NotScalarConstantError
from six import StringIO
_logger = logging.getLogger('theano.tensor.opt')
theano.configparser.AddConfigVar('on_shape_error',
"warn: print a warning and use the default"
" value. raise: raise an error",
......@@ -165,23 +163,24 @@ def broadcast_like(value, template, fgraph, dtype=None):
# the template may have 1s in its shape without being broadcastable
if rval.broadcastable != template.broadcastable:
rval = T.unbroadcast(rval, *[i for i in xrange(rval.ndim)
if rval.broadcastable[i]
and not template.broadcastable[i]])
if rval.broadcastable[i] and
not template.broadcastable[i]])
assert rval.type.dtype == dtype
if rval.type.broadcastable != template.broadcastable:
raise AssertionError("rval.type.broadcastable is " +
str(rval.type.broadcastable) +
" but template.broadcastable is" +
str(template.broadcastable))
str(rval.type.broadcastable) +
" but template.broadcastable is" +
str(template.broadcastable))
return rval
theano.configparser.AddConfigVar('tensor.insert_inplace_optimizer_validate_nb',
"-1: auto, if graph have less then 500 nodes 1, else 10",
theano.configparser.IntParam(-1),
in_c_key=False)
theano.configparser.AddConfigVar(
'tensor.insert_inplace_optimizer_validate_nb',
"-1: auto, if graph have less then 500 nodes 1, else 10",
theano.configparser.IntParam(-1),
in_c_key=False)
def inplace_elemwise_optimizer_op(OP):
......@@ -251,11 +250,10 @@ def inplace_elemwise_optimizer_op(OP):
# target.
# Remove here as faster.
candidate_inputs = [i for i in xrange(len(node.inputs))
if i not in baseline.values() \
and not isinstance(node.inputs[i],
Constant)\
and not fgraph.destroyers(node.inputs[i])\
and node.inputs[i] not in protected_inputs]
if i not in baseline.values() and
not isinstance(node.inputs[i], Constant) and
not fgraph.destroyers(node.inputs[i]) and
node.inputs[i] not in protected_inputs]
verbose = False
......@@ -265,7 +263,7 @@ def inplace_elemwise_optimizer_op(OP):
for candidate_input in candidate_inputs:
# remove inputs that don't have the same dtype as the output
if node.inputs[candidate_input].type != node.outputs[
candidate_output].type:
candidate_output].type:
continue
inplace_pattern = dict(baseline)
......@@ -274,20 +272,20 @@ def inplace_elemwise_optimizer_op(OP):
if hasattr(op.scalar_op, "make_new_inplace"):
new_scal = op.scalar_op.make_new_inplace(
scalar.transfer_type(
*[inplace_pattern.get(i, None) \
for i in xrange(len(node.outputs))]))
*[inplace_pattern.get(i, None)
for i in xrange(len(node.outputs))]))
else:
new_scal = op.scalar_op.__class__(
scalar.transfer_type(
*[inplace_pattern.get(i, None) \
for i in xrange(len(node.outputs))]))
*[inplace_pattern.get(i, None)
for i in xrange(len(node.outputs))]))
new_outputs = OP(new_scal, inplace_pattern)(
*node.inputs, **dict(return_list=True))
*node.inputs, **dict(return_list=True))
new_node = new_outputs[0].owner
for r, new_r in zip(node.outputs, new_outputs):
fgraph.replace(r, new_r,
reason="inplace_elemwise_optimizer")
reason="inplace_elemwise_optimizer")
nb_change_no_validate += 1
if nb_change_no_validate >= check_each_change:
fgraph.validate()
......@@ -295,9 +293,9 @@ def inplace_elemwise_optimizer_op(OP):
nb_change_no_validate = 0
except (ValueError, TypeError, InconsistencyError) as e:
if check_each_change != 1 and not raised_warning:
print((
"Some inplace optimization was not "
"performed due to unexpected error:"), file=sys.stderr)
print(("Some inplace optimization was not "
"performed due to unexpected error:"),
file=sys.stderr)
print(e, file=sys.stderr)
raised_warning = True
fgraph.revert(chk)
......@@ -313,7 +311,8 @@ def inplace_elemwise_optimizer_op(OP):
except Exception:
if not raised_warning:
print(("Some inplace optimization was not "
"performed due to unexpected error"), file=sys.stderr)
"performed due to unexpected error"),
file=sys.stderr)
fgraph.revert(chk)
return inplace_elemwise_optimizer
......@@ -381,8 +380,8 @@ def register_specialize_device(lopt, *tags, **kwargs):
# Register merge_optimizer as a global opt during canonicalize
compile.optdb['canonicalize'].register(
'canon_merge', merge_optimizer, 'fast_run', final_opt=True)
compile.optdb['canonicalize'].register('canon_merge', merge_optimizer,
'fast_run', final_opt=True)
#####################
......@@ -512,11 +511,10 @@ def local_lift_transpose_through_dot(node):
inplace. The newly-introduced transpositions are not inplace, this will
be taken care of in a later optimization phase.
"""
if not (isinstance(node.op, T.DimShuffle)
and node.op.new_order == (1, 0)):
if not (isinstance(node.op, T.DimShuffle) and node.op.new_order == (1, 0)):
return False
if not (node.inputs[0].owner
and isinstance(node.inputs[0].owner.op, T.Dot)):
if not (node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, T.Dot)):
return False
x, y = node.inputs[0].owner.inputs
......@@ -601,22 +599,19 @@ class MakeVector(T.Op):
def make_node(self, *inputs):
inputs = list(map(T.as_tensor_variable, inputs))
if not all(a.type == inputs[0].type for a in inputs) or (
len(inputs) > 0 and inputs[0].dtype != self.dtype):
dtype = theano.scalar.upcast(self.dtype,
*[i.dtype for i in inputs])
if (not all(a.type == inputs[0].type for a in inputs) or
(len(inputs) > 0 and inputs[0].dtype != self.dtype)):
dtype = theano.scalar.upcast(self.dtype, *[i.dtype for i in inputs])
# upcast the input to the determined dtype,
# but don't downcast anything
assert dtype == self.dtype, (
"The upcast of the inputs to MakeVector should match the "
"dtype given in __init__.")
"The upcast of the inputs to MakeVector should match the "
"dtype given in __init__.")
if not all(self.dtype == T.cast(i, dtype=dtype).dtype
for i in inputs):
raise TypeError("MakeVector.make_node expected inputs"
" upcastable to %s. got %s" % (
self.dtype,
str([i.dtype for i in inputs])
))
" upcastable to %s. got %s" %
(self.dtype, str([i.dtype for i in inputs])))
inputs = [T.cast(i, dtype=dtype) for i in inputs]
assert all(self.dtype == a.dtype for a in inputs)
assert all(a.ndim == 0 for a in inputs)
......@@ -625,11 +620,9 @@ class MakeVector(T.Op):
dtype = inputs[0].type.dtype
else:
dtype = self.dtype
#bcastable = (len(inputs) == 1)
# bcastable = (len(inputs) == 1)
bcastable = False
otype = T.TensorType(
broadcastable=(bcastable,),
dtype=dtype)
otype = T.TensorType(broadcastable=(bcastable,), dtype=dtype)
return T.Apply(self, inputs, [otype()])
def __str__(self):
......@@ -700,13 +693,14 @@ class MakeVectorPrinter:
if r.owner is None:
raise TypeError("Can only print make_vector.")
elif isinstance(r.owner.op, MakeVector):
return "[%s]" % ", ".join(pstate.pprinter.process(
input, pstate.clone(precedence=1000)) for input
in r.owner.inputs)
return "[%s]" % ", ".join(
pstate.pprinter.process(input, pstate.clone(precedence=1000))
for input in r.owner.inputs)
else:
raise TypeError("Can only print make_vector.")
T.pprint.assign(lambda pstate, r: r.owner and isinstance(
r.owner.op, MakeVector), MakeVectorPrinter())
T.pprint.assign(lambda pstate, r: r.owner and
isinstance(r.owner.op, MakeVector), MakeVectorPrinter())
class ShapeFeature(object):
......@@ -843,8 +837,8 @@ class ShapeFeature(object):
# by always returning the same object to represent 1
return self.lscalar_one
if (type(s_i) in integer_types or
isinstance(s_i, numpy.integer) or
(isinstance(s_i, numpy.ndarray) and s_i.ndim == 0)):
isinstance(s_i, numpy.integer) or
(isinstance(s_i, numpy.ndarray) and s_i.ndim == 0)):
# this shape is a constant
assert s_i >= 0
return T.constant(s_i, dtype='int64')
......@@ -859,9 +853,9 @@ class ShapeFeature(object):
# s_i is x.shape[i], we change it to Shape_i.
if (s_i.owner and
isinstance(s_i.owner.op, Subtensor) and
s_i.owner.inputs[0].owner and
isinstance(s_i.owner.inputs[0].owner.op, T.Shape)):
isinstance(s_i.owner.op, Subtensor) and
s_i.owner.inputs[0].owner and
isinstance(s_i.owner.inputs[0].owner.op, T.Shape)):
assert s_i.ndim == 0
assert len(s_i.owner.op.idx_list) == 1
......@@ -883,7 +877,7 @@ class ShapeFeature(object):
return s_i
else:
raise TypeError('Unsupported shape element',
s_i, type(s_i), getattr(s_i, 'type', None))
s_i, type(s_i), getattr(s_i, 'type', None))
def set_shape(self, r, s):
"""Assign the shape `s` to previously un-shaped variable `r`.
......@@ -910,7 +904,7 @@ class ShapeFeature(object):
shape_vars = []
for i in xrange(r.ndim):
if (hasattr(r.type, 'broadcastable') and
r.type.broadcastable[i]):
r.type.broadcastable[i]):
shape_vars.append(self.lscalar_one)
else:
shape_vars.append(self.unpack(s[i]))
......@@ -947,8 +941,8 @@ class ShapeFeature(object):
self.set_shape(r, other_shape)
return
if (other_r.owner and r.owner and
other_r.owner.inputs == r.owner.inputs and
other_r.owner.op == r.owner.op):
other_r.owner.inputs == r.owner.inputs and
other_r.owner.op == r.owner.op):
# We are doing a merge. So the 2 shapes graph will be the
# same. This is only a speed optimization to call
# ancestors() less frequently.
......@@ -957,10 +951,10 @@ class ShapeFeature(object):
# Merge other_shape with r_shape, giving the priority to other_shape
merged_shape = []
for i, ps in enumerate(other_shape):
if (ps.owner
and isinstance(getattr(ps.owner, 'op', None), Shape_i)
and ps.owner.op.i == i
and ps.owner.inputs[0] in (r, other_r)):
if (ps.owner and
isinstance(getattr(ps.owner, 'op', None), Shape_i) and
ps.owner.op.i == i and
ps.owner.inputs[0] in (r, other_r)):
# If other_shape[i] is uninformative, use r_shape[i].
# For now, we consider 2 cases of uninformative other_shape[i]:
# - Shape_i(i)(other_r);
......@@ -1084,11 +1078,11 @@ class ShapeFeature(object):
r in node.inputs])
except NotImplementedError as e:
raise NotImplementedError(
'Code called by infer_shape failed raising a '
'NotImplementedError. Raising NotImplementedError to '
'indicate that a shape cannot be computed is no longer '
'supported, and one should now use tensor.ShapeError '
'instead. The original exception message is: %s' % e)
'Code called by infer_shape failed raising a '
'NotImplementedError. Raising NotImplementedError to '
'indicate that a shape cannot be computed is no longer '
'supported, and one should now use tensor.ShapeError '
'instead. The original exception message is: %s' % e)
except Exception as e:
msg = ('Failed to infer_shape from Op %s.\nInput shapes: '
'%s\nException encountered during infer_shape: '
......@@ -1108,10 +1102,10 @@ class ShapeFeature(object):
if len(o_shapes) != len(node.outputs):
raise Exception(
('The infer_shape method for the Op "%s" returned a list ' +
'with the wrong number of element: len(o_shapes) = %d ' +
' != len(node.outputs) = %d') % (str(node.op),
len(o_shapes),
len(node.outputs)))
'with the wrong number of element: len(o_shapes) = %d ' +
' != len(node.outputs) = %d') % (str(node.op),
len(o_shapes),
len(node.outputs)))
# Ensure shapes are in 'int64'. This is to make sure the assert
# found in the `local_useless_subtensor` optimization does not fail.
......@@ -1173,9 +1167,9 @@ class ShapeFeature(object):
# with the InputToGpuOptimizer optimizer.
continue
if (repl.owner and
repl.owner.inputs[0] is shpnode.inputs[0] and
isinstance(repl.owner.op, Shape_i) and
repl.owner.op.i == shpnode.op.i):
repl.owner.inputs[0] is shpnode.inputs[0] and
isinstance(repl.owner.op, Shape_i) and
repl.owner.op.i == shpnode.op.i):
# The replacement is a shape_i of the same
# input. So no need to do this equivalent
# replacement.
......@@ -1239,7 +1233,7 @@ class ShapeFeature(object):
if not dx.owner or not dy.owner:
return False
if (not isinstance(dx.owner.op, Shape_i) or
not isinstance(dy.owner.op, Shape_i)):
not isinstance(dy.owner.op, Shape_i)):
return False
opx = dx.owner.op
opy = dy.owner.op
......@@ -1310,10 +1304,9 @@ def local_fill_to_alloc(node):
return
# TODO: cut out un-necessary dimshuffles of v
assert rval[0].type == node.outputs[0].type, ('rval', rval[0].type,
'orig', node.outputs[0].type,
'node', node,
) # theano.printing.debugprint(node.outputs[0], file='str'))
assert rval[0].type == node.outputs[0].type, (
'rval', rval[0].type, 'orig', node.outputs[0].type, 'node',
node,) # theano.printing.debugprint(node.outputs[0], file='str'))
return rval
......@@ -1404,7 +1397,7 @@ def local_subtensor_make_vector(node):
try:
idx, = node.op.idx_list
except Exception:
#'how can you have multiple indexes into a shape?'
# 'how can you have multiple indexes into a shape?'
raise
if isinstance(idx, (scalar.Scalar, T.TensorType)):
......@@ -1467,13 +1460,13 @@ def local_useless_elemwise(node):
if isinstance(node.op, T.Elemwise):
if node.op.scalar_op == theano.scalar.eq and len(node.inputs) == 2:
if node.inputs[0] == node.inputs[1]:
# it is the same var in the graph. That will always be true
# it is the same var in the graph. That will always be true
return [T.fill(node.inputs[0],
T.constant(1.0,
dtype=node.outputs[0].type.dtype))]
elif node.op.scalar_op == theano.scalar.neq and len(node.inputs) == 2:
if node.inputs[0] == node.inputs[1]:
# it is the same var in the graph. That will always be false
# it is the same var in the graph. That will always be false
return [T.fill(node.inputs[0],
T.constant(0.0,
dtype=node.outputs[0].type.dtype))]
......@@ -1482,8 +1475,8 @@ def local_useless_elemwise(node):
elif node.op.scalar_op == theano.scalar.add and len(node.inputs) == 1:
return [node.inputs[0]]
elif (node.op.scalar_op == theano.scalar.identity
and len(node.inputs) == 1):
elif (node.op.scalar_op == theano.scalar.identity and
len(node.inputs) == 1):
return [node.inputs[0]]
......@@ -1513,12 +1506,12 @@ def local_cast_cast(node):
and the first cast cause an upcast.
"""
if (not isinstance(node.op, T.Elemwise) or
not isinstance(node.op.scalar_op, scalar.Cast)):
not isinstance(node.op.scalar_op, scalar.Cast)):
return
x = node.inputs[0]
if (not x.owner or
not isinstance(x.owner.op, T.Elemwise) or
not isinstance(x.owner.op.scalar_op, scalar.Cast)):
not isinstance(x.owner.op, T.Elemwise) or
not isinstance(x.owner.op.scalar_op, scalar.Cast)):
return
if node.op.scalar_op.o_type == x.owner.op.scalar_op.o_type:
return [x]
......@@ -1738,7 +1731,7 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
# The broadcast pattern of the ouptut must match the broadcast
# pattern of at least one of the inputs.
if not any([i.type.broadcastable ==
node.outputs[0].type.broadcastable for i in node.inputs]):
node.outputs[0].type.broadcastable for i in node.inputs]):
return False
def dimshuffled_alloc(i):
......@@ -1749,10 +1742,8 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
# At least one input must have an owner that is either a AllocOP or a
# DimShuffleOP with an owner that is a AllocOP -- otherwise there is
# nothing to optimize.
if not any([i.owner
and (isinstance(i.owner.op, AllocOP) or
dimshuffled_alloc(i))
for i in node.inputs]):
if not any([i.owner and (isinstance(i.owner.op, AllocOP) or
dimshuffled_alloc(i)) for i in node.inputs]):
return False
# Search for input that we can use as a baseline for the dimensions.
......@@ -1761,9 +1752,8 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
if i.type.broadcastable == node.outputs[0].type.broadcastable:
# Prefer an input that is not a AllocOP nor a DimShuffleOP of a
# AllocOP so that all allocs can be optimized.
if not (i.owner
and (isinstance(i.owner.op, AllocOP)
or dimshuffled_alloc(i))):
if not (i.owner and (isinstance(i.owner.op, AllocOP) or
dimshuffled_alloc(i))):
assert_op_idx = idx
break
......@@ -1773,8 +1763,8 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
# there is more than one then do all but one. number of
# inputs with alloc or dimshuffle alloc
l2 = [i for i in node.inputs
if (i.owner and (isinstance(i.owner.op, AllocOP)
or dimshuffled_alloc(i)))]
if (i.owner and (isinstance(i.owner.op, AllocOP) or
dimshuffled_alloc(i)))]
# If only 1 alloc or dimshuffle alloc, it is the one we
# will use for the shape. So no alloc would be removed.
if len(l2) > 1:
......@@ -1794,14 +1784,13 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
same_shape = node.fgraph.shape_feature.same_shape
for i in node.inputs:
# Remove alloc
if (i.owner and isinstance(i.owner.op, AllocOP)
and i.owner.inputs[0].type != i.owner.outputs[0].type):
if (i.owner and isinstance(i.owner.op, AllocOP) and
i.owner.inputs[0].type != i.owner.outputs[0].type):
# when i.owner.inputs[0].type == i.owner.outputs[0].type we
# will remove that alloc later
assert i.type.ndim == cmp_op.ndim
if (theano.config.experimental.local_alloc_elemwise_assert
and not same_shape(i, cmp_op)):
if (theano.config.experimental.local_alloc_elemwise_assert and
not same_shape(i, cmp_op)):
assert_op = assert_(assert_op,
*[T.eq(i.shape[idx], cmp_op.shape[idx])
for idx in xrange(i.type.ndim)
......@@ -1891,7 +1880,7 @@ def local_upcast_elemwise_constant_inputs(node):
scalar_op = node.op.scalar_op
# print "aa", scalar_op.output_types_preference
if (getattr(scalar_op, 'output_types_preference', None)
in (T.scal.upgrade_to_float, T.scal.upcast_out)):
in (T.scal.upgrade_to_float, T.scal.upcast_out)):
# this is the kind of op that we can screw with the input
# dtypes by upcasting explicitly
output_dtype = node.outputs[0].type.dtype
......@@ -1909,12 +1898,12 @@ def local_upcast_elemwise_constant_inputs(node):
i.ndim))
else:
if shape_i is None:
return
new_inputs.append(T.alloc(T.cast(cval_i,
output_dtype),
*[shape_i(d)(i) for d in xrange(i.ndim)]))
#print >> sys.stderr, "AAA",
#*[Shape_i(d)(i) for d in xrange(i.ndim)]
return new_inputs.append(
T.alloc(T.cast(cval_i, output_dtype),
*[shape_i(d)(i)
for d in xrange(i.ndim)]))
# print >> sys.stderr, "AAA",
# *[Shape_i(d)(i) for d in xrange(i.ndim)]
except NotScalarConstantError:
# for the case of a non-scalar
if isinstance(i, T.TensorConstant):
......@@ -1958,7 +1947,7 @@ def local_useless_inc_subtensor(node):
except NotScalarConstantError:
return
if (node.inputs[0].ndim != node.inputs[1].ndim or
node.inputs[0].broadcastable != node.inputs[1].broadcastable):
node.inputs[0].broadcastable != node.inputs[1].broadcastable):
# FB: I didn't check if this case can happen, but this opt
# don't support it.
return
......@@ -1994,16 +1983,16 @@ def local_set_to_inc_subtensor(node):
AdvancedIncSubtensor1(x, other, ilist, set_instead_of_inc=False)
"""
if (isinstance(node.op, AdvancedIncSubtensor1) and
node.op.set_instead_of_inc == True and
node.inputs[1].owner and
isinstance(node.inputs[1].owner.op, Elemwise) and
isinstance(node.inputs[1].owner.op.scalar_op, scalar.Add)):
node.op.set_instead_of_inc and
node.inputs[1].owner and
isinstance(node.inputs[1].owner.op, Elemwise) and
isinstance(node.inputs[1].owner.op.scalar_op, scalar.Add)):
addn = node.inputs[1].owner
subn = None
other = None
if (addn.inputs[0].owner and
isinstance(addn.inputs[0].owner.op, AdvancedSubtensor1)):
isinstance(addn.inputs[0].owner.op, AdvancedSubtensor1)):
subn = addn.inputs[0].owner
other = addn.inputs[1]
elif (addn.inputs[1].owner and
......@@ -2013,7 +2002,7 @@ def local_set_to_inc_subtensor(node):
else:
return
if (subn.inputs[1] != node.inputs[2] or
subn.inputs[0] != node.inputs[0]):
subn.inputs[0] != node.inputs[0]):
return
return [advanced_inc_subtensor1(node.inputs[0], other, node.inputs[2])]
......@@ -2030,9 +2019,9 @@ def local_useless_slice(node):
last_slice = len(slices)
for s in slices[::-1]:
# check if slice and then check slice indices
if (isinstance(s, slice) and s.start is None and s.stop is None
and (s.step is None or T.extract_constant(s.step) == 1)):
last_slice -= 1
if (isinstance(s, slice) and s.start is None and s.stop is None and
(s.step is None or T.extract_constant(s.step) == 1)):
last_slice -= 1
else:
break
# check if we removed something
......@@ -2098,11 +2087,10 @@ def local_useless_subtensor(node):
# the same underlying variable.
if (length_pos_shape_i.owner and
isinstance(length_pos_shape_i.owner.op,
T.ScalarFromTensor)):
T.ScalarFromTensor)):
length_pos_shape_i = length_pos_shape_i.owner.inputs[0]
elif (length_pos.owner and
isinstance(length_pos.owner.op,
T.TensorFromScalar)):
isinstance(length_pos.owner.op, T.TensorFromScalar)):
length_pos = length_pos.owner.inputs[0]
else:
# We did not find underlying variables of the same type
......@@ -2322,8 +2310,8 @@ def merge_two_slices(slice1, len1, slice2, len2):
pn_stop = sl1.start + (sl2.start - 1) * sl1.step
pn_stop = T.switch(T.and_(T.lt(pn_stop, 0),
T.gt(flen, 0)),
-len1 - 1,
T.minimum(pn_stop, sl1.stop))
-len1 - 1,
T.minimum(pn_stop, sl1.stop))
pn_start = sl1.start + (sl2.stop - 1) * sl1.step
pn_start = T.minimum(pn_start, sl1.stop)
pn_start = T.maximum(pn_start, 0)
......@@ -2345,9 +2333,8 @@ def merge_two_slices(slice1, len1, slice2, len2):
pp_start))
stop = T.switch(T.lt(reverse2 * reverse1, 0),
T.switch(T.lt(reverse1, 0), np_stop, pn_stop),
T.switch(T.lt(reverse1, 0), nn_stop, pp_stop
))
T.switch(T.lt(reverse1, 0), np_stop, pn_stop),
T.switch(T.lt(reverse1, 0), nn_stop, pp_stop))
step = T.switch(T.lt(reverse2 * reverse1, 0), n_step, p_step)
start = T.switch(T.le(flen, 0), 0, start)
......@@ -2463,7 +2450,7 @@ def local_subtensor_of_alloc(node):
# We check that the corresponding val dimensions was
# not a broadcasted dimensions.
if (val.type.ndim > (i - n_added_dims) and
val.type.broadcastable[i - n_added_dims]):
val.type.broadcastable[i - n_added_dims]):
val_slices.append(slice(None))
else:
val_slices.append(sl)
......@@ -2496,8 +2483,8 @@ def local_subtensor_of_alloc(node):
rval[0] = theano.tensor.unbroadcast(
rval[0],
*[i for i, (b1, b2) in enumerate(zip(rval[0].broadcastable,
node.outputs[0].broadcastable))
if b1 and not b2])
node.outputs[0].broadcastable))
if b1 and not b2])
return rval
......@@ -2518,7 +2505,7 @@ def local_subtensor_of_dot(node):
if not isinstance(node.op, Subtensor):
return
if (not node.inputs[0].owner or
not isinstance(node.inputs[0].owner.op, T.Dot)):
not isinstance(node.inputs[0].owner.op, T.Dot)):
return
# If there is other node that use the outputs of the dot
# We don't want to compute twice the sub part.
......@@ -2540,7 +2527,8 @@ def local_subtensor_of_dot(node):
# We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:]
# (dot also handles b.ndim < 2 as a special case)
if b.ndim > 1 and len(b_indices) >= b.ndim - 1:
b_indices = b_indices[:b.ndim-2] + (slice(None, None, None),) + b_indices[b.ndim-2:]
b_indices = (b_indices[:b.ndim - 2] +
(slice(None, None, None),) + b_indices[b.ndim - 2:])
a_sub = a.__getitem__(tuple(a_indices))
b_sub = b.__getitem__(tuple(b_indices)) if b_indices else b
......@@ -2583,14 +2571,13 @@ def local_IncSubtensor_serialize(node):
"""
def movable(i):
# Return True iff this is a incsubtensor that we can move
return i.owner \
and isinstance(i.owner.op, (IncSubtensor,
AdvancedIncSubtensor1,
AdvancedIncSubtensor,
)) \
and i.type == o_type \
and len(i.clients) == 1 \
and not i.owner.op.set_instead_of_inc
return (i.owner and
isinstance(i.owner.op, (IncSubtensor,
AdvancedIncSubtensor1,
AdvancedIncSubtensor,)) and
i.type == o_type and
len(i.clients) == 1 and
not i.owner.op.set_instead_of_inc)
if node.op == T.add:
o_type = node.outputs[0].type
......@@ -2598,8 +2585,8 @@ def local_IncSubtensor_serialize(node):
movable_inputs = [i for i in node.inputs if movable(i)]
if movable_inputs:
new_inputs = [i for i in node.inputs if not movable(i)] \
+ [mi.owner.inputs[0] for mi in movable_inputs]
new_inputs = ([i for i in node.inputs if not movable(i)] +
[mi.owner.inputs[0] for mi in movable_inputs])
new_add = T.add(*new_inputs)
# stack up the new incsubtensors
......@@ -2638,9 +2625,10 @@ def local_inplace_setsubtensor(node):
return [new_node]
return False
compile.optdb.register('local_inplace_setsubtensor',
TopoOptimizer(local_inplace_setsubtensor,
failure_callback=TopoOptimizer.warn_inplace), 60,
'fast_run', 'inplace') # DEBUG
TopoOptimizer(
local_inplace_setsubtensor,
failure_callback=TopoOptimizer.warn_inplace),
60, 'fast_run', 'inplace') # DEBUG
@gof.local_optimizer([AdvancedIncSubtensor1], inplace=True)
......@@ -2653,8 +2641,8 @@ def local_inplace_incsubtensor1(node):
return False
compile.optdb.register('local_inplace_incsubtensor1',
TopoOptimizer(
local_inplace_incsubtensor1,
failure_callback=TopoOptimizer.warn_inplace),
local_inplace_incsubtensor1,
failure_callback=TopoOptimizer.warn_inplace),
60, 'fast_run', 'inplace') # DEBUG
......@@ -2671,7 +2659,7 @@ def local_incsubtensor_of_zeros(node):
if (isinstance(node.op, (IncSubtensor,
AdvancedIncSubtensor,
AdvancedIncSubtensor1)) and
not node.op.set_instead_of_inc):
not node.op.set_instead_of_inc):
x = node.inputs[0]
y = node.inputs[1]
replace = False
......@@ -2713,8 +2701,8 @@ def local_setsubtensor_of_constants(node):
pass
if (replace_x is not None and
replace_y is not None and
replace_x == replace_y):
replace_y is not None and
replace_x == replace_y):
return [x]
else:
return False
......@@ -2738,7 +2726,7 @@ def local_adv_sub1_adv_inc_sub1(node):
return
inp = node.inputs[0]
if (not inp.owner or
not isinstance(inp.owner.op, AdvancedIncSubtensor1)):
not isinstance(inp.owner.op, AdvancedIncSubtensor1)):
return
idx = node.inputs[1]
idx2 = inp.owner.inputs[2]
......@@ -2747,13 +2735,13 @@ def local_adv_sub1_adv_inc_sub1(node):
if idx is not idx2:
return
if (not inp.owner.op.set_instead_of_inc and
T.extract_constant(x) != 0):
T.extract_constant(x) != 0):
return
cond = [T.all(T.and_(T.lt(idx, x.shape[0]),
T.ge(idx, -x.shape[0])))]
cond = [T.all(T.and_(T.lt(idx, x.shape[0]), T.ge(idx, -x.shape[0])))]
if not node.fgraph.shape_feature.same_shape(idx, y, 0, 0):
cond.append(T.eq(idx.shape[0], y.shape[0]))
y = Assert("Bad indexing or shapes in a AdvancedIncSubtensor1 that was optimized away")(y, *cond)
y = Assert("Bad indexing or shapes in a AdvancedIncSubtensor1 "
"that was optimized away")(y, *cond)
if y.dtype == node.outputs[0].dtype:
return [y]
......@@ -2828,33 +2816,34 @@ def local_useless_inc_subtensor_alloc(node):
# Build `z_broad` explicitly to include extra implicit dimensions.
z_broad = ((True,) * (xi.ndim - z.ndim) + z.broadcastable)
cond = [# The shapes of `y` and `xi` must either agree or `y` may
# also have shape equal to 1 which may be treated as a
# broadcastable dimension by the subtensor op.
T.or_(T.eq(y.shape[k], 1), T.eq(y.shape[k], xi.shape[k]))
# Loop over all dimensions.
for k in xrange(xi.ndim)
# We need to check the above shapes, if
# * the pre-alloc increment `z` is broadcastable in
# dimension `k` (if it isn't, then the shapes of `z` and
# `y` are the same by the definition of the `Alloc` op in
# this dimension and replacing `y` by `z` will not hide a
# shape error), and
# * `xi` and `y` do not have the same shape in dimension
# `k` or we cannot infer the shape statically (if the
# shapes of `xi` and `y` are not the same, then replacing
# `y` by `z` will hide the shape error of `y`), and
# * the shape of `y` is not equal to 1 or we cannot infer
# the shape statically (if the shape of `y` is equal to
# 1, then `y` is broadcasted by the inc_subtensor op
# internally, so the shapes of `xi` and `y` do not need
# to match in dimension `k`; else we need to check at
# runtime that the shape of `y` is either 1 or the same
# as `xi` or otherwise replacing `y` by `z` will hide a
# shape error).
if (z_broad[k] and
not same_shape(xi, y, dim_x=k, dim_y=k) and
shape_of[y][k] != 1)]
cond = [
# The shapes of `y` and `xi` must either agree or `y` may
# also have shape equal to 1 which may be treated as a
# broadcastable dimension by the subtensor op.
T.or_(T.eq(y.shape[k], 1), T.eq(y.shape[k], xi.shape[k]))
# Loop over all dimensions.
for k in xrange(xi.ndim)
# We need to check the above shapes, if
# * the pre-alloc increment `z` is broadcastable in
# dimension `k` (if it isn't, then the shapes of `z` and
# `y` are the same by the definition of the `Alloc` op in
# this dimension and replacing `y` by `z` will not hide a
# shape error), and
# * `xi` and `y` do not have the same shape in dimension
# `k` or we cannot infer the shape statically (if the
# shapes of `xi` and `y` are not the same, then replacing
# `y` by `z` will hide the shape error of `y`), and
# * the shape of `y` is not equal to 1 or we cannot infer
# the shape statically (if the shape of `y` is equal to
# 1, then `y` is broadcasted by the inc_subtensor op
# internally, so the shapes of `xi` and `y` do not need
# to match in dimension `k`; else we need to check at
# runtime that the shape of `y` is either 1 or the same
# as `xi` or otherwise replacing `y` by `z` will hide a
# shape error).
if (z_broad[k] and
not same_shape(xi, y, dim_x=k, dim_y=k) and
shape_of[y][k] != 1)]
if len(cond) > 0:
msg = '`x[i]` and `y` do not have the same shape.'
......@@ -2916,7 +2905,7 @@ def local_rebroadcast_lift(node):
# compilation phase.
if hasattr(input, 'clients') and len(input.clients) == 1:
rval = inode.op.make_node(T.Rebroadcast(*list(op.axis.items()))(
inode.inputs[0])).outputs
inode.inputs[0])).outputs
return rval
if inode and isinstance(inode.op, T.Rebroadcast):
# the "axis" specification in the outer Rebroadcast overrides
......@@ -3031,11 +3020,11 @@ def local_join_make_vector(node):
for idx in xrange(2, len(node.inputs)):
inp = node.inputs[idx]
if (inp.owner and
isinstance(inp.owner.op, MakeVector) and
new_inputs[-1].owner and
isinstance(new_inputs[-1].owner.op, MakeVector) and
# MakeVector have a dtype parameter
inp.owner.op == new_inputs[-1].owner.op):
isinstance(inp.owner.op, MakeVector) and
new_inputs[-1].owner and
isinstance(new_inputs[-1].owner.op, MakeVector) and
# MakeVector have a dtype parameter
inp.owner.op == new_inputs[-1].owner.op):
inps = new_inputs[-1].owner.inputs + inp.owner.inputs
new_inputs[-1] = inp.owner.op(*inps)
else:
......@@ -3059,7 +3048,7 @@ def local_remove_switch_const_cond(node):
if cond is constant and cond != 0: left
"""
if (isinstance(node.op, T.Elemwise) and
isinstance(node.op.scalar_op, scalar.basic.Switch)):
isinstance(node.op.scalar_op, scalar.basic.Switch)):
cond = T.extract_constant(node.inputs[0], elemwise=False)
if type(cond) is numpy.ndarray and cond.ndim == 0:
if cond == 0:
......@@ -3241,9 +3230,9 @@ def local_flatten_lift(node):
nnet/sigm.py:log1msigm_to_softplus to get applied when there is a flatten.
"""
if (isinstance(node.op, T.Flatten) and
node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, T.Elemwise) and
len(node.inputs[0].owner.inputs) == 1):
node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, T.Elemwise) and
len(node.inputs[0].owner.inputs) == 1):
f = node.op(node.inputs[0].owner.inputs[0])
e = node.inputs[0].owner.op(f)
return [e]
......@@ -3290,9 +3279,9 @@ def local_reshape_lift(node):
nnet/sigm.py:log1msigm_to_softplus to get applied when there is a reshape.
"""
if (isinstance(node.op, T.Reshape) and
node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, T.Elemwise) and
len(node.inputs[0].owner.inputs) == 1):
node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, T.Elemwise) and
len(node.inputs[0].owner.inputs) == 1):
r = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])
e = node.inputs[0].owner.op(r)
# In rare case the original broadcast was (False, True), but
......@@ -3539,7 +3528,7 @@ class Canonizer(gof.LocalOptimizer):
return [input], []
if input.owner is None or input.owner.op not in [
self.main, self.inverse, self.reciprocal]:
self.main, self.inverse, self.reciprocal]:
if input.owner and isinstance(input.owner.op, T.DimShuffle):
# If input is a DimShuffle of some input which does
# something like this:
......@@ -3552,9 +3541,9 @@ class Canonizer(gof.LocalOptimizer):
# the num/denum of its input
dsn = input.owner # dimshuffle node
dsop = dsn.op # dimshuffle op
dsi0 = dsn.inputs[0] # the first input of the
# dimshuffle i.e. the ndarray to
# redim
# the first input of the dimshuffle i.e. the ndarray to redim
dsi0 = dsn.inputs[0]
# The compatible order is a DimShuffle "new_order" of the form:
# ('x', ..., 'x', 0, 1, 2, ..., dimshuffle_input.type.ndim)
......@@ -3566,9 +3555,9 @@ class Canonizer(gof.LocalOptimizer):
# different numbers of dimensions (hence why we can
# discard its information - we know we can retrieve it
# later on).
compatible_order = ('x',) * (input.type.ndim
- dsi0.type.ndim) + tuple(
range(dsi0.type.ndim))
compatible_order = (('x',) *
(input.type.ndim - dsi0.type.ndim) +
tuple(range(dsi0.type.ndim)))
if dsop.new_order == compatible_order:
# If the "new_order" is the one we recognize,
# we return the num_denum of the dimshuffled input.
......@@ -3815,9 +3804,9 @@ class Canonizer(gof.LocalOptimizer):
new = self.merge_num_denum(num, denum)
if new.type.dtype != out.type.dtype:
#new = T.fill(out, new)
# new = T.fill(out, new)
elem_op = T.Elemwise(scalar.Identity(scalar.specific_out(
getattr(scalar, out.type.dtype))))
getattr(scalar, out.type.dtype))))
new = elem_op(new)
assert (new.type == out.type) == (not (new.type != out.type))
......@@ -3833,12 +3822,12 @@ class Canonizer(gof.LocalOptimizer):
else:
_logger.warning(' '.join(('CANONIZE FAILED: new, out = ',
new, ',', out, 'types',
new.type, ',', out.type)))
new.type, ',', out.type)))
return False
def __str__(self):
return getattr(self, 'name', 'Canonizer(%s, %s, %s)' % (
self.main, self.inverse, self.reciprocal))
self.main, self.inverse, self.reciprocal))
def mul_calculate(num, denum, aslist=False, out_type=None):
......@@ -3872,7 +3861,7 @@ register_canonicalize(local_mul_canonizer, name='local_mul_canonizer')
def local_neg_to_mul(node):
if node.op == T.neg:
return [T.mul(numpy.array(-1, dtype=node.inputs[0].dtype),
node.inputs[0])]
node.inputs[0])]
register_canonicalize(local_neg_to_mul)
......@@ -3924,10 +3913,10 @@ def local_elemwise_sub_zeros(node):
"""
Elemwise{sub}(X,X) -> zeros_like(X)
"""
if (isinstance(node.op, T.Elemwise)
and node.op.scalar_op.nin == 2
and node.op.scalar_op == scalar.sub
and node.inputs[0] == node.inputs[1]):
if (isinstance(node.op, T.Elemwise) and
node.op.scalar_op.nin == 2 and
node.op.scalar_op == scalar.sub and
node.inputs[0] == node.inputs[1]):
return [T.zeros_like(node.inputs[0])]
......@@ -4013,9 +4002,8 @@ def local_sum_div_dimshuffle(node):
' to False.')
new_denom = T.DimShuffle(
thing_dimshuffled.type.broadcastable,
new_new_order
)(thing_dimshuffled)
thing_dimshuffled.type.broadcastable,
new_new_order)(thing_dimshuffled)
return [T.true_div(node.op(numerator), new_denom)]
# else:
# print 'incompatible dims:', axis, new_order
......@@ -4052,8 +4040,9 @@ def local_op_of_op(node):
# We manipulate the graph so this is done to make sure the opt
# doesn't affect other computations.
if len(node_inps.clients) == 1:
if (node_inps.owner and (isinstance(node_inps.owner.op, T.elemwise.Prod)
or isinstance(node_inps.owner.op, T.elemwise.Sum))):
if (node_inps.owner and
(isinstance(node_inps.owner.op, T.elemwise.Prod) or
isinstance(node_inps.owner.op, T.elemwise.Sum))):
# check to see either the inner or outer prod is doing a
# product over all axis, in which case we can remove it
......@@ -4074,7 +4063,6 @@ def local_op_of_op(node):
assert len(newaxis) == len(list(node_inps.owner.op.axis) +
list(node.op.axis))
# The old bugged logic. We keep it there to generate a warning
# when we generated bad code.
alldims = list(range(node_inps.owner.inputs[0].type.ndim))
......@@ -4087,20 +4075,20 @@ def local_op_of_op(node):
if i not in alldims]
if (theano.config.warn.sum_sum_bug and
newaxis != newaxis_old and
len(newaxis) == len(newaxis_old)):
newaxis != newaxis_old and
len(newaxis) == len(newaxis_old)):
_logger.warn(
"WARNING (YOUR CURRENT CODE IS FINE): Theano "
"versions between version 9923a40c7b7a and August "
"2nd, 2010 generated bugged code in this case. "
"This happens when there are two consecutive sums "
"in the graph and the intermediate sum is not "
"used elsewhere in the code. Some safeguard "
"removed some bad code, but not in all cases. You "
"are in one such case. To disable this warning "
"(that you can safely ignore since this bug has "
"been fixed) set the theano flag "
"`warn.sum_sum_bug` to False.")
"WARNING (YOUR CURRENT CODE IS FINE): Theano "
"versions between version 9923a40c7b7a and August "
"2nd, 2010 generated bugged code in this case. "
"This happens when there are two consecutive sums "
"in the graph and the intermediate sum is not "
"used elsewhere in the code. Some safeguard "
"removed some bad code, but not in all cases. You "
"are in one such case. To disable this warning "
"(that you can safely ignore since this bug has "
"been fixed) set the theano flag "
"`warn.sum_sum_bug` to False.")
combined = opt_type(newaxis, dtype=out_dtype)
return [combined(node_inps.owner.inputs[0])]
......@@ -4126,9 +4114,8 @@ def local_reduce_join(node):
"""
if (isinstance(node.op, T.CAReduce) and
node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, T.Join)):
node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, T.Join)):
join = node.inputs[0].owner
if T.extract_constant(join.inputs[0]) != 0:
return
......@@ -4149,7 +4136,8 @@ def local_reduce_join(node):
if not inp:
return
if (not isinstance(inp.op, DimShuffle) or
inp.op.new_order != ('x',) + tuple(range(inp.inputs[0].ndim))):
inp.op.new_order != ('x',) +
tuple(range(inp.inputs[0].ndim))):
return
new_inp.append(inp.inputs[0])
ret = Elemwise(node.op.scalar_op)(*new_inp)
......@@ -4174,9 +4162,8 @@ def local_reduce_join(node):
'optimization, that modified the pattern '
'"Reduce{scalar.op}(Join(axis=0, a, b), axis=0)", '
'did not check the reduction axis. So if the '
'reduction axis was not 0, you got a wrong answer.'
))
return
'reduction axis was not 0, you got a wrong answer.'))
return
# We add the new check late to don't add extra warning.
try:
......@@ -4204,7 +4191,7 @@ def local_cut_useless_reduce(node):
# theano/tensor/tests/test_opt.py:T_local_reduce.test_local_reduce_broadcast_some_0
# see gh-790 issue.
#
#@register_canonicalize
# @register_canonicalize
@register_uncanonicalize
@register_specialize
@gof.local_optimizer(ALL_REDUCE)
......@@ -4258,7 +4245,7 @@ def local_opt_alloc(node):
input = node_inps.owner.inputs[0]
shapes = node_inps.owner.inputs[1:]
if (node.op.axis is None or
node.op.axis == tuple(range(input.ndim))):
node.op.axis == tuple(range(input.ndim))):
try:
val = get_scalar_constant_value(input)
assert val.size == 1
......@@ -4346,7 +4333,7 @@ register_canonicalize(local_mul_zero)
@gof.local_optimizer([T.true_div])
def local_div_to_inv(node):
if node.op == T.true_div and N.all(
local_mul_canonizer.get_constant(node.inputs[0]) == 1.0):
local_mul_canonizer.get_constant(node.inputs[0]) == 1.0):
out = node.outputs[0]
new_out = T.inv(local_mul_canonizer.merge_num_denum(node.inputs[1:],
[]))
......@@ -4501,7 +4488,8 @@ def local_pow_specialize_device(node):
if abs(y) > 2:
# We fuse all the pow together here to make
# compilation faster
rval1 = Elemwise(theano.scalar.Composite(
rval1 = Elemwise(
theano.scalar.Composite(
[pow2_scal[0]], [rval1_scal])).make_node(xsym)
if y < 0:
rval = [T.inv(rval1)]
......@@ -4566,8 +4554,8 @@ def local_mul_specialize(node):
else:
# The next case would cause a replace by an equivalent case.
if (neg and
nb_neg_node == 0 and
nb_cst == 1):
nb_neg_node == 0 and
nb_cst == 1):
return
elif neg:
# Don't add an extra neg node as we can't
......@@ -4640,8 +4628,8 @@ def check_for_x_over_absX(numerators, denominators):
# TODO: this function should dig/search through dimshuffles
# This won't catch a dimshuffled absolute value
for den in list(denominators):
if (den.owner and den.owner.op == T.abs_
and den.owner.inputs[0] in numerators):
if (den.owner and den.owner.op == T.abs_ and
den.owner.inputs[0] in numerators):
if den.owner.inputs[0].type.dtype.startswith('complex'):
# TODO: Make an Op that projects a complex number to
# have unit length but projects 0 to 0. That
......@@ -4715,8 +4703,8 @@ def local_log1p(node):
if node.op == T.log:
log_arg, = node.inputs
if log_arg.owner and log_arg.owner.op == T.add:
scalars, scalar_inputs, nonconsts = \
scalarconsts_rest(log_arg.owner.inputs)
scalars, scalar_inputs, nonconsts = scalarconsts_rest(
log_arg.owner.inputs)
# scalar_inputs are potentially dimshuffled and fill'd scalars
if scalars and numpy.allclose(numpy.sum(scalars), 1):
if not nonconsts:
......@@ -4748,7 +4736,7 @@ def local_log_add(node):
if len(zi) != 2:
# -- upgrading Maximum to handle multiple inputs wasn't trivial
# TODO
#raise NotImplementedError()
# raise NotImplementedError()
return
pre_exp = [x.owner.inputs[0] for x in zi
if x.owner and x.owner.op == T.exp]
......@@ -4945,8 +4933,7 @@ def constant_folding(node):
storage_map[o] = [None]
compute_map[o] = [False]
if (hasattr(node.op, 'python_constant_folding') and
node.op.python_constant_folding(node)):
node.op.python_constant_folding(node)):
old_value = getattr(node.op, '_op_use_c_code', False)
try:
node.op._op_use_c_code = False
......@@ -5037,9 +5024,9 @@ register_specialize(local_one_minus_erf)
local_one_minus_erf2 = gof.PatternSub((T.add,
1,
(T.mul, -1, (T.erf, 'x'))),
(T.erfc, 'x'),
allow_multiple_clients=True,
name='local_one_minus_erf2')
(T.erfc, 'x'),
allow_multiple_clients=True,
name='local_one_minus_erf2')
register_canonicalize(local_one_minus_erf2)
register_stabilize(local_one_minus_erf2)
register_specialize(local_one_minus_erf2)
......@@ -5058,7 +5045,7 @@ register_canonicalize(local_one_plus_neg_erf)
register_stabilize(local_one_plus_neg_erf)
register_specialize(local_one_plus_neg_erf)
#(-1)+erf(x) => -erfc(x) don't need erf(x)+(-1) as the canonicalize
# (-1)+erf(x) => -erfc(x) don't need erf(x)+(-1) as the canonicalize
# will put the -1 as the first argument.
local_erf_minus_one = gof.PatternSub((T.add,
dict(pattern='y', constraint=_is_minus1),
......@@ -5124,7 +5111,7 @@ register_canonicalize(local_one_add_neg_erfc)
register_stabilize(local_one_add_neg_erfc)
register_specialize(local_one_add_neg_erfc)
#(-1)+erfc(-x)=>erf(x)
# (-1)+erfc(-x)=>erf(x)
local_erf_neg_minus_one = gof.PatternSub((T.add,
dict(pattern='y', constraint=_is_minus1),
(T.erfc, (T.neg, 'x'))),
......@@ -5137,7 +5124,7 @@ register_canonicalize(local_erf_neg_minus_one)
register_stabilize(local_erf_neg_minus_one)
register_specialize(local_erf_neg_minus_one)
#(-1)+erfc(-1*x)=>erf(x)
# (-1)+erfc(-1*x)=>erf(x)
local_erf_neg_minus_one2 = gof.PatternSub((T.add,
dict(pattern='y', constraint=_is_minus1),
(T.erfc, (T.mul, -1, 'x'))),
......@@ -5176,8 +5163,8 @@ def local_log_erfc(node):
x = node.inputs[0].owner.inputs[0]
stab_value = (-x ** 2 - T.log(x) - .5 * T.log(numpy.pi) +
T.log(1 - 1 / (2 * x ** 2) + 3 / (4 * x ** 4)
- 15 / (8 * x ** 6)))
T.log(1 - 1 / (2 * x ** 2) + 3 / (4 * x ** 4) -
15 / (8 * x ** 6)))
if (node.outputs[0].dtype == 'float32' or
node.outputs[0].dtype == 'float16'):
......@@ -5191,8 +5178,8 @@ def local_log_erfc(node):
# Stability optimization of the grad of log(erfc(x))
#([y*]exp(-(x**2)))/erfc(x) # The y* is optional
#([y*]exp(x**2))/erfc(-x) => [y*](when x>threashold,
# ([y*]exp(-(x**2)))/erfc(x) # The y* is optional
# ([y*]exp(x**2))/erfc(-x) => [y*](when x>threashold,
# sqrt(pi)*-x/(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6)))
# for float64: threshold=26.63 see at the end of the fct for the explaination
# for float32: threshold=9.3 see at the end of the fct for the explaination
......@@ -5226,8 +5213,8 @@ def local_grad_log_erfc_neg(node):
if mul.owner.inputs[0].owner or len(mul.owner.inputs) != 2:
return False
y = mul.owner.inputs[0]
if (not mul.owner.inputs[1].owner
or mul.owner.inputs[1].owner.op != T.exp):
if (not mul.owner.inputs[1].owner or
mul.owner.inputs[1].owner.op != T.exp):
return False
exp = mul.owner.inputs[1]
......@@ -5236,8 +5223,8 @@ def local_grad_log_erfc_neg(node):
if exp.owner.inputs[0].owner.op == T.neg:
neg = exp.owner.inputs[0]
if (not neg.owner.inputs[0].owner
or neg.owner.inputs[0].owner.op != T.sqr):
if (not neg.owner.inputs[0].owner or
neg.owner.inputs[0].owner.op != T.sqr):
return False
sqr = neg.owner.inputs[0]
x = sqr.owner.inputs[0]
......@@ -5279,8 +5266,8 @@ def local_grad_log_erfc_neg(node):
return False
if len(mul_neg.owner.inputs) == 2:
if (not mul_neg.owner.inputs[1].owner
or mul_neg.owner.inputs[1].owner.op != T.sqr):
if (not mul_neg.owner.inputs[1].owner or
mul_neg.owner.inputs[1].owner.op != T.sqr):
return False
sqr = mul_neg.owner.inputs[1]
x = sqr.owner.inputs[0]
......@@ -5292,8 +5279,8 @@ def local_grad_log_erfc_neg(node):
return False
if cst2 != -1:
if (not erfc_x.owner or erfc_x.owner.op != T.mul
or len(erfc_x.owner.inputs) != 2):
if (not erfc_x.owner or erfc_x.owner.op != T.mul or
len(erfc_x.owner.inputs) != 2):
# todo implement that case
return False
if erfc_x.owner.inputs[1] is not mul_neg.owner.inputs[1]:
......@@ -5324,12 +5311,12 @@ def local_grad_log_erfc_neg(node):
# aaron value
stab_value = (x * T.pow(1 - 1 / (2 * (x ** 2)) +
3 / (4 * (x ** 4)) - 15 / (8 * (x ** 6)), -1)
* T.cast(T.sqrt(numpy.pi), dtype=x.dtype))
3 / (4 * (x ** 4)) - 15 / (8 * (x ** 6)), -1) *
T.cast(T.sqrt(numpy.pi), dtype=x.dtype))
if x.dtype == 'float32' or x.dtype == 'float16':
threshold = 9.3
#threshold = 10.1
# threshold = 10.1
elif x.dtype == 'float64':
threshold = 26.641747557
ret = T.switch(x < threshold, true_div_no_mul, stab_value) * y
......@@ -5531,6 +5518,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
if maker is None:
def maker(node, scalar_op):
return OP(scalar_op)
def local_fuse(node):
"""
As part of specialization, we fuse two consecutive elemwise Ops of the
......@@ -5598,13 +5586,13 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
# If a variable is used as multiple into to the same node,
# we still want to fusion. So we take the set.
if (i.owner and
isinstance(i.owner.op, OP) and
len(set([n for n, idx in i.clients])) == 1 and
# Do not merge elemwise that don't have the same
# broadcastable pattern to don't redo duplicate
# computation due to broadcast.
i.owner.outputs[0].broadcastable == node.outputs[0].broadcastable):
isinstance(i.owner.op, OP) and
len(set([n for n, idx in i.clients])) == 1 and
# Do not merge elemwise that don't have the same
# broadcastable pattern to don't redo duplicate
# computation due to broadcast.
i.owner.outputs[0].broadcastable ==
node.outputs[0].broadcastable):
do_fusion = True
try:
tmp_s_input = []
......@@ -5840,14 +5828,14 @@ def local_add_mul_fusion(node):
"""
if (not isinstance(node.op, Elemwise) or
not isinstance(node.op.scalar_op, (scalar.Add, scalar.Mul))):
not isinstance(node.op.scalar_op, (scalar.Add, scalar.Mul))):
return False
s_op = node.op.scalar_op.__class__
for inp in node.inputs:
if (inp.owner and
isinstance(inp.owner.op, Elemwise) and
isinstance(inp.owner.op.scalar_op, s_op)):
isinstance(inp.owner.op, Elemwise) and
isinstance(inp.owner.op.scalar_op, s_op)):
l = list(node.inputs)
l.remove(inp)
return [node.op(*(l + inp.owner.inputs))]
......@@ -5882,13 +5870,15 @@ else:
# just returns the input, it should be removed from the graph to
# make sure all possible optimizations can be applied.
register_canonicalize(gof.OpRemove(theano.gradient.consider_constant_),
'fast_compile', 'fast_run', name='remove_consider_constant')
'fast_compile', 'fast_run',
name='remove_consider_constant')
register_canonicalize(gof.OpRemove(theano.gradient.zero_grad_),
'fast_compile', 'fast_run', name='remove_zero_grad')
'fast_compile', 'fast_run', name='remove_zero_grad')
register_canonicalize(gof.OpRemove(theano.gradient.disconnected_grad_),
'fast_compile', 'fast_run', name='remove_disconnected_grad')
'fast_compile', 'fast_run',
name='remove_disconnected_grad')
@register_canonicalize
......
......@@ -63,7 +63,6 @@ whitelist_flake8 = [
"tensor/sort.py",
"tensor/__init__.py",
"tensor/opt_uncanonicalize.py",
"tensor/opt.py",
"tensor/blas.py",
"tensor/extra_ops.py",
"tensor/nlinalg.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论