提交 d40861ec authored 作者: Iban Harlouchet's avatar Iban Harlouchet

Flake8 for theano/tensor/opt.py

上级 34b98041
...@@ -6,8 +6,6 @@ from __future__ import print_function ...@@ -6,8 +6,6 @@ from __future__ import print_function
# TODO: 0*x -> 0 # TODO: 0*x -> 0
import logging import logging
_logger = logging.getLogger('theano.tensor.opt')
import itertools import itertools
import operator import operator
import sys import sys
...@@ -34,12 +32,10 @@ from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice, ...@@ -34,12 +32,10 @@ from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice,
Subtensor, IncSubtensor, make_constant, Subtensor, IncSubtensor, make_constant,
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
AdvancedIncSubtensor, AdvancedIncSubtensor,
AdvancedSubtensor,
AdvancedSubtensor1, AdvancedSubtensor1,
advanced_subtensor, advanced_subtensor,
advanced_subtensor1, advanced_subtensor1,
advanced_inc_subtensor1, advanced_inc_subtensor1)
inc_subtensor)
from theano import scalar from theano import scalar
from theano.scalar import basic from theano.scalar import basic
from theano.tensor import basic as T from theano.tensor import basic as T
...@@ -56,6 +52,8 @@ from theano.gof import toolbox ...@@ -56,6 +52,8 @@ from theano.gof import toolbox
from theano.tensor.basic import get_scalar_constant_value, ShapeError, NotScalarConstantError from theano.tensor.basic import get_scalar_constant_value, ShapeError, NotScalarConstantError
from six import StringIO from six import StringIO
_logger = logging.getLogger('theano.tensor.opt')
theano.configparser.AddConfigVar('on_shape_error', theano.configparser.AddConfigVar('on_shape_error',
"warn: print a warning and use the default" "warn: print a warning and use the default"
" value. raise: raise an error", " value. raise: raise an error",
...@@ -165,23 +163,24 @@ def broadcast_like(value, template, fgraph, dtype=None): ...@@ -165,23 +163,24 @@ def broadcast_like(value, template, fgraph, dtype=None):
# the template may have 1s in its shape without being broadcastable # the template may have 1s in its shape without being broadcastable
if rval.broadcastable != template.broadcastable: if rval.broadcastable != template.broadcastable:
rval = T.unbroadcast(rval, *[i for i in xrange(rval.ndim) rval = T.unbroadcast(rval, *[i for i in xrange(rval.ndim)
if rval.broadcastable[i] if rval.broadcastable[i] and
and not template.broadcastable[i]]) not template.broadcastable[i]])
assert rval.type.dtype == dtype assert rval.type.dtype == dtype
if rval.type.broadcastable != template.broadcastable: if rval.type.broadcastable != template.broadcastable:
raise AssertionError("rval.type.broadcastable is " + raise AssertionError("rval.type.broadcastable is " +
str(rval.type.broadcastable) + str(rval.type.broadcastable) +
" but template.broadcastable is" + " but template.broadcastable is" +
str(template.broadcastable)) str(template.broadcastable))
return rval return rval
theano.configparser.AddConfigVar('tensor.insert_inplace_optimizer_validate_nb', theano.configparser.AddConfigVar(
"-1: auto, if graph have less then 500 nodes 1, else 10", 'tensor.insert_inplace_optimizer_validate_nb',
theano.configparser.IntParam(-1), "-1: auto, if graph have less then 500 nodes 1, else 10",
in_c_key=False) theano.configparser.IntParam(-1),
in_c_key=False)
def inplace_elemwise_optimizer_op(OP): def inplace_elemwise_optimizer_op(OP):
...@@ -251,11 +250,10 @@ def inplace_elemwise_optimizer_op(OP): ...@@ -251,11 +250,10 @@ def inplace_elemwise_optimizer_op(OP):
# target. # target.
# Remove here as faster. # Remove here as faster.
candidate_inputs = [i for i in xrange(len(node.inputs)) candidate_inputs = [i for i in xrange(len(node.inputs))
if i not in baseline.values() \ if i not in baseline.values() and
and not isinstance(node.inputs[i], not isinstance(node.inputs[i], Constant) and
Constant)\ not fgraph.destroyers(node.inputs[i]) and
and not fgraph.destroyers(node.inputs[i])\ node.inputs[i] not in protected_inputs]
and node.inputs[i] not in protected_inputs]
verbose = False verbose = False
...@@ -265,7 +263,7 @@ def inplace_elemwise_optimizer_op(OP): ...@@ -265,7 +263,7 @@ def inplace_elemwise_optimizer_op(OP):
for candidate_input in candidate_inputs: for candidate_input in candidate_inputs:
# remove inputs that don't have the same dtype as the output # remove inputs that don't have the same dtype as the output
if node.inputs[candidate_input].type != node.outputs[ if node.inputs[candidate_input].type != node.outputs[
candidate_output].type: candidate_output].type:
continue continue
inplace_pattern = dict(baseline) inplace_pattern = dict(baseline)
...@@ -274,20 +272,20 @@ def inplace_elemwise_optimizer_op(OP): ...@@ -274,20 +272,20 @@ def inplace_elemwise_optimizer_op(OP):
if hasattr(op.scalar_op, "make_new_inplace"): if hasattr(op.scalar_op, "make_new_inplace"):
new_scal = op.scalar_op.make_new_inplace( new_scal = op.scalar_op.make_new_inplace(
scalar.transfer_type( scalar.transfer_type(
*[inplace_pattern.get(i, None) \ *[inplace_pattern.get(i, None)
for i in xrange(len(node.outputs))])) for i in xrange(len(node.outputs))]))
else: else:
new_scal = op.scalar_op.__class__( new_scal = op.scalar_op.__class__(
scalar.transfer_type( scalar.transfer_type(
*[inplace_pattern.get(i, None) \ *[inplace_pattern.get(i, None)
for i in xrange(len(node.outputs))])) for i in xrange(len(node.outputs))]))
new_outputs = OP(new_scal, inplace_pattern)( new_outputs = OP(new_scal, inplace_pattern)(
*node.inputs, **dict(return_list=True)) *node.inputs, **dict(return_list=True))
new_node = new_outputs[0].owner new_node = new_outputs[0].owner
for r, new_r in zip(node.outputs, new_outputs): for r, new_r in zip(node.outputs, new_outputs):
fgraph.replace(r, new_r, fgraph.replace(r, new_r,
reason="inplace_elemwise_optimizer") reason="inplace_elemwise_optimizer")
nb_change_no_validate += 1 nb_change_no_validate += 1
if nb_change_no_validate >= check_each_change: if nb_change_no_validate >= check_each_change:
fgraph.validate() fgraph.validate()
...@@ -295,9 +293,9 @@ def inplace_elemwise_optimizer_op(OP): ...@@ -295,9 +293,9 @@ def inplace_elemwise_optimizer_op(OP):
nb_change_no_validate = 0 nb_change_no_validate = 0
except (ValueError, TypeError, InconsistencyError) as e: except (ValueError, TypeError, InconsistencyError) as e:
if check_each_change != 1 and not raised_warning: if check_each_change != 1 and not raised_warning:
print(( print(("Some inplace optimization was not "
"Some inplace optimization was not " "performed due to unexpected error:"),
"performed due to unexpected error:"), file=sys.stderr) file=sys.stderr)
print(e, file=sys.stderr) print(e, file=sys.stderr)
raised_warning = True raised_warning = True
fgraph.revert(chk) fgraph.revert(chk)
...@@ -313,7 +311,8 @@ def inplace_elemwise_optimizer_op(OP): ...@@ -313,7 +311,8 @@ def inplace_elemwise_optimizer_op(OP):
except Exception: except Exception:
if not raised_warning: if not raised_warning:
print(("Some inplace optimization was not " print(("Some inplace optimization was not "
"performed due to unexpected error"), file=sys.stderr) "performed due to unexpected error"),
file=sys.stderr)
fgraph.revert(chk) fgraph.revert(chk)
return inplace_elemwise_optimizer return inplace_elemwise_optimizer
...@@ -381,8 +380,8 @@ def register_specialize_device(lopt, *tags, **kwargs): ...@@ -381,8 +380,8 @@ def register_specialize_device(lopt, *tags, **kwargs):
# Register merge_optimizer as a global opt during canonicalize # Register merge_optimizer as a global opt during canonicalize
compile.optdb['canonicalize'].register( compile.optdb['canonicalize'].register('canon_merge', merge_optimizer,
'canon_merge', merge_optimizer, 'fast_run', final_opt=True) 'fast_run', final_opt=True)
##################### #####################
...@@ -512,11 +511,10 @@ def local_lift_transpose_through_dot(node): ...@@ -512,11 +511,10 @@ def local_lift_transpose_through_dot(node):
inplace. The newly-introduced transpositions are not inplace, this will inplace. The newly-introduced transpositions are not inplace, this will
be taken care of in a later optimization phase. be taken care of in a later optimization phase.
""" """
if not (isinstance(node.op, T.DimShuffle) if not (isinstance(node.op, T.DimShuffle) and node.op.new_order == (1, 0)):
and node.op.new_order == (1, 0)):
return False return False
if not (node.inputs[0].owner if not (node.inputs[0].owner and
and isinstance(node.inputs[0].owner.op, T.Dot)): isinstance(node.inputs[0].owner.op, T.Dot)):
return False return False
x, y = node.inputs[0].owner.inputs x, y = node.inputs[0].owner.inputs
...@@ -601,22 +599,19 @@ class MakeVector(T.Op): ...@@ -601,22 +599,19 @@ class MakeVector(T.Op):
def make_node(self, *inputs): def make_node(self, *inputs):
inputs = list(map(T.as_tensor_variable, inputs)) inputs = list(map(T.as_tensor_variable, inputs))
if not all(a.type == inputs[0].type for a in inputs) or ( if (not all(a.type == inputs[0].type for a in inputs) or
len(inputs) > 0 and inputs[0].dtype != self.dtype): (len(inputs) > 0 and inputs[0].dtype != self.dtype)):
dtype = theano.scalar.upcast(self.dtype, dtype = theano.scalar.upcast(self.dtype, *[i.dtype for i in inputs])
*[i.dtype for i in inputs])
# upcast the input to the determined dtype, # upcast the input to the determined dtype,
# but don't downcast anything # but don't downcast anything
assert dtype == self.dtype, ( assert dtype == self.dtype, (
"The upcast of the inputs to MakeVector should match the " "The upcast of the inputs to MakeVector should match the "
"dtype given in __init__.") "dtype given in __init__.")
if not all(self.dtype == T.cast(i, dtype=dtype).dtype if not all(self.dtype == T.cast(i, dtype=dtype).dtype
for i in inputs): for i in inputs):
raise TypeError("MakeVector.make_node expected inputs" raise TypeError("MakeVector.make_node expected inputs"
" upcastable to %s. got %s" % ( " upcastable to %s. got %s" %
self.dtype, (self.dtype, str([i.dtype for i in inputs])))
str([i.dtype for i in inputs])
))
inputs = [T.cast(i, dtype=dtype) for i in inputs] inputs = [T.cast(i, dtype=dtype) for i in inputs]
assert all(self.dtype == a.dtype for a in inputs) assert all(self.dtype == a.dtype for a in inputs)
assert all(a.ndim == 0 for a in inputs) assert all(a.ndim == 0 for a in inputs)
...@@ -625,11 +620,9 @@ class MakeVector(T.Op): ...@@ -625,11 +620,9 @@ class MakeVector(T.Op):
dtype = inputs[0].type.dtype dtype = inputs[0].type.dtype
else: else:
dtype = self.dtype dtype = self.dtype
#bcastable = (len(inputs) == 1) # bcastable = (len(inputs) == 1)
bcastable = False bcastable = False
otype = T.TensorType( otype = T.TensorType(broadcastable=(bcastable,), dtype=dtype)
broadcastable=(bcastable,),
dtype=dtype)
return T.Apply(self, inputs, [otype()]) return T.Apply(self, inputs, [otype()])
def __str__(self): def __str__(self):
...@@ -700,13 +693,14 @@ class MakeVectorPrinter: ...@@ -700,13 +693,14 @@ class MakeVectorPrinter:
if r.owner is None: if r.owner is None:
raise TypeError("Can only print make_vector.") raise TypeError("Can only print make_vector.")
elif isinstance(r.owner.op, MakeVector): elif isinstance(r.owner.op, MakeVector):
return "[%s]" % ", ".join(pstate.pprinter.process( return "[%s]" % ", ".join(
input, pstate.clone(precedence=1000)) for input pstate.pprinter.process(input, pstate.clone(precedence=1000))
in r.owner.inputs) for input in r.owner.inputs)
else: else:
raise TypeError("Can only print make_vector.") raise TypeError("Can only print make_vector.")
T.pprint.assign(lambda pstate, r: r.owner and isinstance(
r.owner.op, MakeVector), MakeVectorPrinter()) T.pprint.assign(lambda pstate, r: r.owner and
isinstance(r.owner.op, MakeVector), MakeVectorPrinter())
class ShapeFeature(object): class ShapeFeature(object):
...@@ -843,8 +837,8 @@ class ShapeFeature(object): ...@@ -843,8 +837,8 @@ class ShapeFeature(object):
# by always returning the same object to represent 1 # by always returning the same object to represent 1
return self.lscalar_one return self.lscalar_one
if (type(s_i) in integer_types or if (type(s_i) in integer_types or
isinstance(s_i, numpy.integer) or isinstance(s_i, numpy.integer) or
(isinstance(s_i, numpy.ndarray) and s_i.ndim == 0)): (isinstance(s_i, numpy.ndarray) and s_i.ndim == 0)):
# this shape is a constant # this shape is a constant
assert s_i >= 0 assert s_i >= 0
return T.constant(s_i, dtype='int64') return T.constant(s_i, dtype='int64')
...@@ -859,9 +853,9 @@ class ShapeFeature(object): ...@@ -859,9 +853,9 @@ class ShapeFeature(object):
# s_i is x.shape[i], we change it to Shape_i. # s_i is x.shape[i], we change it to Shape_i.
if (s_i.owner and if (s_i.owner and
isinstance(s_i.owner.op, Subtensor) and isinstance(s_i.owner.op, Subtensor) and
s_i.owner.inputs[0].owner and s_i.owner.inputs[0].owner and
isinstance(s_i.owner.inputs[0].owner.op, T.Shape)): isinstance(s_i.owner.inputs[0].owner.op, T.Shape)):
assert s_i.ndim == 0 assert s_i.ndim == 0
assert len(s_i.owner.op.idx_list) == 1 assert len(s_i.owner.op.idx_list) == 1
...@@ -883,7 +877,7 @@ class ShapeFeature(object): ...@@ -883,7 +877,7 @@ class ShapeFeature(object):
return s_i return s_i
else: else:
raise TypeError('Unsupported shape element', raise TypeError('Unsupported shape element',
s_i, type(s_i), getattr(s_i, 'type', None)) s_i, type(s_i), getattr(s_i, 'type', None))
def set_shape(self, r, s): def set_shape(self, r, s):
"""Assign the shape `s` to previously un-shaped variable `r`. """Assign the shape `s` to previously un-shaped variable `r`.
...@@ -910,7 +904,7 @@ class ShapeFeature(object): ...@@ -910,7 +904,7 @@ class ShapeFeature(object):
shape_vars = [] shape_vars = []
for i in xrange(r.ndim): for i in xrange(r.ndim):
if (hasattr(r.type, 'broadcastable') and if (hasattr(r.type, 'broadcastable') and
r.type.broadcastable[i]): r.type.broadcastable[i]):
shape_vars.append(self.lscalar_one) shape_vars.append(self.lscalar_one)
else: else:
shape_vars.append(self.unpack(s[i])) shape_vars.append(self.unpack(s[i]))
...@@ -947,8 +941,8 @@ class ShapeFeature(object): ...@@ -947,8 +941,8 @@ class ShapeFeature(object):
self.set_shape(r, other_shape) self.set_shape(r, other_shape)
return return
if (other_r.owner and r.owner and if (other_r.owner and r.owner and
other_r.owner.inputs == r.owner.inputs and other_r.owner.inputs == r.owner.inputs and
other_r.owner.op == r.owner.op): other_r.owner.op == r.owner.op):
# We are doing a merge. So the 2 shapes graph will be the # We are doing a merge. So the 2 shapes graph will be the
# same. This is only a speed optimization to call # same. This is only a speed optimization to call
# ancestors() less frequently. # ancestors() less frequently.
...@@ -957,10 +951,10 @@ class ShapeFeature(object): ...@@ -957,10 +951,10 @@ class ShapeFeature(object):
# Merge other_shape with r_shape, giving the priority to other_shape # Merge other_shape with r_shape, giving the priority to other_shape
merged_shape = [] merged_shape = []
for i, ps in enumerate(other_shape): for i, ps in enumerate(other_shape):
if (ps.owner if (ps.owner and
and isinstance(getattr(ps.owner, 'op', None), Shape_i) isinstance(getattr(ps.owner, 'op', None), Shape_i) and
and ps.owner.op.i == i ps.owner.op.i == i and
and ps.owner.inputs[0] in (r, other_r)): ps.owner.inputs[0] in (r, other_r)):
# If other_shape[i] is uninformative, use r_shape[i]. # If other_shape[i] is uninformative, use r_shape[i].
# For now, we consider 2 cases of uninformative other_shape[i]: # For now, we consider 2 cases of uninformative other_shape[i]:
# - Shape_i(i)(other_r); # - Shape_i(i)(other_r);
...@@ -1084,11 +1078,11 @@ class ShapeFeature(object): ...@@ -1084,11 +1078,11 @@ class ShapeFeature(object):
r in node.inputs]) r in node.inputs])
except NotImplementedError as e: except NotImplementedError as e:
raise NotImplementedError( raise NotImplementedError(
'Code called by infer_shape failed raising a ' 'Code called by infer_shape failed raising a '
'NotImplementedError. Raising NotImplementedError to ' 'NotImplementedError. Raising NotImplementedError to '
'indicate that a shape cannot be computed is no longer ' 'indicate that a shape cannot be computed is no longer '
'supported, and one should now use tensor.ShapeError ' 'supported, and one should now use tensor.ShapeError '
'instead. The original exception message is: %s' % e) 'instead. The original exception message is: %s' % e)
except Exception as e: except Exception as e:
msg = ('Failed to infer_shape from Op %s.\nInput shapes: ' msg = ('Failed to infer_shape from Op %s.\nInput shapes: '
'%s\nException encountered during infer_shape: ' '%s\nException encountered during infer_shape: '
...@@ -1108,10 +1102,10 @@ class ShapeFeature(object): ...@@ -1108,10 +1102,10 @@ class ShapeFeature(object):
if len(o_shapes) != len(node.outputs): if len(o_shapes) != len(node.outputs):
raise Exception( raise Exception(
('The infer_shape method for the Op "%s" returned a list ' + ('The infer_shape method for the Op "%s" returned a list ' +
'with the wrong number of element: len(o_shapes) = %d ' + 'with the wrong number of element: len(o_shapes) = %d ' +
' != len(node.outputs) = %d') % (str(node.op), ' != len(node.outputs) = %d') % (str(node.op),
len(o_shapes), len(o_shapes),
len(node.outputs))) len(node.outputs)))
# Ensure shapes are in 'int64'. This is to make sure the assert # Ensure shapes are in 'int64'. This is to make sure the assert
# found in the `local_useless_subtensor` optimization does not fail. # found in the `local_useless_subtensor` optimization does not fail.
...@@ -1173,9 +1167,9 @@ class ShapeFeature(object): ...@@ -1173,9 +1167,9 @@ class ShapeFeature(object):
# with the InputToGpuOptimizer optimizer. # with the InputToGpuOptimizer optimizer.
continue continue
if (repl.owner and if (repl.owner and
repl.owner.inputs[0] is shpnode.inputs[0] and repl.owner.inputs[0] is shpnode.inputs[0] and
isinstance(repl.owner.op, Shape_i) and isinstance(repl.owner.op, Shape_i) and
repl.owner.op.i == shpnode.op.i): repl.owner.op.i == shpnode.op.i):
# The replacement is a shape_i of the same # The replacement is a shape_i of the same
# input. So no need to do this equivalent # input. So no need to do this equivalent
# replacement. # replacement.
...@@ -1239,7 +1233,7 @@ class ShapeFeature(object): ...@@ -1239,7 +1233,7 @@ class ShapeFeature(object):
if not dx.owner or not dy.owner: if not dx.owner or not dy.owner:
return False return False
if (not isinstance(dx.owner.op, Shape_i) or if (not isinstance(dx.owner.op, Shape_i) or
not isinstance(dy.owner.op, Shape_i)): not isinstance(dy.owner.op, Shape_i)):
return False return False
opx = dx.owner.op opx = dx.owner.op
opy = dy.owner.op opy = dy.owner.op
...@@ -1310,10 +1304,9 @@ def local_fill_to_alloc(node): ...@@ -1310,10 +1304,9 @@ def local_fill_to_alloc(node):
return return
# TODO: cut out un-necessary dimshuffles of v # TODO: cut out un-necessary dimshuffles of v
assert rval[0].type == node.outputs[0].type, ('rval', rval[0].type, assert rval[0].type == node.outputs[0].type, (
'orig', node.outputs[0].type, 'rval', rval[0].type, 'orig', node.outputs[0].type, 'node',
'node', node, node,) # theano.printing.debugprint(node.outputs[0], file='str'))
) # theano.printing.debugprint(node.outputs[0], file='str'))
return rval return rval
...@@ -1404,7 +1397,7 @@ def local_subtensor_make_vector(node): ...@@ -1404,7 +1397,7 @@ def local_subtensor_make_vector(node):
try: try:
idx, = node.op.idx_list idx, = node.op.idx_list
except Exception: except Exception:
#'how can you have multiple indexes into a shape?' # 'how can you have multiple indexes into a shape?'
raise raise
if isinstance(idx, (scalar.Scalar, T.TensorType)): if isinstance(idx, (scalar.Scalar, T.TensorType)):
...@@ -1467,13 +1460,13 @@ def local_useless_elemwise(node): ...@@ -1467,13 +1460,13 @@ def local_useless_elemwise(node):
if isinstance(node.op, T.Elemwise): if isinstance(node.op, T.Elemwise):
if node.op.scalar_op == theano.scalar.eq and len(node.inputs) == 2: if node.op.scalar_op == theano.scalar.eq and len(node.inputs) == 2:
if node.inputs[0] == node.inputs[1]: if node.inputs[0] == node.inputs[1]:
# it is the same var in the graph. That will always be true # it is the same var in the graph. That will always be true
return [T.fill(node.inputs[0], return [T.fill(node.inputs[0],
T.constant(1.0, T.constant(1.0,
dtype=node.outputs[0].type.dtype))] dtype=node.outputs[0].type.dtype))]
elif node.op.scalar_op == theano.scalar.neq and len(node.inputs) == 2: elif node.op.scalar_op == theano.scalar.neq and len(node.inputs) == 2:
if node.inputs[0] == node.inputs[1]: if node.inputs[0] == node.inputs[1]:
# it is the same var in the graph. That will always be false # it is the same var in the graph. That will always be false
return [T.fill(node.inputs[0], return [T.fill(node.inputs[0],
T.constant(0.0, T.constant(0.0,
dtype=node.outputs[0].type.dtype))] dtype=node.outputs[0].type.dtype))]
...@@ -1482,8 +1475,8 @@ def local_useless_elemwise(node): ...@@ -1482,8 +1475,8 @@ def local_useless_elemwise(node):
elif node.op.scalar_op == theano.scalar.add and len(node.inputs) == 1: elif node.op.scalar_op == theano.scalar.add and len(node.inputs) == 1:
return [node.inputs[0]] return [node.inputs[0]]
elif (node.op.scalar_op == theano.scalar.identity elif (node.op.scalar_op == theano.scalar.identity and
and len(node.inputs) == 1): len(node.inputs) == 1):
return [node.inputs[0]] return [node.inputs[0]]
...@@ -1513,12 +1506,12 @@ def local_cast_cast(node): ...@@ -1513,12 +1506,12 @@ def local_cast_cast(node):
and the first cast cause an upcast. and the first cast cause an upcast.
""" """
if (not isinstance(node.op, T.Elemwise) or if (not isinstance(node.op, T.Elemwise) or
not isinstance(node.op.scalar_op, scalar.Cast)): not isinstance(node.op.scalar_op, scalar.Cast)):
return return
x = node.inputs[0] x = node.inputs[0]
if (not x.owner or if (not x.owner or
not isinstance(x.owner.op, T.Elemwise) or not isinstance(x.owner.op, T.Elemwise) or
not isinstance(x.owner.op.scalar_op, scalar.Cast)): not isinstance(x.owner.op.scalar_op, scalar.Cast)):
return return
if node.op.scalar_op.o_type == x.owner.op.scalar_op.o_type: if node.op.scalar_op.o_type == x.owner.op.scalar_op.o_type:
return [x] return [x]
...@@ -1738,7 +1731,7 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1738,7 +1731,7 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
# The broadcast pattern of the ouptut must match the broadcast # The broadcast pattern of the ouptut must match the broadcast
# pattern of at least one of the inputs. # pattern of at least one of the inputs.
if not any([i.type.broadcastable == if not any([i.type.broadcastable ==
node.outputs[0].type.broadcastable for i in node.inputs]): node.outputs[0].type.broadcastable for i in node.inputs]):
return False return False
def dimshuffled_alloc(i): def dimshuffled_alloc(i):
...@@ -1749,10 +1742,8 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1749,10 +1742,8 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
# At least one input must have an owner that is either a AllocOP or a # At least one input must have an owner that is either a AllocOP or a
# DimShuffleOP with an owner that is a AllocOP -- otherwise there is # DimShuffleOP with an owner that is a AllocOP -- otherwise there is
# nothing to optimize. # nothing to optimize.
if not any([i.owner if not any([i.owner and (isinstance(i.owner.op, AllocOP) or
and (isinstance(i.owner.op, AllocOP) or dimshuffled_alloc(i)) for i in node.inputs]):
dimshuffled_alloc(i))
for i in node.inputs]):
return False return False
# Search for input that we can use as a baseline for the dimensions. # Search for input that we can use as a baseline for the dimensions.
...@@ -1761,9 +1752,8 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1761,9 +1752,8 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
if i.type.broadcastable == node.outputs[0].type.broadcastable: if i.type.broadcastable == node.outputs[0].type.broadcastable:
# Prefer an input that is not a AllocOP nor a DimShuffleOP of a # Prefer an input that is not a AllocOP nor a DimShuffleOP of a
# AllocOP so that all allocs can be optimized. # AllocOP so that all allocs can be optimized.
if not (i.owner if not (i.owner and (isinstance(i.owner.op, AllocOP) or
and (isinstance(i.owner.op, AllocOP) dimshuffled_alloc(i))):
or dimshuffled_alloc(i))):
assert_op_idx = idx assert_op_idx = idx
break break
...@@ -1773,8 +1763,8 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1773,8 +1763,8 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
# there is more than one then do all but one. number of # there is more than one then do all but one. number of
# inputs with alloc or dimshuffle alloc # inputs with alloc or dimshuffle alloc
l2 = [i for i in node.inputs l2 = [i for i in node.inputs
if (i.owner and (isinstance(i.owner.op, AllocOP) if (i.owner and (isinstance(i.owner.op, AllocOP) or
or dimshuffled_alloc(i)))] dimshuffled_alloc(i)))]
# If only 1 alloc or dimshuffle alloc, it is the one we # If only 1 alloc or dimshuffle alloc, it is the one we
# will use for the shape. So no alloc would be removed. # will use for the shape. So no alloc would be removed.
if len(l2) > 1: if len(l2) > 1:
...@@ -1794,14 +1784,13 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1794,14 +1784,13 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
same_shape = node.fgraph.shape_feature.same_shape same_shape = node.fgraph.shape_feature.same_shape
for i in node.inputs: for i in node.inputs:
# Remove alloc # Remove alloc
if (i.owner and isinstance(i.owner.op, AllocOP) if (i.owner and isinstance(i.owner.op, AllocOP) and
and i.owner.inputs[0].type != i.owner.outputs[0].type): i.owner.inputs[0].type != i.owner.outputs[0].type):
# when i.owner.inputs[0].type == i.owner.outputs[0].type we # when i.owner.inputs[0].type == i.owner.outputs[0].type we
# will remove that alloc later # will remove that alloc later
assert i.type.ndim == cmp_op.ndim assert i.type.ndim == cmp_op.ndim
if (theano.config.experimental.local_alloc_elemwise_assert if (theano.config.experimental.local_alloc_elemwise_assert and
and not same_shape(i, cmp_op)): not same_shape(i, cmp_op)):
assert_op = assert_(assert_op, assert_op = assert_(assert_op,
*[T.eq(i.shape[idx], cmp_op.shape[idx]) *[T.eq(i.shape[idx], cmp_op.shape[idx])
for idx in xrange(i.type.ndim) for idx in xrange(i.type.ndim)
...@@ -1891,7 +1880,7 @@ def local_upcast_elemwise_constant_inputs(node): ...@@ -1891,7 +1880,7 @@ def local_upcast_elemwise_constant_inputs(node):
scalar_op = node.op.scalar_op scalar_op = node.op.scalar_op
# print "aa", scalar_op.output_types_preference # print "aa", scalar_op.output_types_preference
if (getattr(scalar_op, 'output_types_preference', None) if (getattr(scalar_op, 'output_types_preference', None)
in (T.scal.upgrade_to_float, T.scal.upcast_out)): in (T.scal.upgrade_to_float, T.scal.upcast_out)):
# this is the kind of op that we can screw with the input # this is the kind of op that we can screw with the input
# dtypes by upcasting explicitly # dtypes by upcasting explicitly
output_dtype = node.outputs[0].type.dtype output_dtype = node.outputs[0].type.dtype
...@@ -1909,12 +1898,12 @@ def local_upcast_elemwise_constant_inputs(node): ...@@ -1909,12 +1898,12 @@ def local_upcast_elemwise_constant_inputs(node):
i.ndim)) i.ndim))
else: else:
if shape_i is None: if shape_i is None:
return return new_inputs.append(
new_inputs.append(T.alloc(T.cast(cval_i, T.alloc(T.cast(cval_i, output_dtype),
output_dtype), *[shape_i(d)(i)
*[shape_i(d)(i) for d in xrange(i.ndim)])) for d in xrange(i.ndim)]))
#print >> sys.stderr, "AAA", # print >> sys.stderr, "AAA",
#*[Shape_i(d)(i) for d in xrange(i.ndim)] # *[Shape_i(d)(i) for d in xrange(i.ndim)]
except NotScalarConstantError: except NotScalarConstantError:
# for the case of a non-scalar # for the case of a non-scalar
if isinstance(i, T.TensorConstant): if isinstance(i, T.TensorConstant):
...@@ -1958,7 +1947,7 @@ def local_useless_inc_subtensor(node): ...@@ -1958,7 +1947,7 @@ def local_useless_inc_subtensor(node):
except NotScalarConstantError: except NotScalarConstantError:
return return
if (node.inputs[0].ndim != node.inputs[1].ndim or if (node.inputs[0].ndim != node.inputs[1].ndim or
node.inputs[0].broadcastable != node.inputs[1].broadcastable): node.inputs[0].broadcastable != node.inputs[1].broadcastable):
# FB: I didn't check if this case can happen, but this opt # FB: I didn't check if this case can happen, but this opt
# don't support it. # don't support it.
return return
...@@ -1994,16 +1983,16 @@ def local_set_to_inc_subtensor(node): ...@@ -1994,16 +1983,16 @@ def local_set_to_inc_subtensor(node):
AdvancedIncSubtensor1(x, other, ilist, set_instead_of_inc=False) AdvancedIncSubtensor1(x, other, ilist, set_instead_of_inc=False)
""" """
if (isinstance(node.op, AdvancedIncSubtensor1) and if (isinstance(node.op, AdvancedIncSubtensor1) and
node.op.set_instead_of_inc == True and node.op.set_instead_of_inc and
node.inputs[1].owner and node.inputs[1].owner and
isinstance(node.inputs[1].owner.op, Elemwise) and isinstance(node.inputs[1].owner.op, Elemwise) and
isinstance(node.inputs[1].owner.op.scalar_op, scalar.Add)): isinstance(node.inputs[1].owner.op.scalar_op, scalar.Add)):
addn = node.inputs[1].owner addn = node.inputs[1].owner
subn = None subn = None
other = None other = None
if (addn.inputs[0].owner and if (addn.inputs[0].owner and
isinstance(addn.inputs[0].owner.op, AdvancedSubtensor1)): isinstance(addn.inputs[0].owner.op, AdvancedSubtensor1)):
subn = addn.inputs[0].owner subn = addn.inputs[0].owner
other = addn.inputs[1] other = addn.inputs[1]
elif (addn.inputs[1].owner and elif (addn.inputs[1].owner and
...@@ -2013,7 +2002,7 @@ def local_set_to_inc_subtensor(node): ...@@ -2013,7 +2002,7 @@ def local_set_to_inc_subtensor(node):
else: else:
return return
if (subn.inputs[1] != node.inputs[2] or if (subn.inputs[1] != node.inputs[2] or
subn.inputs[0] != node.inputs[0]): subn.inputs[0] != node.inputs[0]):
return return
return [advanced_inc_subtensor1(node.inputs[0], other, node.inputs[2])] return [advanced_inc_subtensor1(node.inputs[0], other, node.inputs[2])]
...@@ -2030,9 +2019,9 @@ def local_useless_slice(node): ...@@ -2030,9 +2019,9 @@ def local_useless_slice(node):
last_slice = len(slices) last_slice = len(slices)
for s in slices[::-1]: for s in slices[::-1]:
# check if slice and then check slice indices # check if slice and then check slice indices
if (isinstance(s, slice) and s.start is None and s.stop is None if (isinstance(s, slice) and s.start is None and s.stop is None and
and (s.step is None or T.extract_constant(s.step) == 1)): (s.step is None or T.extract_constant(s.step) == 1)):
last_slice -= 1 last_slice -= 1
else: else:
break break
# check if we removed something # check if we removed something
...@@ -2098,11 +2087,10 @@ def local_useless_subtensor(node): ...@@ -2098,11 +2087,10 @@ def local_useless_subtensor(node):
# the same underlying variable. # the same underlying variable.
if (length_pos_shape_i.owner and if (length_pos_shape_i.owner and
isinstance(length_pos_shape_i.owner.op, isinstance(length_pos_shape_i.owner.op,
T.ScalarFromTensor)): T.ScalarFromTensor)):
length_pos_shape_i = length_pos_shape_i.owner.inputs[0] length_pos_shape_i = length_pos_shape_i.owner.inputs[0]
elif (length_pos.owner and elif (length_pos.owner and
isinstance(length_pos.owner.op, isinstance(length_pos.owner.op, T.TensorFromScalar)):
T.TensorFromScalar)):
length_pos = length_pos.owner.inputs[0] length_pos = length_pos.owner.inputs[0]
else: else:
# We did not find underlying variables of the same type # We did not find underlying variables of the same type
...@@ -2322,8 +2310,8 @@ def merge_two_slices(slice1, len1, slice2, len2): ...@@ -2322,8 +2310,8 @@ def merge_two_slices(slice1, len1, slice2, len2):
pn_stop = sl1.start + (sl2.start - 1) * sl1.step pn_stop = sl1.start + (sl2.start - 1) * sl1.step
pn_stop = T.switch(T.and_(T.lt(pn_stop, 0), pn_stop = T.switch(T.and_(T.lt(pn_stop, 0),
T.gt(flen, 0)), T.gt(flen, 0)),
-len1 - 1, -len1 - 1,
T.minimum(pn_stop, sl1.stop)) T.minimum(pn_stop, sl1.stop))
pn_start = sl1.start + (sl2.stop - 1) * sl1.step pn_start = sl1.start + (sl2.stop - 1) * sl1.step
pn_start = T.minimum(pn_start, sl1.stop) pn_start = T.minimum(pn_start, sl1.stop)
pn_start = T.maximum(pn_start, 0) pn_start = T.maximum(pn_start, 0)
...@@ -2345,9 +2333,8 @@ def merge_two_slices(slice1, len1, slice2, len2): ...@@ -2345,9 +2333,8 @@ def merge_two_slices(slice1, len1, slice2, len2):
pp_start)) pp_start))
stop = T.switch(T.lt(reverse2 * reverse1, 0), stop = T.switch(T.lt(reverse2 * reverse1, 0),
T.switch(T.lt(reverse1, 0), np_stop, pn_stop), T.switch(T.lt(reverse1, 0), np_stop, pn_stop),
T.switch(T.lt(reverse1, 0), nn_stop, pp_stop T.switch(T.lt(reverse1, 0), nn_stop, pp_stop))
))
step = T.switch(T.lt(reverse2 * reverse1, 0), n_step, p_step) step = T.switch(T.lt(reverse2 * reverse1, 0), n_step, p_step)
start = T.switch(T.le(flen, 0), 0, start) start = T.switch(T.le(flen, 0), 0, start)
...@@ -2463,7 +2450,7 @@ def local_subtensor_of_alloc(node): ...@@ -2463,7 +2450,7 @@ def local_subtensor_of_alloc(node):
# We check that the corresponding val dimensions was # We check that the corresponding val dimensions was
# not a broadcasted dimensions. # not a broadcasted dimensions.
if (val.type.ndim > (i - n_added_dims) and if (val.type.ndim > (i - n_added_dims) and
val.type.broadcastable[i - n_added_dims]): val.type.broadcastable[i - n_added_dims]):
val_slices.append(slice(None)) val_slices.append(slice(None))
else: else:
val_slices.append(sl) val_slices.append(sl)
...@@ -2496,8 +2483,8 @@ def local_subtensor_of_alloc(node): ...@@ -2496,8 +2483,8 @@ def local_subtensor_of_alloc(node):
rval[0] = theano.tensor.unbroadcast( rval[0] = theano.tensor.unbroadcast(
rval[0], rval[0],
*[i for i, (b1, b2) in enumerate(zip(rval[0].broadcastable, *[i for i, (b1, b2) in enumerate(zip(rval[0].broadcastable,
node.outputs[0].broadcastable)) node.outputs[0].broadcastable))
if b1 and not b2]) if b1 and not b2])
return rval return rval
...@@ -2518,7 +2505,7 @@ def local_subtensor_of_dot(node): ...@@ -2518,7 +2505,7 @@ def local_subtensor_of_dot(node):
if not isinstance(node.op, Subtensor): if not isinstance(node.op, Subtensor):
return return
if (not node.inputs[0].owner or if (not node.inputs[0].owner or
not isinstance(node.inputs[0].owner.op, T.Dot)): not isinstance(node.inputs[0].owner.op, T.Dot)):
return return
# If there is other node that use the outputs of the dot # If there is other node that use the outputs of the dot
# We don't want to compute twice the sub part. # We don't want to compute twice the sub part.
...@@ -2540,7 +2527,8 @@ def local_subtensor_of_dot(node): ...@@ -2540,7 +2527,8 @@ def local_subtensor_of_dot(node):
# We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:] # We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:]
# (dot also handles b.ndim < 2 as a special case) # (dot also handles b.ndim < 2 as a special case)
if b.ndim > 1 and len(b_indices) >= b.ndim - 1: if b.ndim > 1 and len(b_indices) >= b.ndim - 1:
b_indices = b_indices[:b.ndim-2] + (slice(None, None, None),) + b_indices[b.ndim-2:] b_indices = (b_indices[:b.ndim - 2] +
(slice(None, None, None),) + b_indices[b.ndim - 2:])
a_sub = a.__getitem__(tuple(a_indices)) a_sub = a.__getitem__(tuple(a_indices))
b_sub = b.__getitem__(tuple(b_indices)) if b_indices else b b_sub = b.__getitem__(tuple(b_indices)) if b_indices else b
...@@ -2583,14 +2571,13 @@ def local_IncSubtensor_serialize(node): ...@@ -2583,14 +2571,13 @@ def local_IncSubtensor_serialize(node):
""" """
def movable(i): def movable(i):
# Return True iff this is a incsubtensor that we can move # Return True iff this is a incsubtensor that we can move
return i.owner \ return (i.owner and
and isinstance(i.owner.op, (IncSubtensor, isinstance(i.owner.op, (IncSubtensor,
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
AdvancedIncSubtensor, AdvancedIncSubtensor,)) and
)) \ i.type == o_type and
and i.type == o_type \ len(i.clients) == 1 and
and len(i.clients) == 1 \ not i.owner.op.set_instead_of_inc)
and not i.owner.op.set_instead_of_inc
if node.op == T.add: if node.op == T.add:
o_type = node.outputs[0].type o_type = node.outputs[0].type
...@@ -2598,8 +2585,8 @@ def local_IncSubtensor_serialize(node): ...@@ -2598,8 +2585,8 @@ def local_IncSubtensor_serialize(node):
movable_inputs = [i for i in node.inputs if movable(i)] movable_inputs = [i for i in node.inputs if movable(i)]
if movable_inputs: if movable_inputs:
new_inputs = [i for i in node.inputs if not movable(i)] \ new_inputs = ([i for i in node.inputs if not movable(i)] +
+ [mi.owner.inputs[0] for mi in movable_inputs] [mi.owner.inputs[0] for mi in movable_inputs])
new_add = T.add(*new_inputs) new_add = T.add(*new_inputs)
# stack up the new incsubtensors # stack up the new incsubtensors
...@@ -2638,9 +2625,10 @@ def local_inplace_setsubtensor(node): ...@@ -2638,9 +2625,10 @@ def local_inplace_setsubtensor(node):
return [new_node] return [new_node]
return False return False
compile.optdb.register('local_inplace_setsubtensor', compile.optdb.register('local_inplace_setsubtensor',
TopoOptimizer(local_inplace_setsubtensor, TopoOptimizer(
failure_callback=TopoOptimizer.warn_inplace), 60, local_inplace_setsubtensor,
'fast_run', 'inplace') # DEBUG failure_callback=TopoOptimizer.warn_inplace),
60, 'fast_run', 'inplace') # DEBUG
@gof.local_optimizer([AdvancedIncSubtensor1], inplace=True) @gof.local_optimizer([AdvancedIncSubtensor1], inplace=True)
...@@ -2653,8 +2641,8 @@ def local_inplace_incsubtensor1(node): ...@@ -2653,8 +2641,8 @@ def local_inplace_incsubtensor1(node):
return False return False
compile.optdb.register('local_inplace_incsubtensor1', compile.optdb.register('local_inplace_incsubtensor1',
TopoOptimizer( TopoOptimizer(
local_inplace_incsubtensor1, local_inplace_incsubtensor1,
failure_callback=TopoOptimizer.warn_inplace), failure_callback=TopoOptimizer.warn_inplace),
60, 'fast_run', 'inplace') # DEBUG 60, 'fast_run', 'inplace') # DEBUG
...@@ -2671,7 +2659,7 @@ def local_incsubtensor_of_zeros(node): ...@@ -2671,7 +2659,7 @@ def local_incsubtensor_of_zeros(node):
if (isinstance(node.op, (IncSubtensor, if (isinstance(node.op, (IncSubtensor,
AdvancedIncSubtensor, AdvancedIncSubtensor,
AdvancedIncSubtensor1)) and AdvancedIncSubtensor1)) and
not node.op.set_instead_of_inc): not node.op.set_instead_of_inc):
x = node.inputs[0] x = node.inputs[0]
y = node.inputs[1] y = node.inputs[1]
replace = False replace = False
...@@ -2713,8 +2701,8 @@ def local_setsubtensor_of_constants(node): ...@@ -2713,8 +2701,8 @@ def local_setsubtensor_of_constants(node):
pass pass
if (replace_x is not None and if (replace_x is not None and
replace_y is not None and replace_y is not None and
replace_x == replace_y): replace_x == replace_y):
return [x] return [x]
else: else:
return False return False
...@@ -2738,7 +2726,7 @@ def local_adv_sub1_adv_inc_sub1(node): ...@@ -2738,7 +2726,7 @@ def local_adv_sub1_adv_inc_sub1(node):
return return
inp = node.inputs[0] inp = node.inputs[0]
if (not inp.owner or if (not inp.owner or
not isinstance(inp.owner.op, AdvancedIncSubtensor1)): not isinstance(inp.owner.op, AdvancedIncSubtensor1)):
return return
idx = node.inputs[1] idx = node.inputs[1]
idx2 = inp.owner.inputs[2] idx2 = inp.owner.inputs[2]
...@@ -2747,13 +2735,13 @@ def local_adv_sub1_adv_inc_sub1(node): ...@@ -2747,13 +2735,13 @@ def local_adv_sub1_adv_inc_sub1(node):
if idx is not idx2: if idx is not idx2:
return return
if (not inp.owner.op.set_instead_of_inc and if (not inp.owner.op.set_instead_of_inc and
T.extract_constant(x) != 0): T.extract_constant(x) != 0):
return return
cond = [T.all(T.and_(T.lt(idx, x.shape[0]), cond = [T.all(T.and_(T.lt(idx, x.shape[0]), T.ge(idx, -x.shape[0])))]
T.ge(idx, -x.shape[0])))]
if not node.fgraph.shape_feature.same_shape(idx, y, 0, 0): if not node.fgraph.shape_feature.same_shape(idx, y, 0, 0):
cond.append(T.eq(idx.shape[0], y.shape[0])) cond.append(T.eq(idx.shape[0], y.shape[0]))
y = Assert("Bad indexing or shapes in a AdvancedIncSubtensor1 that was optimized away")(y, *cond) y = Assert("Bad indexing or shapes in a AdvancedIncSubtensor1 "
"that was optimized away")(y, *cond)
if y.dtype == node.outputs[0].dtype: if y.dtype == node.outputs[0].dtype:
return [y] return [y]
...@@ -2828,33 +2816,34 @@ def local_useless_inc_subtensor_alloc(node): ...@@ -2828,33 +2816,34 @@ def local_useless_inc_subtensor_alloc(node):
# Build `z_broad` explicitly to include extra implicit dimensions. # Build `z_broad` explicitly to include extra implicit dimensions.
z_broad = ((True,) * (xi.ndim - z.ndim) + z.broadcastable) z_broad = ((True,) * (xi.ndim - z.ndim) + z.broadcastable)
cond = [# The shapes of `y` and `xi` must either agree or `y` may cond = [
# also have shape equal to 1 which may be treated as a # The shapes of `y` and `xi` must either agree or `y` may
# broadcastable dimension by the subtensor op. # also have shape equal to 1 which may be treated as a
T.or_(T.eq(y.shape[k], 1), T.eq(y.shape[k], xi.shape[k])) # broadcastable dimension by the subtensor op.
# Loop over all dimensions. T.or_(T.eq(y.shape[k], 1), T.eq(y.shape[k], xi.shape[k]))
for k in xrange(xi.ndim) # Loop over all dimensions.
# We need to check the above shapes, if for k in xrange(xi.ndim)
# * the pre-alloc increment `z` is broadcastable in # We need to check the above shapes, if
# dimension `k` (if it isn't, then the shapes of `z` and # * the pre-alloc increment `z` is broadcastable in
# `y` are the same by the definition of the `Alloc` op in # dimension `k` (if it isn't, then the shapes of `z` and
# this dimension and replacing `y` by `z` will not hide a # `y` are the same by the definition of the `Alloc` op in
# shape error), and # this dimension and replacing `y` by `z` will not hide a
# * `xi` and `y` do not have the same shape in dimension # shape error), and
# `k` or we cannot infer the shape statically (if the # * `xi` and `y` do not have the same shape in dimension
# shapes of `xi` and `y` are not the same, then replacing # `k` or we cannot infer the shape statically (if the
# `y` by `z` will hide the shape error of `y`), and # shapes of `xi` and `y` are not the same, then replacing
# * the shape of `y` is not equal to 1 or we cannot infer # `y` by `z` will hide the shape error of `y`), and
# the shape statically (if the shape of `y` is equal to # * the shape of `y` is not equal to 1 or we cannot infer
# 1, then `y` is broadcasted by the inc_subtensor op # the shape statically (if the shape of `y` is equal to
# internally, so the shapes of `xi` and `y` do not need # 1, then `y` is broadcasted by the inc_subtensor op
# to match in dimension `k`; else we need to check at # internally, so the shapes of `xi` and `y` do not need
# runtime that the shape of `y` is either 1 or the same # to match in dimension `k`; else we need to check at
# as `xi` or otherwise replacing `y` by `z` will hide a # runtime that the shape of `y` is either 1 or the same
# shape error). # as `xi` or otherwise replacing `y` by `z` will hide a
if (z_broad[k] and # shape error).
not same_shape(xi, y, dim_x=k, dim_y=k) and if (z_broad[k] and
shape_of[y][k] != 1)] not same_shape(xi, y, dim_x=k, dim_y=k) and
shape_of[y][k] != 1)]
if len(cond) > 0: if len(cond) > 0:
msg = '`x[i]` and `y` do not have the same shape.' msg = '`x[i]` and `y` do not have the same shape.'
...@@ -2916,7 +2905,7 @@ def local_rebroadcast_lift(node): ...@@ -2916,7 +2905,7 @@ def local_rebroadcast_lift(node):
# compilation phase. # compilation phase.
if hasattr(input, 'clients') and len(input.clients) == 1: if hasattr(input, 'clients') and len(input.clients) == 1:
rval = inode.op.make_node(T.Rebroadcast(*list(op.axis.items()))( rval = inode.op.make_node(T.Rebroadcast(*list(op.axis.items()))(
inode.inputs[0])).outputs inode.inputs[0])).outputs
return rval return rval
if inode and isinstance(inode.op, T.Rebroadcast): if inode and isinstance(inode.op, T.Rebroadcast):
# the "axis" specification in the outer Rebroadcast overrides # the "axis" specification in the outer Rebroadcast overrides
...@@ -3031,11 +3020,11 @@ def local_join_make_vector(node): ...@@ -3031,11 +3020,11 @@ def local_join_make_vector(node):
for idx in xrange(2, len(node.inputs)): for idx in xrange(2, len(node.inputs)):
inp = node.inputs[idx] inp = node.inputs[idx]
if (inp.owner and if (inp.owner and
isinstance(inp.owner.op, MakeVector) and isinstance(inp.owner.op, MakeVector) and
new_inputs[-1].owner and new_inputs[-1].owner and
isinstance(new_inputs[-1].owner.op, MakeVector) and isinstance(new_inputs[-1].owner.op, MakeVector) and
# MakeVector have a dtype parameter # MakeVector have a dtype parameter
inp.owner.op == new_inputs[-1].owner.op): inp.owner.op == new_inputs[-1].owner.op):
inps = new_inputs[-1].owner.inputs + inp.owner.inputs inps = new_inputs[-1].owner.inputs + inp.owner.inputs
new_inputs[-1] = inp.owner.op(*inps) new_inputs[-1] = inp.owner.op(*inps)
else: else:
...@@ -3059,7 +3048,7 @@ def local_remove_switch_const_cond(node): ...@@ -3059,7 +3048,7 @@ def local_remove_switch_const_cond(node):
if cond is constant and cond != 0: left if cond is constant and cond != 0: left
""" """
if (isinstance(node.op, T.Elemwise) and if (isinstance(node.op, T.Elemwise) and
isinstance(node.op.scalar_op, scalar.basic.Switch)): isinstance(node.op.scalar_op, scalar.basic.Switch)):
cond = T.extract_constant(node.inputs[0], elemwise=False) cond = T.extract_constant(node.inputs[0], elemwise=False)
if type(cond) is numpy.ndarray and cond.ndim == 0: if type(cond) is numpy.ndarray and cond.ndim == 0:
if cond == 0: if cond == 0:
...@@ -3241,9 +3230,9 @@ def local_flatten_lift(node): ...@@ -3241,9 +3230,9 @@ def local_flatten_lift(node):
nnet/sigm.py:log1msigm_to_softplus to get applied when there is a flatten. nnet/sigm.py:log1msigm_to_softplus to get applied when there is a flatten.
""" """
if (isinstance(node.op, T.Flatten) and if (isinstance(node.op, T.Flatten) and
node.inputs[0].owner and node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, T.Elemwise) and isinstance(node.inputs[0].owner.op, T.Elemwise) and
len(node.inputs[0].owner.inputs) == 1): len(node.inputs[0].owner.inputs) == 1):
f = node.op(node.inputs[0].owner.inputs[0]) f = node.op(node.inputs[0].owner.inputs[0])
e = node.inputs[0].owner.op(f) e = node.inputs[0].owner.op(f)
return [e] return [e]
...@@ -3290,9 +3279,9 @@ def local_reshape_lift(node): ...@@ -3290,9 +3279,9 @@ def local_reshape_lift(node):
nnet/sigm.py:log1msigm_to_softplus to get applied when there is a reshape. nnet/sigm.py:log1msigm_to_softplus to get applied when there is a reshape.
""" """
if (isinstance(node.op, T.Reshape) and if (isinstance(node.op, T.Reshape) and
node.inputs[0].owner and node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, T.Elemwise) and isinstance(node.inputs[0].owner.op, T.Elemwise) and
len(node.inputs[0].owner.inputs) == 1): len(node.inputs[0].owner.inputs) == 1):
r = node.op(node.inputs[0].owner.inputs[0], node.inputs[1]) r = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])
e = node.inputs[0].owner.op(r) e = node.inputs[0].owner.op(r)
# In rare case the original broadcast was (False, True), but # In rare case the original broadcast was (False, True), but
...@@ -3539,7 +3528,7 @@ class Canonizer(gof.LocalOptimizer): ...@@ -3539,7 +3528,7 @@ class Canonizer(gof.LocalOptimizer):
return [input], [] return [input], []
if input.owner is None or input.owner.op not in [ if input.owner is None or input.owner.op not in [
self.main, self.inverse, self.reciprocal]: self.main, self.inverse, self.reciprocal]:
if input.owner and isinstance(input.owner.op, T.DimShuffle): if input.owner and isinstance(input.owner.op, T.DimShuffle):
# If input is a DimShuffle of some input which does # If input is a DimShuffle of some input which does
# something like this: # something like this:
...@@ -3552,9 +3541,9 @@ class Canonizer(gof.LocalOptimizer): ...@@ -3552,9 +3541,9 @@ class Canonizer(gof.LocalOptimizer):
# the num/denum of its input # the num/denum of its input
dsn = input.owner # dimshuffle node dsn = input.owner # dimshuffle node
dsop = dsn.op # dimshuffle op dsop = dsn.op # dimshuffle op
dsi0 = dsn.inputs[0] # the first input of the
# dimshuffle i.e. the ndarray to # the first input of the dimshuffle i.e. the ndarray to redim
# redim dsi0 = dsn.inputs[0]
# The compatible order is a DimShuffle "new_order" of the form: # The compatible order is a DimShuffle "new_order" of the form:
# ('x', ..., 'x', 0, 1, 2, ..., dimshuffle_input.type.ndim) # ('x', ..., 'x', 0, 1, 2, ..., dimshuffle_input.type.ndim)
...@@ -3566,9 +3555,9 @@ class Canonizer(gof.LocalOptimizer): ...@@ -3566,9 +3555,9 @@ class Canonizer(gof.LocalOptimizer):
# different numbers of dimensions (hence why we can # different numbers of dimensions (hence why we can
# discard its information - we know we can retrieve it # discard its information - we know we can retrieve it
# later on). # later on).
compatible_order = ('x',) * (input.type.ndim compatible_order = (('x',) *
- dsi0.type.ndim) + tuple( (input.type.ndim - dsi0.type.ndim) +
range(dsi0.type.ndim)) tuple(range(dsi0.type.ndim)))
if dsop.new_order == compatible_order: if dsop.new_order == compatible_order:
# If the "new_order" is the one we recognize, # If the "new_order" is the one we recognize,
# we return the num_denum of the dimshuffled input. # we return the num_denum of the dimshuffled input.
...@@ -3815,9 +3804,9 @@ class Canonizer(gof.LocalOptimizer): ...@@ -3815,9 +3804,9 @@ class Canonizer(gof.LocalOptimizer):
new = self.merge_num_denum(num, denum) new = self.merge_num_denum(num, denum)
if new.type.dtype != out.type.dtype: if new.type.dtype != out.type.dtype:
#new = T.fill(out, new) # new = T.fill(out, new)
elem_op = T.Elemwise(scalar.Identity(scalar.specific_out( elem_op = T.Elemwise(scalar.Identity(scalar.specific_out(
getattr(scalar, out.type.dtype)))) getattr(scalar, out.type.dtype))))
new = elem_op(new) new = elem_op(new)
assert (new.type == out.type) == (not (new.type != out.type)) assert (new.type == out.type) == (not (new.type != out.type))
...@@ -3833,12 +3822,12 @@ class Canonizer(gof.LocalOptimizer): ...@@ -3833,12 +3822,12 @@ class Canonizer(gof.LocalOptimizer):
else: else:
_logger.warning(' '.join(('CANONIZE FAILED: new, out = ', _logger.warning(' '.join(('CANONIZE FAILED: new, out = ',
new, ',', out, 'types', new, ',', out, 'types',
new.type, ',', out.type))) new.type, ',', out.type)))
return False return False
def __str__(self): def __str__(self):
return getattr(self, 'name', 'Canonizer(%s, %s, %s)' % ( return getattr(self, 'name', 'Canonizer(%s, %s, %s)' % (
self.main, self.inverse, self.reciprocal)) self.main, self.inverse, self.reciprocal))
def mul_calculate(num, denum, aslist=False, out_type=None): def mul_calculate(num, denum, aslist=False, out_type=None):
...@@ -3872,7 +3861,7 @@ register_canonicalize(local_mul_canonizer, name='local_mul_canonizer') ...@@ -3872,7 +3861,7 @@ register_canonicalize(local_mul_canonizer, name='local_mul_canonizer')
def local_neg_to_mul(node): def local_neg_to_mul(node):
if node.op == T.neg: if node.op == T.neg:
return [T.mul(numpy.array(-1, dtype=node.inputs[0].dtype), return [T.mul(numpy.array(-1, dtype=node.inputs[0].dtype),
node.inputs[0])] node.inputs[0])]
register_canonicalize(local_neg_to_mul) register_canonicalize(local_neg_to_mul)
...@@ -3924,10 +3913,10 @@ def local_elemwise_sub_zeros(node): ...@@ -3924,10 +3913,10 @@ def local_elemwise_sub_zeros(node):
""" """
Elemwise{sub}(X,X) -> zeros_like(X) Elemwise{sub}(X,X) -> zeros_like(X)
""" """
if (isinstance(node.op, T.Elemwise) if (isinstance(node.op, T.Elemwise) and
and node.op.scalar_op.nin == 2 node.op.scalar_op.nin == 2 and
and node.op.scalar_op == scalar.sub node.op.scalar_op == scalar.sub and
and node.inputs[0] == node.inputs[1]): node.inputs[0] == node.inputs[1]):
return [T.zeros_like(node.inputs[0])] return [T.zeros_like(node.inputs[0])]
...@@ -4013,9 +4002,8 @@ def local_sum_div_dimshuffle(node): ...@@ -4013,9 +4002,8 @@ def local_sum_div_dimshuffle(node):
' to False.') ' to False.')
new_denom = T.DimShuffle( new_denom = T.DimShuffle(
thing_dimshuffled.type.broadcastable, thing_dimshuffled.type.broadcastable,
new_new_order new_new_order)(thing_dimshuffled)
)(thing_dimshuffled)
return [T.true_div(node.op(numerator), new_denom)] return [T.true_div(node.op(numerator), new_denom)]
# else: # else:
# print 'incompatible dims:', axis, new_order # print 'incompatible dims:', axis, new_order
...@@ -4052,8 +4040,9 @@ def local_op_of_op(node): ...@@ -4052,8 +4040,9 @@ def local_op_of_op(node):
# We manipulate the graph so this is done to make sure the opt # We manipulate the graph so this is done to make sure the opt
# doesn't affect other computations. # doesn't affect other computations.
if len(node_inps.clients) == 1: if len(node_inps.clients) == 1:
if (node_inps.owner and (isinstance(node_inps.owner.op, T.elemwise.Prod) if (node_inps.owner and
or isinstance(node_inps.owner.op, T.elemwise.Sum))): (isinstance(node_inps.owner.op, T.elemwise.Prod) or
isinstance(node_inps.owner.op, T.elemwise.Sum))):
# check to see either the inner or outer prod is doing a # check to see either the inner or outer prod is doing a
# product over all axis, in which case we can remove it # product over all axis, in which case we can remove it
...@@ -4074,7 +4063,6 @@ def local_op_of_op(node): ...@@ -4074,7 +4063,6 @@ def local_op_of_op(node):
assert len(newaxis) == len(list(node_inps.owner.op.axis) + assert len(newaxis) == len(list(node_inps.owner.op.axis) +
list(node.op.axis)) list(node.op.axis))
# The old bugged logic. We keep it there to generate a warning # The old bugged logic. We keep it there to generate a warning
# when we generated bad code. # when we generated bad code.
alldims = list(range(node_inps.owner.inputs[0].type.ndim)) alldims = list(range(node_inps.owner.inputs[0].type.ndim))
...@@ -4087,20 +4075,20 @@ def local_op_of_op(node): ...@@ -4087,20 +4075,20 @@ def local_op_of_op(node):
if i not in alldims] if i not in alldims]
if (theano.config.warn.sum_sum_bug and if (theano.config.warn.sum_sum_bug and
newaxis != newaxis_old and newaxis != newaxis_old and
len(newaxis) == len(newaxis_old)): len(newaxis) == len(newaxis_old)):
_logger.warn( _logger.warn(
"WARNING (YOUR CURRENT CODE IS FINE): Theano " "WARNING (YOUR CURRENT CODE IS FINE): Theano "
"versions between version 9923a40c7b7a and August " "versions between version 9923a40c7b7a and August "
"2nd, 2010 generated bugged code in this case. " "2nd, 2010 generated bugged code in this case. "
"This happens when there are two consecutive sums " "This happens when there are two consecutive sums "
"in the graph and the intermediate sum is not " "in the graph and the intermediate sum is not "
"used elsewhere in the code. Some safeguard " "used elsewhere in the code. Some safeguard "
"removed some bad code, but not in all cases. You " "removed some bad code, but not in all cases. You "
"are in one such case. To disable this warning " "are in one such case. To disable this warning "
"(that you can safely ignore since this bug has " "(that you can safely ignore since this bug has "
"been fixed) set the theano flag " "been fixed) set the theano flag "
"`warn.sum_sum_bug` to False.") "`warn.sum_sum_bug` to False.")
combined = opt_type(newaxis, dtype=out_dtype) combined = opt_type(newaxis, dtype=out_dtype)
return [combined(node_inps.owner.inputs[0])] return [combined(node_inps.owner.inputs[0])]
...@@ -4126,9 +4114,8 @@ def local_reduce_join(node): ...@@ -4126,9 +4114,8 @@ def local_reduce_join(node):
""" """
if (isinstance(node.op, T.CAReduce) and if (isinstance(node.op, T.CAReduce) and
node.inputs[0].owner and node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, T.Join)): isinstance(node.inputs[0].owner.op, T.Join)):
join = node.inputs[0].owner join = node.inputs[0].owner
if T.extract_constant(join.inputs[0]) != 0: if T.extract_constant(join.inputs[0]) != 0:
return return
...@@ -4149,7 +4136,8 @@ def local_reduce_join(node): ...@@ -4149,7 +4136,8 @@ def local_reduce_join(node):
if not inp: if not inp:
return return
if (not isinstance(inp.op, DimShuffle) or if (not isinstance(inp.op, DimShuffle) or
inp.op.new_order != ('x',) + tuple(range(inp.inputs[0].ndim))): inp.op.new_order != ('x',) +
tuple(range(inp.inputs[0].ndim))):
return return
new_inp.append(inp.inputs[0]) new_inp.append(inp.inputs[0])
ret = Elemwise(node.op.scalar_op)(*new_inp) ret = Elemwise(node.op.scalar_op)(*new_inp)
...@@ -4174,9 +4162,8 @@ def local_reduce_join(node): ...@@ -4174,9 +4162,8 @@ def local_reduce_join(node):
'optimization, that modified the pattern ' 'optimization, that modified the pattern '
'"Reduce{scalar.op}(Join(axis=0, a, b), axis=0)", ' '"Reduce{scalar.op}(Join(axis=0, a, b), axis=0)", '
'did not check the reduction axis. So if the ' 'did not check the reduction axis. So if the '
'reduction axis was not 0, you got a wrong answer.' 'reduction axis was not 0, you got a wrong answer.'))
)) return
return
# We add the new check late to don't add extra warning. # We add the new check late to don't add extra warning.
try: try:
...@@ -4204,7 +4191,7 @@ def local_cut_useless_reduce(node): ...@@ -4204,7 +4191,7 @@ def local_cut_useless_reduce(node):
# theano/tensor/tests/test_opt.py:T_local_reduce.test_local_reduce_broadcast_some_0 # theano/tensor/tests/test_opt.py:T_local_reduce.test_local_reduce_broadcast_some_0
# see gh-790 issue. # see gh-790 issue.
# #
#@register_canonicalize # @register_canonicalize
@register_uncanonicalize @register_uncanonicalize
@register_specialize @register_specialize
@gof.local_optimizer(ALL_REDUCE) @gof.local_optimizer(ALL_REDUCE)
...@@ -4258,7 +4245,7 @@ def local_opt_alloc(node): ...@@ -4258,7 +4245,7 @@ def local_opt_alloc(node):
input = node_inps.owner.inputs[0] input = node_inps.owner.inputs[0]
shapes = node_inps.owner.inputs[1:] shapes = node_inps.owner.inputs[1:]
if (node.op.axis is None or if (node.op.axis is None or
node.op.axis == tuple(range(input.ndim))): node.op.axis == tuple(range(input.ndim))):
try: try:
val = get_scalar_constant_value(input) val = get_scalar_constant_value(input)
assert val.size == 1 assert val.size == 1
...@@ -4346,7 +4333,7 @@ register_canonicalize(local_mul_zero) ...@@ -4346,7 +4333,7 @@ register_canonicalize(local_mul_zero)
@gof.local_optimizer([T.true_div]) @gof.local_optimizer([T.true_div])
def local_div_to_inv(node): def local_div_to_inv(node):
if node.op == T.true_div and N.all( if node.op == T.true_div and N.all(
local_mul_canonizer.get_constant(node.inputs[0]) == 1.0): local_mul_canonizer.get_constant(node.inputs[0]) == 1.0):
out = node.outputs[0] out = node.outputs[0]
new_out = T.inv(local_mul_canonizer.merge_num_denum(node.inputs[1:], new_out = T.inv(local_mul_canonizer.merge_num_denum(node.inputs[1:],
[])) []))
...@@ -4501,7 +4488,8 @@ def local_pow_specialize_device(node): ...@@ -4501,7 +4488,8 @@ def local_pow_specialize_device(node):
if abs(y) > 2: if abs(y) > 2:
# We fuse all the pow together here to make # We fuse all the pow together here to make
# compilation faster # compilation faster
rval1 = Elemwise(theano.scalar.Composite( rval1 = Elemwise(
theano.scalar.Composite(
[pow2_scal[0]], [rval1_scal])).make_node(xsym) [pow2_scal[0]], [rval1_scal])).make_node(xsym)
if y < 0: if y < 0:
rval = [T.inv(rval1)] rval = [T.inv(rval1)]
...@@ -4566,8 +4554,8 @@ def local_mul_specialize(node): ...@@ -4566,8 +4554,8 @@ def local_mul_specialize(node):
else: else:
# The next case would cause a replace by an equivalent case. # The next case would cause a replace by an equivalent case.
if (neg and if (neg and
nb_neg_node == 0 and nb_neg_node == 0 and
nb_cst == 1): nb_cst == 1):
return return
elif neg: elif neg:
# Don't add an extra neg node as we can't # Don't add an extra neg node as we can't
...@@ -4640,8 +4628,8 @@ def check_for_x_over_absX(numerators, denominators): ...@@ -4640,8 +4628,8 @@ def check_for_x_over_absX(numerators, denominators):
# TODO: this function should dig/search through dimshuffles # TODO: this function should dig/search through dimshuffles
# This won't catch a dimshuffled absolute value # This won't catch a dimshuffled absolute value
for den in list(denominators): for den in list(denominators):
if (den.owner and den.owner.op == T.abs_ if (den.owner and den.owner.op == T.abs_ and
and den.owner.inputs[0] in numerators): den.owner.inputs[0] in numerators):
if den.owner.inputs[0].type.dtype.startswith('complex'): if den.owner.inputs[0].type.dtype.startswith('complex'):
# TODO: Make an Op that projects a complex number to # TODO: Make an Op that projects a complex number to
# have unit length but projects 0 to 0. That # have unit length but projects 0 to 0. That
...@@ -4715,8 +4703,8 @@ def local_log1p(node): ...@@ -4715,8 +4703,8 @@ def local_log1p(node):
if node.op == T.log: if node.op == T.log:
log_arg, = node.inputs log_arg, = node.inputs
if log_arg.owner and log_arg.owner.op == T.add: if log_arg.owner and log_arg.owner.op == T.add:
scalars, scalar_inputs, nonconsts = \ scalars, scalar_inputs, nonconsts = scalarconsts_rest(
scalarconsts_rest(log_arg.owner.inputs) log_arg.owner.inputs)
# scalar_inputs are potentially dimshuffled and fill'd scalars # scalar_inputs are potentially dimshuffled and fill'd scalars
if scalars and numpy.allclose(numpy.sum(scalars), 1): if scalars and numpy.allclose(numpy.sum(scalars), 1):
if not nonconsts: if not nonconsts:
...@@ -4748,7 +4736,7 @@ def local_log_add(node): ...@@ -4748,7 +4736,7 @@ def local_log_add(node):
if len(zi) != 2: if len(zi) != 2:
# -- upgrading Maximum to handle multiple inputs wasn't trivial # -- upgrading Maximum to handle multiple inputs wasn't trivial
# TODO # TODO
#raise NotImplementedError() # raise NotImplementedError()
return return
pre_exp = [x.owner.inputs[0] for x in zi pre_exp = [x.owner.inputs[0] for x in zi
if x.owner and x.owner.op == T.exp] if x.owner and x.owner.op == T.exp]
...@@ -4945,8 +4933,7 @@ def constant_folding(node): ...@@ -4945,8 +4933,7 @@ def constant_folding(node):
storage_map[o] = [None] storage_map[o] = [None]
compute_map[o] = [False] compute_map[o] = [False]
if (hasattr(node.op, 'python_constant_folding') and if (hasattr(node.op, 'python_constant_folding') and
node.op.python_constant_folding(node)): node.op.python_constant_folding(node)):
old_value = getattr(node.op, '_op_use_c_code', False) old_value = getattr(node.op, '_op_use_c_code', False)
try: try:
node.op._op_use_c_code = False node.op._op_use_c_code = False
...@@ -5037,9 +5024,9 @@ register_specialize(local_one_minus_erf) ...@@ -5037,9 +5024,9 @@ register_specialize(local_one_minus_erf)
local_one_minus_erf2 = gof.PatternSub((T.add, local_one_minus_erf2 = gof.PatternSub((T.add,
1, 1,
(T.mul, -1, (T.erf, 'x'))), (T.mul, -1, (T.erf, 'x'))),
(T.erfc, 'x'), (T.erfc, 'x'),
allow_multiple_clients=True, allow_multiple_clients=True,
name='local_one_minus_erf2') name='local_one_minus_erf2')
register_canonicalize(local_one_minus_erf2) register_canonicalize(local_one_minus_erf2)
register_stabilize(local_one_minus_erf2) register_stabilize(local_one_minus_erf2)
register_specialize(local_one_minus_erf2) register_specialize(local_one_minus_erf2)
...@@ -5058,7 +5045,7 @@ register_canonicalize(local_one_plus_neg_erf) ...@@ -5058,7 +5045,7 @@ register_canonicalize(local_one_plus_neg_erf)
register_stabilize(local_one_plus_neg_erf) register_stabilize(local_one_plus_neg_erf)
register_specialize(local_one_plus_neg_erf) register_specialize(local_one_plus_neg_erf)
#(-1)+erf(x) => -erfc(x) don't need erf(x)+(-1) as the canonicalize # (-1)+erf(x) => -erfc(x) don't need erf(x)+(-1) as the canonicalize
# will put the -1 as the first argument. # will put the -1 as the first argument.
local_erf_minus_one = gof.PatternSub((T.add, local_erf_minus_one = gof.PatternSub((T.add,
dict(pattern='y', constraint=_is_minus1), dict(pattern='y', constraint=_is_minus1),
...@@ -5124,7 +5111,7 @@ register_canonicalize(local_one_add_neg_erfc) ...@@ -5124,7 +5111,7 @@ register_canonicalize(local_one_add_neg_erfc)
register_stabilize(local_one_add_neg_erfc) register_stabilize(local_one_add_neg_erfc)
register_specialize(local_one_add_neg_erfc) register_specialize(local_one_add_neg_erfc)
#(-1)+erfc(-x)=>erf(x) # (-1)+erfc(-x)=>erf(x)
local_erf_neg_minus_one = gof.PatternSub((T.add, local_erf_neg_minus_one = gof.PatternSub((T.add,
dict(pattern='y', constraint=_is_minus1), dict(pattern='y', constraint=_is_minus1),
(T.erfc, (T.neg, 'x'))), (T.erfc, (T.neg, 'x'))),
...@@ -5137,7 +5124,7 @@ register_canonicalize(local_erf_neg_minus_one) ...@@ -5137,7 +5124,7 @@ register_canonicalize(local_erf_neg_minus_one)
register_stabilize(local_erf_neg_minus_one) register_stabilize(local_erf_neg_minus_one)
register_specialize(local_erf_neg_minus_one) register_specialize(local_erf_neg_minus_one)
#(-1)+erfc(-1*x)=>erf(x) # (-1)+erfc(-1*x)=>erf(x)
local_erf_neg_minus_one2 = gof.PatternSub((T.add, local_erf_neg_minus_one2 = gof.PatternSub((T.add,
dict(pattern='y', constraint=_is_minus1), dict(pattern='y', constraint=_is_minus1),
(T.erfc, (T.mul, -1, 'x'))), (T.erfc, (T.mul, -1, 'x'))),
...@@ -5176,8 +5163,8 @@ def local_log_erfc(node): ...@@ -5176,8 +5163,8 @@ def local_log_erfc(node):
x = node.inputs[0].owner.inputs[0] x = node.inputs[0].owner.inputs[0]
stab_value = (-x ** 2 - T.log(x) - .5 * T.log(numpy.pi) + stab_value = (-x ** 2 - T.log(x) - .5 * T.log(numpy.pi) +
T.log(1 - 1 / (2 * x ** 2) + 3 / (4 * x ** 4) T.log(1 - 1 / (2 * x ** 2) + 3 / (4 * x ** 4) -
- 15 / (8 * x ** 6))) 15 / (8 * x ** 6)))
if (node.outputs[0].dtype == 'float32' or if (node.outputs[0].dtype == 'float32' or
node.outputs[0].dtype == 'float16'): node.outputs[0].dtype == 'float16'):
...@@ -5191,8 +5178,8 @@ def local_log_erfc(node): ...@@ -5191,8 +5178,8 @@ def local_log_erfc(node):
# Stability optimization of the grad of log(erfc(x)) # Stability optimization of the grad of log(erfc(x))
#([y*]exp(-(x**2)))/erfc(x) # The y* is optional # ([y*]exp(-(x**2)))/erfc(x) # The y* is optional
#([y*]exp(x**2))/erfc(-x) => [y*](when x>threashold, # ([y*]exp(x**2))/erfc(-x) => [y*](when x>threashold,
# sqrt(pi)*-x/(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6))) # sqrt(pi)*-x/(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6)))
# for float64: threshold=26.63 see at the end of the fct for the explaination # for float64: threshold=26.63 see at the end of the fct for the explaination
# for float32: threshold=9.3 see at the end of the fct for the explaination # for float32: threshold=9.3 see at the end of the fct for the explaination
...@@ -5226,8 +5213,8 @@ def local_grad_log_erfc_neg(node): ...@@ -5226,8 +5213,8 @@ def local_grad_log_erfc_neg(node):
if mul.owner.inputs[0].owner or len(mul.owner.inputs) != 2: if mul.owner.inputs[0].owner or len(mul.owner.inputs) != 2:
return False return False
y = mul.owner.inputs[0] y = mul.owner.inputs[0]
if (not mul.owner.inputs[1].owner if (not mul.owner.inputs[1].owner or
or mul.owner.inputs[1].owner.op != T.exp): mul.owner.inputs[1].owner.op != T.exp):
return False return False
exp = mul.owner.inputs[1] exp = mul.owner.inputs[1]
...@@ -5236,8 +5223,8 @@ def local_grad_log_erfc_neg(node): ...@@ -5236,8 +5223,8 @@ def local_grad_log_erfc_neg(node):
if exp.owner.inputs[0].owner.op == T.neg: if exp.owner.inputs[0].owner.op == T.neg:
neg = exp.owner.inputs[0] neg = exp.owner.inputs[0]
if (not neg.owner.inputs[0].owner if (not neg.owner.inputs[0].owner or
or neg.owner.inputs[0].owner.op != T.sqr): neg.owner.inputs[0].owner.op != T.sqr):
return False return False
sqr = neg.owner.inputs[0] sqr = neg.owner.inputs[0]
x = sqr.owner.inputs[0] x = sqr.owner.inputs[0]
...@@ -5279,8 +5266,8 @@ def local_grad_log_erfc_neg(node): ...@@ -5279,8 +5266,8 @@ def local_grad_log_erfc_neg(node):
return False return False
if len(mul_neg.owner.inputs) == 2: if len(mul_neg.owner.inputs) == 2:
if (not mul_neg.owner.inputs[1].owner if (not mul_neg.owner.inputs[1].owner or
or mul_neg.owner.inputs[1].owner.op != T.sqr): mul_neg.owner.inputs[1].owner.op != T.sqr):
return False return False
sqr = mul_neg.owner.inputs[1] sqr = mul_neg.owner.inputs[1]
x = sqr.owner.inputs[0] x = sqr.owner.inputs[0]
...@@ -5292,8 +5279,8 @@ def local_grad_log_erfc_neg(node): ...@@ -5292,8 +5279,8 @@ def local_grad_log_erfc_neg(node):
return False return False
if cst2 != -1: if cst2 != -1:
if (not erfc_x.owner or erfc_x.owner.op != T.mul if (not erfc_x.owner or erfc_x.owner.op != T.mul or
or len(erfc_x.owner.inputs) != 2): len(erfc_x.owner.inputs) != 2):
# todo implement that case # todo implement that case
return False return False
if erfc_x.owner.inputs[1] is not mul_neg.owner.inputs[1]: if erfc_x.owner.inputs[1] is not mul_neg.owner.inputs[1]:
...@@ -5324,12 +5311,12 @@ def local_grad_log_erfc_neg(node): ...@@ -5324,12 +5311,12 @@ def local_grad_log_erfc_neg(node):
# aaron value # aaron value
stab_value = (x * T.pow(1 - 1 / (2 * (x ** 2)) + stab_value = (x * T.pow(1 - 1 / (2 * (x ** 2)) +
3 / (4 * (x ** 4)) - 15 / (8 * (x ** 6)), -1) 3 / (4 * (x ** 4)) - 15 / (8 * (x ** 6)), -1) *
* T.cast(T.sqrt(numpy.pi), dtype=x.dtype)) T.cast(T.sqrt(numpy.pi), dtype=x.dtype))
if x.dtype == 'float32' or x.dtype == 'float16': if x.dtype == 'float32' or x.dtype == 'float16':
threshold = 9.3 threshold = 9.3
#threshold = 10.1 # threshold = 10.1
elif x.dtype == 'float64': elif x.dtype == 'float64':
threshold = 26.641747557 threshold = 26.641747557
ret = T.switch(x < threshold, true_div_no_mul, stab_value) * y ret = T.switch(x < threshold, true_div_no_mul, stab_value) * y
...@@ -5531,6 +5518,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, ...@@ -5531,6 +5518,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
if maker is None: if maker is None:
def maker(node, scalar_op): def maker(node, scalar_op):
return OP(scalar_op) return OP(scalar_op)
def local_fuse(node): def local_fuse(node):
""" """
As part of specialization, we fuse two consecutive elemwise Ops of the As part of specialization, we fuse two consecutive elemwise Ops of the
...@@ -5598,13 +5586,13 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, ...@@ -5598,13 +5586,13 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
# If a variable is used as multiple into to the same node, # If a variable is used as multiple into to the same node,
# we still want to fusion. So we take the set. # we still want to fusion. So we take the set.
if (i.owner and if (i.owner and
isinstance(i.owner.op, OP) and isinstance(i.owner.op, OP) and
len(set([n for n, idx in i.clients])) == 1 and len(set([n for n, idx in i.clients])) == 1 and
# Do not merge elemwise that don't have the same # Do not merge elemwise that don't have the same
# broadcastable pattern to don't redo duplicate # broadcastable pattern to don't redo duplicate
# computation due to broadcast. # computation due to broadcast.
i.owner.outputs[0].broadcastable == node.outputs[0].broadcastable): i.owner.outputs[0].broadcastable ==
node.outputs[0].broadcastable):
do_fusion = True do_fusion = True
try: try:
tmp_s_input = [] tmp_s_input = []
...@@ -5840,14 +5828,14 @@ def local_add_mul_fusion(node): ...@@ -5840,14 +5828,14 @@ def local_add_mul_fusion(node):
""" """
if (not isinstance(node.op, Elemwise) or if (not isinstance(node.op, Elemwise) or
not isinstance(node.op.scalar_op, (scalar.Add, scalar.Mul))): not isinstance(node.op.scalar_op, (scalar.Add, scalar.Mul))):
return False return False
s_op = node.op.scalar_op.__class__ s_op = node.op.scalar_op.__class__
for inp in node.inputs: for inp in node.inputs:
if (inp.owner and if (inp.owner and
isinstance(inp.owner.op, Elemwise) and isinstance(inp.owner.op, Elemwise) and
isinstance(inp.owner.op.scalar_op, s_op)): isinstance(inp.owner.op.scalar_op, s_op)):
l = list(node.inputs) l = list(node.inputs)
l.remove(inp) l.remove(inp)
return [node.op(*(l + inp.owner.inputs))] return [node.op(*(l + inp.owner.inputs))]
...@@ -5882,13 +5870,15 @@ else: ...@@ -5882,13 +5870,15 @@ else:
# just returns the input, it should be removed from the graph to # just returns the input, it should be removed from the graph to
# make sure all possible optimizations can be applied. # make sure all possible optimizations can be applied.
register_canonicalize(gof.OpRemove(theano.gradient.consider_constant_), register_canonicalize(gof.OpRemove(theano.gradient.consider_constant_),
'fast_compile', 'fast_run', name='remove_consider_constant') 'fast_compile', 'fast_run',
name='remove_consider_constant')
register_canonicalize(gof.OpRemove(theano.gradient.zero_grad_), register_canonicalize(gof.OpRemove(theano.gradient.zero_grad_),
'fast_compile', 'fast_run', name='remove_zero_grad') 'fast_compile', 'fast_run', name='remove_zero_grad')
register_canonicalize(gof.OpRemove(theano.gradient.disconnected_grad_), register_canonicalize(gof.OpRemove(theano.gradient.disconnected_grad_),
'fast_compile', 'fast_run', name='remove_disconnected_grad') 'fast_compile', 'fast_run',
name='remove_disconnected_grad')
@register_canonicalize @register_canonicalize
......
...@@ -63,7 +63,6 @@ whitelist_flake8 = [ ...@@ -63,7 +63,6 @@ whitelist_flake8 = [
"tensor/sort.py", "tensor/sort.py",
"tensor/__init__.py", "tensor/__init__.py",
"tensor/opt_uncanonicalize.py", "tensor/opt_uncanonicalize.py",
"tensor/opt.py",
"tensor/blas.py", "tensor/blas.py",
"tensor/extra_ops.py", "tensor/extra_ops.py",
"tensor/nlinalg.py", "tensor/nlinalg.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论