提交 4cf7afb4 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2952 from abergeron/flake8

Flake8 work
...@@ -15,7 +15,9 @@ class OpFromGraph(gof.Op): ...@@ -15,7 +15,9 @@ class OpFromGraph(gof.Op):
TODO: TODO:
- examples for a multi-layer mlp. where? - examples for a multi-layer mlp. where?
- __hash__, __eq__ otherwise won't merge, try gof.opt.is_same_graph_with_merge(op1.new_outputs, op2, new_outputs) - __hash__, __eq__ otherwise won't merge, try
gof.opt.is_same_graph_with_merge(op1.new_outputs, op2,
new_outputs)
- c_code() to remove the double overhead? - c_code() to remove the double overhead?
- opt to unfold it, work inplace on inputs - opt to unfold it, work inplace on inputs
- grad() make it support DisconnectedType and the new interface - grad() make it support DisconnectedType and the new interface
...@@ -76,8 +78,6 @@ class OpFromGraph(gof.Op): ...@@ -76,8 +78,6 @@ class OpFromGraph(gof.Op):
# not see them. Otherwise their is problem with the gradient. # not see them. Otherwise their is problem with the gradient.
self.shared_inputs = [var for var in gof.graph.inputs(outputs) self.shared_inputs = [var for var in gof.graph.inputs(outputs)
if isinstance(var, SharedVariable)] if isinstance(var, SharedVariable)]
used_inputs = [var for var in gof.graph.inputs(outputs)
if not isinstance(var, gof.Constant)]
shared_vars = [var.type() for var in self.shared_inputs] shared_vars = [var.type() for var in self.shared_inputs]
new = rebuild_collect_shared(outputs, inputs=inputs + shared_vars, new = rebuild_collect_shared(outputs, inputs=inputs + shared_vars,
replace=dict(zip(self.shared_inputs, replace=dict(zip(self.shared_inputs,
...@@ -110,8 +110,8 @@ class OpFromGraph(gof.Op): ...@@ -110,8 +110,8 @@ class OpFromGraph(gof.Op):
def make_node(self, *inputs): def make_node(self, *inputs):
for input, type in zip(inputs, self.input_types): for input, type in zip(inputs, self.input_types):
if not type == input.type: if not type == input.type:
raise TypeError("Wrong type, expected %s but got %s" raise TypeError("Wrong type, expected %s but got %s" %
% (type, input.type)) (type, input.type))
return gof.Apply(self, return gof.Apply(self,
list(inputs) + self.shared_inputs, list(inputs) + self.shared_inputs,
[type() for type in self.output_types]) [type() for type in self.output_types])
...@@ -143,7 +143,8 @@ class OpFromGraph(gof.Op): ...@@ -143,7 +143,8 @@ class OpFromGraph(gof.Op):
grad_ops = self.grad_ops grad_ops = self.grad_ops
else: else:
gs = theano.gradient.grad(cost=None, gs = theano.gradient.grad(cost=None,
known_grads=dict(zip(self.new_outputs, output_grads)), known_grads=dict(zip(self.new_outputs,
output_grads)),
wrt=self.new_inputs, wrt=self.new_inputs,
disconnected_inputs='ignore') disconnected_inputs='ignore')
......
...@@ -5,10 +5,12 @@ ...@@ -5,10 +5,12 @@
""" """
from __future__ import print_function from __future__ import print_function
__docformat__ = "restructuredtext en" import copy
import sys
import copy, sys, copy_reg, gc import copy_reg
import gc
from itertools import izip from itertools import izip
import logging
import numpy import numpy
...@@ -16,10 +18,9 @@ import theano ...@@ -16,10 +18,9 @@ import theano
from theano import gof from theano import gof
from theano.compat import get_unbound_function, product as itertools_product from theano.compat import get_unbound_function, product as itertools_product
from theano.compat.six import StringIO from theano.compat.six import StringIO
from theano.gof import (FunctionGraph, graph, utils, link, from theano.gof import (graph, utils, link,
ops_with_inner_function) ops_with_inner_function)
from theano.gof.link import raise_with_op from theano.gof.link import raise_with_op
from theano.gof.cc import CLinker
from theano.configparser import (config, AddConfigVar, BoolParam, IntParam, from theano.configparser import (config, AddConfigVar, BoolParam, IntParam,
StrParam) StrParam)
from theano.compile.function_module import ( from theano.compile.function_module import (
...@@ -29,6 +30,8 @@ from theano.compile.function_module import ( ...@@ -29,6 +30,8 @@ from theano.compile.function_module import (
from theano.compile.mode import Mode, register_mode from theano.compile.mode import Mode, register_mode
from theano.compile.ops import OutputGuard from theano.compile.ops import OutputGuard
__docformat__ = "restructuredtext en"
AddConfigVar('DebugMode.patience', AddConfigVar('DebugMode.patience',
"Optimize graph this many times to detect inconsistency", "Optimize graph this many times to detect inconsistency",
IntParam(10, lambda i: i > 0), IntParam(10, lambda i: i > 0),
...@@ -57,8 +60,8 @@ AddConfigVar('DebugMode.check_strides', ...@@ -57,8 +60,8 @@ AddConfigVar('DebugMode.check_strides',
AddConfigVar('DebugMode.warn_input_not_reused', AddConfigVar('DebugMode.warn_input_not_reused',
("Generate a warning when destroy_map or view_map says that an " ("Generate a warning when destroy_map or view_map says that an "
"op works inplace, but the op did not reuse the input for its output." "op works inplace, but the op did not reuse the input for its "
), "output."),
BoolParam(True), BoolParam(True),
in_c_key=False) in_c_key=False)
...@@ -94,7 +97,6 @@ AddConfigVar('DebugMode.check_preallocated_output_ndim', ...@@ -94,7 +97,6 @@ AddConfigVar('DebugMode.check_preallocated_output_ndim',
IntParam(4, lambda i: i > 0), IntParam(4, lambda i: i > 0),
in_c_key=False) in_c_key=False)
import logging
_logger = logging.getLogger("theano.compile.debugmode") _logger = logging.getLogger("theano.compile.debugmode")
...@@ -148,7 +150,7 @@ class BadThunkOutput(DebugModeError): ...@@ -148,7 +150,7 @@ class BadThunkOutput(DebugModeError):
def __init__(self, r, thunk1, val1, thunk2, val2, inputs_val=()): def __init__(self, r, thunk1, val1, thunk2, val2, inputs_val=()):
"""Initialize members""" """Initialize members"""
DebugModeError.__init__(self) # to be compatible with python2.4 super(BadThunkOutput, self).__init__()
self.r = r self.r = r
self.thunk1 = thunk1 self.thunk1 = thunk1
self.val1 = val1 self.val1 = val1
...@@ -173,8 +175,10 @@ class BadThunkOutput(DebugModeError): ...@@ -173,8 +175,10 @@ class BadThunkOutput(DebugModeError):
print(" op :", self.offending_op(), file=sio) print(" op :", self.offending_op(), file=sio)
print(" Outputs Type:", self.r.type, file=sio) print(" Outputs Type:", self.r.type, file=sio)
print(" Outputs Shape:", getattr(self.val1, 'shape', None), file=sio) print(" Outputs Shape:", getattr(self.val1, 'shape', None), file=sio)
print(" Outputs Strides:", getattr(self.val1, 'strides', None), file=sio) print(" Outputs Strides:", getattr(self.val1, 'strides', None),
print(" Inputs Type :", [i.type for i in self.r.owner.inputs], file=sio) file=sio)
print(" Inputs Type :", [i.type for i in self.r.owner.inputs],
file=sio)
print(" Inputs Shape:", [getattr(val, 'shape', None) print(" Inputs Shape:", [getattr(val, 'shape', None)
for val in self.inputs_val], file=sio) for val in self.inputs_val], file=sio)
print(" Inputs Strides:", [getattr(val, 'strides', None) print(" Inputs Strides:", [getattr(val, 'strides', None)
...@@ -226,7 +230,7 @@ class BadOptimization(DebugModeError): ...@@ -226,7 +230,7 @@ class BadOptimization(DebugModeError):
def __init__(self, old_r, new_r, old_r_val, new_r_val, reason, def __init__(self, old_r, new_r, old_r_val, new_r_val, reason,
old_graph, new_graph): old_graph, new_graph):
"""Initialize members""" """Initialize members"""
DebugModeError.__init__(self) # to be compatible with python2.4 super(BadOptimization, self).__init__()
self.old_r = old_r self.old_r = old_r
self.new_r = new_r self.new_r = new_r
self.old_r_val = old_r_val self.old_r_val = old_r_val
...@@ -287,24 +291,20 @@ class BadOptimization(DebugModeError): ...@@ -287,24 +291,20 @@ class BadOptimization(DebugModeError):
ov = numpy.asarray(self.old_r_val) ov = numpy.asarray(self.old_r_val)
nv = numpy.asarray(self.new_r_val) nv = numpy.asarray(self.new_r_val)
ssio = StringIO() ssio = StringIO()
print(" Max Abs Diff: ", numpy.max(numpy.absolute(nv - abs_diff = numpy.absolute(nv - ov)
ov)), file=ssio) print(" Max Abs Diff: ", numpy.max(abs_diff), file=ssio)
print(" Mean Abs Diff: ", numpy.mean(numpy.absolute(nv - print(" Mean Abs Diff: ", numpy.mean(abs_diff), file=ssio)
ov)), file=ssio) print(" Median Abs Diff: ", numpy.median(abs_diff), file=ssio)
print(" Median Abs Diff: ", numpy.median(numpy.absolute( print(" Std Abs Diff: ", numpy.std(abs_diff), file=ssio)
nv - ov)), file=ssio) arg_max_val = numpy.argmax(abs_diff)
print(" Std Abs Diff: ", numpy.std(numpy.absolute(
nv - ov)), file=ssio)
arg_max_val = numpy.argmax(numpy.absolute(nv - ov))
values_at_max = (nv.flatten()[arg_max_val], values_at_max = (nv.flatten()[arg_max_val],
ov.flatten()[arg_max_val]) ov.flatten()[arg_max_val])
print(" Value at Max Diff: ", values_at_max, file=ssio) print(" Value at Max Diff: ", values_at_max, file=ssio)
# N.B. the maximum(..., 1e-8) protects against div by 0 when # N.B. the maximum(..., 1e-8) protects against div by 0 when
# nv == ov == 0 # nv == ov == 0
reldiff = (numpy.absolute(nv - ov) reldiff = (abs_diff /
/ numpy.maximum( numpy.maaximum(numpy.absolute(nv) + numpy.absolute(ov),
numpy.absolute(nv) + numpy.absolute(ov),
1e-8)) 1e-8))
print(" Max Rel Diff: ", numpy.max(reldiff), file=ssio) print(" Max Rel Diff: ", numpy.max(reldiff), file=ssio)
print(" Mean Rel Diff: ", numpy.mean(reldiff), file=ssio) print(" Mean Rel Diff: ", numpy.mean(reldiff), file=ssio)
...@@ -325,8 +325,10 @@ class BadOptimization(DebugModeError): ...@@ -325,8 +325,10 @@ class BadOptimization(DebugModeError):
print(" New Graph:", file=sio) print(" New Graph:", file=sio)
print(self.new_graph, file=sio) print(self.new_graph, file=sio)
print("", file=sio) print("", file=sio)
print("Hint: relax the tolerance by setting tensor.cmp_sloppy=1", file=sio) print("Hint: relax the tolerance by setting tensor.cmp_sloppy=1",
print(" or even tensor.cmp_sloppy=2 for less-strict comparison", file=sio) file=sio)
print(" or even tensor.cmp_sloppy=2 for less-strict comparison",
file=sio)
return sio.getvalue() return sio.getvalue()
...@@ -334,8 +336,7 @@ class BadDestroyMap(DebugModeError): ...@@ -334,8 +336,7 @@ class BadDestroyMap(DebugModeError):
"""Exception: Some perform() or c_code() modified an input that """Exception: Some perform() or c_code() modified an input that
wasn't in the destroy_map""" wasn't in the destroy_map"""
def __init__(self, node, idx, old_val, new_val, perform): def __init__(self, node, idx, old_val, new_val, perform):
#super(BadDestroyMap, self).__init__() super(BadDestroyMap, self).__init__()
DebugModeError.__init__(self) # to be compatible with python2.4
self.node = node self.node = node
self.idx = idx self.idx = idx
self.old_val = old_val self.old_val = old_val
...@@ -351,26 +352,38 @@ class BadDestroyMap(DebugModeError): ...@@ -351,26 +352,38 @@ class BadDestroyMap(DebugModeError):
print(" destroy_map:", getattr(self.node.op, print(" destroy_map:", getattr(self.node.op,
'destroy_map', {}), file=sio) 'destroy_map', {}), file=sio)
print(" changed input idx:", self.idx, file=sio) print(" changed input idx:", self.idx, file=sio)
print(" changed input type:", self.node.inputs[self.idx].type, file=sio) print(" changed input type:", self.node.inputs[self.idx].type,
file=sio)
print(" repr (old val):", repr(self.old_val), file=sio) print(" repr (old val):", repr(self.old_val), file=sio)
print(" repr (new val):", repr(self.new_val), file=sio) print(" repr (new val):", repr(self.new_val), file=sio)
try: try:
npy_old_val = numpy.asarray(self.old_val) npy_old_val = numpy.asarray(self.old_val)
npy_new_val = numpy.asarray(self.new_val) npy_new_val = numpy.asarray(self.new_val)
print(" value dtype (new <space> old):", npy_new_val.dtype, npy_old_val.dtype, file=sio) print(" value dtype (new <space> old):", npy_new_val.dtype,
print(" value shape (new <space> old):", npy_new_val.shape, npy_old_val.shape, file=sio) npy_old_val.dtype, file=sio)
print(" value min (new <space> old):", npy_new_val.min(), npy_old_val.min(), file=sio) print(" value shape (new <space> old):", npy_new_val.shape,
print(" value max (new <space> old):", npy_new_val.max(), npy_old_val.max(), file=sio) npy_old_val.shape, file=sio)
print(" value min (new <space> old):", npy_new_val.min(),
npy_old_val.min(), file=sio)
print(" value max (new <space> old):", npy_new_val.max(),
npy_old_val.max(), file=sio)
delta = npy_new_val - npy_old_val delta = npy_new_val - npy_old_val
print(" value min (new-old):", delta.min(), file=sio) print(" value min (new-old):", delta.min(), file=sio)
print(" value max (new-old):", delta.max(), file=sio) print(" value max (new-old):", delta.max(), file=sio)
print(" value argmin (new-old):", numpy.unravel_index(delta.argmin(), npy_new_val.shape), file=sio) print(" value argmin (new-old):",
print(" value argmax (new-old):", numpy.unravel_index(delta.argmax(), npy_new_val.shape), file=sio) numpy.unravel_index(delta.argmin(), npy_new_val.shape),
print(" location of first 10 mismatches:", numpy.transpose(numpy.nonzero(delta))[:10], file=sio) file=sio)
print(" value argmax (new-old):",
numpy.unravel_index(delta.argmax(), npy_new_val.shape),
file=sio)
print(" location of first 10 mismatches:",
numpy.transpose(numpy.nonzero(delta))[:10], file=sio)
print("", file=sio) print("", file=sio)
except Exception as e: except Exception as e:
print("(Numpy-hints failed with: %s)" % str(e), file=sio) print("(Numpy-hints failed with: %s)" % str(e), file=sio)
print(" Hint: this can also be caused by a deficient values_eq_approx() or __eq__() implementation [which compared input values]", file=sio) print(" Hint: this can also be caused by a deficient "
"values_eq_approx() or __eq__() implementation "
"[which compared input values]", file=sio)
return sio.getvalue() return sio.getvalue()
...@@ -379,8 +392,7 @@ class BadViewMap(DebugModeError): ...@@ -379,8 +392,7 @@ class BadViewMap(DebugModeError):
that wasn't in the view_map""" that wasn't in the view_map"""
def __init__(self, node, output_idx, out_storage, def __init__(self, node, output_idx, out_storage,
in_alias_idx=None, out_alias_idx=None): in_alias_idx=None, out_alias_idx=None):
#super(BadViewMap, self).__init__() super(BadViewMap, self).__init__()
DebugModeError.__init__(self) # to be compatible with python2.4
self.node = node self.node = node
self.output_idx = output_idx self.output_idx = output_idx
self.out_storage = out_storage self.out_storage = out_storage
...@@ -425,8 +437,7 @@ class InvalidValueError(DebugModeError): ...@@ -425,8 +437,7 @@ class InvalidValueError(DebugModeError):
the Type of that output""" the Type of that output"""
def __init__(self, r, v, client_node=None, hint='none', def __init__(self, r, v, client_node=None, hint='none',
specific_hint='none'): specific_hint='none'):
#super(InvalidValueError, self).__init__() super(InvalidValueError, self).__init__()
DebugModeError.__init__(self) # to be compatible with python2.4
self.r = r self.r = r
self.v = v self.v = v
self.client_node = client_node self.client_node = client_node
...@@ -454,7 +465,8 @@ class InvalidValueError(DebugModeError): ...@@ -454,7 +465,8 @@ class InvalidValueError(DebugModeError):
client_node = self.client_node client_node = self.client_node
hint = self.hint hint = self.hint
specific_hint = self.specific_hint specific_hint = self.specific_hint
context = debugprint(r, prefix=' ', depth=12, file=StringIO()).getvalue() context = debugprint(r, prefix=' ', depth=12,
file=StringIO()).getvalue()
return """InvalidValueError return """InvalidValueError
type(variable) = %(type_r)s type(variable) = %(type_r)s
variable = %(r)s variable = %(r)s
...@@ -512,7 +524,8 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -512,7 +524,8 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
and their associated printed ids and their associated printed ids
:param print_type: whether to print the Variable type after the other infos :param print_type: whether to print the Variable type after the other infos
:param file: file-like object to which to print :param file: file-like object to which to print
:param print_destroy_map: whether to print the op destroy_map after other info :param print_destroy_map: whether to print the op destroy_map after
other info
:param print_view_map: whether to print the op view_map after other info :param print_view_map: whether to print the op view_map after other info
:param order: If not empty will print the index in the toposort. :param order: If not empty will print the index in the toposort.
:param ids: How do we print the identifier of the variable :param ids: How do we print the identifier of the variable
...@@ -592,7 +605,7 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -592,7 +605,7 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
already_printed = a in done # get_id_str put it in the dict already_printed = a in done # get_id_str put it in the dict
id_str = get_id_str(a) id_str = get_id_str(a)
if profile == None or a not in profile.apply_time: if profile is None or a not in profile.apply_time:
if len(a.outputs) == 1: if len(a.outputs) == 1:
print('%s%s %s%s \'%s\' %s %s %s' % (prefix, a.op, print('%s%s %s%s \'%s\' %s %s %s' % (prefix, a.op,
id_str, id_str,
...@@ -617,7 +630,8 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -617,7 +630,8 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
tot_time_percent = (tot_time_dict[a] / profile.fct_call_time) * 100 tot_time_percent = (tot_time_dict[a] / profile.fct_call_time) * 100
if len(a.outputs) == 1: if len(a.outputs) == 1:
print('%s%s %s%s \'%s\' %s %s %s --> %8.2es %4.1f%% %8.2es %4.1f%%'\ print("%s%s %s%s '%s' %s %s %s --> "
"%8.2es %4.1f%% %8.2es %4.1f%%"
% (prefix, a.op, % (prefix, a.op,
id_str, id_str,
type_str, type_str,
...@@ -629,7 +643,8 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -629,7 +643,8 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
tot_time, tot_time,
tot_time_percent), file=file) tot_time_percent), file=file)
else: else:
print('%s%s.%i %s%s \'%s\' %s %s %s --> %8.2es %4.1f%% %8.2es %4.1f%%'\ print("%s%s.%i %s%s '%s' %s %s %s --> "
"%8.2es %4.1f%% %8.2es %4.1f%%"
% (prefix, a.op, % (prefix, a.op,
a.outputs.index(r), a.outputs.index(r),
id_str, type_str, id_str, type_str,
...@@ -652,14 +667,15 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -652,14 +667,15 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
new_prefix_child = prefix_child + ' ' new_prefix_child = prefix_child + ' '
if hasattr(i, 'owner') and hasattr(i.owner, 'op'): if hasattr(i, 'owner') and hasattr(i.owner, 'op'):
if isinstance(i.owner.op, theano.scan_module.scan_op.Scan): if isinstance(i.owner.op,
theano.scan_module.scan_op.Scan):
scan_ops.append(i) scan_ops.append(i)
debugprint(i, new_prefix, depth=depth - 1, done=done, debugprint(i, new_prefix, depth=depth - 1, done=done,
print_type=print_type, file=file, order=order, print_type=print_type, file=file, order=order,
ids=ids, stop_on_name=stop_on_name, ids=ids, stop_on_name=stop_on_name,
prefix_child=new_prefix_child, scan_ops=scan_ops, prefix_child=new_prefix_child,
profile=profile) scan_ops=scan_ops, profile=profile)
else: else:
# this is an input variable # this is an input variable
...@@ -679,7 +695,8 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False): ...@@ -679,7 +695,8 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
:param accept_inplace: are inplace ops permitted in the original graph? :param accept_inplace: are inplace ops permitted in the original graph?
:type accept_inplace: Bool :type accept_inplace: Bool
:rtype: `FunctionGraph` :rtype: `FunctionGraph`
:returns: a new FunctionGraph with a cloned graph, with debugging `Feature` instances already installed. :returns: a new FunctionGraph with a cloned graph, with debugging
`Feature` instances already installed.
""" """
orig_inputs = [spec.variable for spec in input_specs] orig_inputs = [spec.variable for spec in input_specs]
updates = [spec.update for spec in input_specs if spec.update] updates = [spec.update for spec in input_specs if spec.update]
...@@ -687,15 +704,15 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False): ...@@ -687,15 +704,15 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
equivalence_tracker = _VariableEquivalenceTracker() equivalence_tracker = _VariableEquivalenceTracker()
fgraph = gof.fg.FunctionGraph(orig_inputs, orig_outputs, fgraph = gof.fg.FunctionGraph(orig_inputs, orig_outputs,
features=[equivalence_tracker])
# DestroyHandler may not be needed yet, as there is usually no # DestroyHandler may not be needed yet, as there is usually no
# inplace operation in the graph at this stage. DestroyHandler # inplace operation in the graph at this stage. DestroyHandler
# will be installed by an optimization after canonicalization, # will be installed by an optimization after canonicalization,
# before the inplace operations are applied. # before the inplace operations are applied. This results in a big
# This results in a big speed gain. # speed gain.
#
# If inplace operations are accepted and present, however, # If inplace operations are accepted and present, however,
# DestroyHandler will be inserted in the loop below. # DestroyHandler will be inserted in the loop below.
# features=[equivalence_tracker, gof.DestroyHandler(do_imports_on_attach=False)])
features=[equivalence_tracker])
if not accept_inplace: if not accept_inplace:
for node in fgraph.apply_nodes: for node in fgraph.apply_nodes:
...@@ -711,9 +728,10 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False): ...@@ -711,9 +728,10 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
break break
# We need to protect all immutable inputs from inplace operations. # We need to protect all immutable inputs from inplace operations.
fgraph.attach_feature(Supervisor(input for spec, input in zip(input_specs, fgraph.inputs) fgraph.attach_feature(Supervisor(
if not (spec.mutable or (hasattr(fgraph, 'destroyers') input for spec, input in zip(input_specs, fgraph.inputs)
and fgraph.destroyers(input))))) if not (spec.mutable or (hasattr(fgraph, 'destroyers') and
fgraph.destroyers(input)))))
for feature in std_fgraph.features: for feature in std_fgraph.features:
fgraph.attach_feature(feature()) fgraph.attach_feature(feature())
...@@ -797,8 +815,8 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes, ...@@ -797,8 +815,8 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes,
continue continue
if not may_share: if not may_share:
_logger.warning("Optimization Warning: input idx %d marked " _logger.warning("Optimization Warning: input idx %d marked "
"as viewed but new memory allocated by node '%s'", "as viewed but new memory allocated by node "
ii[0], str(node)) "'%s'", ii[0], str(node))
for r_idx, r in enumerate(node.inputs): for r_idx, r in enumerate(node.inputs):
if not r.type.values_eq(r_vals[r], storage_map[r][0]): if not r.type.values_eq(r_vals[r], storage_map[r][0]):
...@@ -808,10 +826,12 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes, ...@@ -808,10 +826,12 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes,
# ok, we expected r to be destroyed # ok, we expected r to be destroyed
if node in active_nodes: if node in active_nodes:
if dr_vals.get(r, (0, node))[1] is not node: if dr_vals.get(r, (0, node))[1] is not node:
# bad: there should only be one active node that destroys any variable # bad: there should only be one active node
# that destroys any variable
raise Exception('failure in topological ordering') raise Exception('failure in topological ordering')
if clobber_dr_vals: if clobber_dr_vals:
dr_vals[r] = (storage_map[r][0], node) # no copy, this is the last use of this variable # no copy, this is the last use of this variable
dr_vals[r] = (storage_map[r][0], node)
# make sure that dr_vals[r] doens't get used again # make sure that dr_vals[r] doens't get used again
storage_map[r][0] = data_destroyed storage_map[r][0] = data_destroyed
else: else:
...@@ -823,8 +843,10 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes, ...@@ -823,8 +843,10 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes,
def _check_viewmap(node, storage_map): def _check_viewmap(node, storage_map):
""" """
This functions raises a BadViewMap exception when it detects the following: This functions raises a BadViewMap exception when it detects the
- output node storages aliased to input storage, with no declaration in view_map following:
- output node storages aliased to input storage, with no declaration
in view_map
- if not aliased to an input, check if two outputs are aliased together - if not aliased to an input, check if two outputs are aliased together
and used subsequently in the graph and used subsequently in the graph
""" """
...@@ -1062,16 +1084,13 @@ def _find_bad_optimizations2(order, reasons, r_vals): ...@@ -1062,16 +1084,13 @@ def _find_bad_optimizations2(order, reasons, r_vals):
return return
checked_variables.add(r) checked_variables.add(r)
# (recursively) first check all the variables that could make r look bad: # (recursively) first check all the variables that could make
# r look bad:
list_of_vars = [old_r for (reason, old_r, olds, news) in reasons[r]] list_of_vars = [old_r for (reason, old_r, olds, news) in reasons[r]]
if (None is not r.owner): if (None is not r.owner):
list_of_vars += r.owner.inputs list_of_vars += r.owner.inputs
for var_that_could_make_r_look_bad in \ for var_that_could_make_r_look_bad in list_of_vars:
list_of_vars:
# backport
#[old_r for (reason, old_r, olds, news) in reasons[r]] \
#+ ([] if (None is r.owner) else r.owner.inputs):
check_variable(var_that_could_make_r_look_bad) check_variable(var_that_could_make_r_look_bad)
check_variable_norec(r) check_variable_norec(r)
...@@ -1087,8 +1106,8 @@ _find_bad_optimizations = _find_bad_optimizations0 ...@@ -1087,8 +1106,8 @@ _find_bad_optimizations = _find_bad_optimizations0
def _get_preallocated_maps(node, thunk, prealloc_modes, def_val, def _get_preallocated_maps(node, thunk, prealloc_modes, def_val,
storage_map, r_vals, dr_vals, perform, active_order_set, storage_map, r_vals, dr_vals, perform,
inplace_outs, init_outputs): active_order_set, inplace_outs, init_outputs):
'''Preallocate outputs in different memory layouts''' '''Preallocate outputs in different memory layouts'''
# To avoid circular imports # To avoid circular imports
...@@ -1310,8 +1329,8 @@ def _get_preallocated_maps(node, thunk, prealloc_modes, def_val, ...@@ -1310,8 +1329,8 @@ def _get_preallocated_maps(node, thunk, prealloc_modes, def_val,
def _check_preallocated_output(node, thunk, prealloc_modes, def_val, def _check_preallocated_output(node, thunk, prealloc_modes, def_val,
storage_map, r_vals, dr_vals, perform, active_order_set, storage_map, r_vals, dr_vals, perform,
inplace_outs, init_outputs): active_order_set, inplace_outs, init_outputs):
'''Try to apply thunk() on different output storages''' '''Try to apply thunk() on different output storages'''
# If node has an inner compiled Theano function with mode DebugMode, # If node has an inner compiled Theano function with mode DebugMode,
...@@ -1321,9 +1340,9 @@ def _check_preallocated_output(node, thunk, prealloc_modes, def_val, ...@@ -1321,9 +1340,9 @@ def _check_preallocated_output(node, thunk, prealloc_modes, def_val,
if type(getattr(node, 'op', None)) in ops_with_inner_function: if type(getattr(node, 'op', None)) in ops_with_inner_function:
fn_attr_name = ops_with_inner_function[type(node.op)] fn_attr_name = ops_with_inner_function[type(node.op)]
fn = getattr(node.op, fn_attr_name, None) fn = getattr(node.op, fn_attr_name, None)
if (not fn if (not fn or
or not hasattr(fn, 'maker') not hasattr(fn, 'maker') or
or not hasattr(fn.maker, 'mode')): not hasattr(fn.maker, 'mode')):
_logger.warn('Expected theano function not found in %s.%s', _logger.warn('Expected theano function not found in %s.%s',
node.op, fn_attr_name) node.op, fn_attr_name)
else: else:
...@@ -1378,7 +1397,8 @@ def _check_preallocated_output(node, thunk, prealloc_modes, def_val, ...@@ -1378,7 +1397,8 @@ def _check_preallocated_output(node, thunk, prealloc_modes, def_val,
# Check outputs # Check outputs
for r in node.outputs: for r in node.outputs:
if not r.type.is_valid_value(storage_map[r][0]): if not r.type.is_valid_value(storage_map[r][0]):
raise InvalidValueError(r, storage_map[r][0], raise InvalidValueError(
r, storage_map[r][0],
hint=thunk_name, hint=thunk_name,
specific_hint=r.type.value_validity_msg( specific_hint=r.type.value_validity_msg(
storage_map[r][0])) storage_map[r][0]))
...@@ -1393,10 +1413,13 @@ def _check_preallocated_output(node, thunk, prealloc_modes, def_val, ...@@ -1393,10 +1413,13 @@ def _check_preallocated_output(node, thunk, prealloc_modes, def_val,
for r in node.outputs: for r in node.outputs:
if not check_eq(r, r_vals[r], storage_map[r][0]): if not check_eq(r, r_vals[r], storage_map[r][0]):
# TODO: indicate it is not a C/Py problem # TODO: indicate it is not a C/Py problem
inputs_val = [storage_map[inp][0] for inp in r.owner.inputs] inputs_val = [storage_map[inp][0] for inp in
r.owner.inputs]
raise BadThunkOutput(r, raise BadThunkOutput(r,
thunk1='Reference value', val1=r_vals[r], thunk1='Reference value',
thunk2=thunk_name, val2=storage_map[r][0], val1=r_vals[r],
thunk2=thunk_name,
val2=storage_map[r][0],
inputs_val=inputs_val) inputs_val=inputs_val)
# Clear storage_map # Clear storage_map
...@@ -1455,8 +1478,6 @@ class _FunctionGraphEvent(object): ...@@ -1455,8 +1478,6 @@ class _FunctionGraphEvent(object):
str(self.op), str(self.op),
str(self.idx), str(self.idx),
msg]) msg])
# backport
# str(len(self.node.inputs)) if (self.op != 'output') else ''])
else: else:
return str(self.__dict__) return str(self.__dict__)
...@@ -1475,7 +1496,8 @@ class _FunctionGraphEvent(object): ...@@ -1475,7 +1496,8 @@ class _FunctionGraphEvent(object):
class _VariableEquivalenceTracker(object): class _VariableEquivalenceTracker(object):
"""A FunctionGraph Feature that keeps tabs on an FunctionGraph and tries to detect problems.""" """A FunctionGraph Feature that keeps tabs on an FunctionGraph and
tries to detect problems."""
fgraph = None fgraph = None
"""WRITEME""" """WRITEME"""
...@@ -1524,7 +1546,6 @@ class _VariableEquivalenceTracker(object): ...@@ -1524,7 +1546,6 @@ class _VariableEquivalenceTracker(object):
def on_prune(self, fgraph, node, reason): def on_prune(self, fgraph, node, reason):
self.event_list.append(_FunctionGraphEvent('prune', node, self.event_list.append(_FunctionGraphEvent('prune', node,
reason=reason)) reason=reason))
# print 'PRUNING NODE', node, id(node)
assert node in self.active_nodes assert node in self.active_nodes
assert node not in self.inactive_nodes assert node not in self.inactive_nodes
self.active_nodes.remove(node) self.active_nodes.remove(node)
...@@ -1534,7 +1555,6 @@ class _VariableEquivalenceTracker(object): ...@@ -1534,7 +1555,6 @@ class _VariableEquivalenceTracker(object):
self.event_list.append(_FunctionGraphEvent('import', node, self.event_list.append(_FunctionGraphEvent('import', node,
reason=reason)) reason=reason))
# print 'NEW NODE', node, id(node)
assert node not in self.active_nodes assert node not in self.active_nodes
self.active_nodes.add(node) self.active_nodes.add(node)
...@@ -1554,7 +1574,6 @@ class _VariableEquivalenceTracker(object): ...@@ -1554,7 +1574,6 @@ class _VariableEquivalenceTracker(object):
self.replaced_by.setdefault(r, []) self.replaced_by.setdefault(r, [])
def on_change_input(self, fgraph, node, i, r, new_r, reason=None): def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
# print 'CHANGE by', reason, 'to use', new_r, type(new_r)
self.event_list.append(_FunctionGraphEvent('change', node, self.event_list.append(_FunctionGraphEvent('change', node,
reason=str(reason), idx=i)) reason=str(reason), idx=i))
...@@ -1570,7 +1589,8 @@ class _VariableEquivalenceTracker(object): ...@@ -1570,7 +1589,8 @@ class _VariableEquivalenceTracker(object):
# N.B. compute the debugprint now, because future # N.B. compute the debugprint now, because future
# optimizations will change the graph # optimizations will change the graph
done = dict() done = dict()
self.reasons[new_r].append((reason, self.reasons[new_r].append(
(reason,
r, r,
debugprint(r, prefix=' ', depth=6, debugprint(r, prefix=' ', depth=6,
file=StringIO(), done=done).getvalue(), file=StringIO(), done=done).getvalue(),
...@@ -1647,26 +1667,28 @@ class _Linker(gof.link.LocalLinker): ...@@ -1647,26 +1667,28 @@ class _Linker(gof.link.LocalLinker):
self.no_recycling = no_recycling self.no_recycling = no_recycling
return self return self
def make_all(self, profiler=None, input_storage=None def make_all(self, profiler=None, input_storage=None,
, output_storage=None): output_storage=None):
if 1:
# can't import at toplevel because of circular import TODO: # can't import at toplevel because of circular import TODO:
# don't do this ugly hacky way of setting the # don't do this ugly hacky way of setting the
# filter_checks_isfinite # filter_checks_isfinite
from theano.tensor import TensorType # to set filter_check_isfinite from theano.tensor import TensorType # to set filter_check_isfinite
fgraph = self.fgraph fgraph = self.fgraph
input_storage_ = input_storage input_storage_ = input_storage
output_storage_ = output_storage output_storage_ = output_storage
#order = self.schedule(fgraph)
# Compute a topological ordering that IGNORES the destroy_map of destructive Ops. # Compute a topological ordering that IGNORES the destroy_map
# This will be OK, because every thunk is evaluated on a copy of its input. # of destructive Ops. This will be OK, because every thunk is
order_outputs = copy.copy(fgraph.equivalence_tracker.all_variables_ever) # evaluated on a copy of its input.
fgraph_equiv = fgraph.equivalence_tracker
order_outputs = copy.copy(fgraph_equiv.all_variables_ever)
del fgraph_equiv
order_outputs.reverse() order_outputs.reverse()
order = graph.io_toposort(fgraph.inputs, order_outputs) order = graph.io_toposort(fgraph.inputs, order_outputs)
active_order = self.schedule(fgraph) # an ordering of just the active nodes # an ordering of just the active nodes
active_order = self.schedule(fgraph)
active_order_set = set(active_order) active_order_set = set(active_order)
# Disable no_recycling, in order to be able to use # Disable no_recycling, in order to be able to use
...@@ -1682,9 +1704,6 @@ class _Linker(gof.link.LocalLinker): ...@@ -1682,9 +1704,6 @@ class _Linker(gof.link.LocalLinker):
thunks_c = [] # c thunks thunks_c = [] # c thunks
for node in order: for node in order:
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
compute_map = {} compute_map = {}
for k in node.inputs: for k in node.inputs:
compute_map[k] = [True] compute_map[k] = [True]
...@@ -1696,7 +1715,8 @@ class _Linker(gof.link.LocalLinker): ...@@ -1696,7 +1715,8 @@ class _Linker(gof.link.LocalLinker):
# the compilation of some dependency is triggered there. # the compilation of some dependency is triggered there.
thunk_other = None thunk_other = None
if get_unbound_function(node.op.make_thunk) not in default_make_thunk: if (get_unbound_function(node.op.make_thunk) not in
default_make_thunk):
thunk = node.op.make_thunk(node, thunk = node.op.make_thunk(node,
storage_map, storage_map,
compute_map, compute_map,
...@@ -1725,7 +1745,8 @@ class _Linker(gof.link.LocalLinker): ...@@ -1725,7 +1745,8 @@ class _Linker(gof.link.LocalLinker):
# raises an not implemented exception), so in those cases we # raises an not implemented exception), so in those cases we
# consider that we don't have a python implementation # consider that we don't have a python implementation
if ((self.maker.mode.check_py_code or thunks_c[-1] is None) and if ((self.maker.mode.check_py_code or thunks_c[-1] is None) and
node.op.perform.func_code != gof.op.PureOp.perform.func_code): (node.op.perform.func_code !=
gof.op.PureOp.perform.func_code)):
thunk = node.op.make_py_thunk(node, storage_map, compute_map, thunk = node.op.make_py_thunk(node, storage_map, compute_map,
no_recycling) no_recycling)
thunks_py.append(thunk) thunks_py.append(thunk)
...@@ -1739,7 +1760,9 @@ class _Linker(gof.link.LocalLinker): ...@@ -1739,7 +1760,9 @@ class _Linker(gof.link.LocalLinker):
elif thunks_c[-1] is None: elif thunks_c[-1] is None:
thunks_c[-1] = thunk_other thunks_c[-1] = thunk_other
else: else:
_logger.warn("We won't check the perform function of node '%s' but we will check its make_thunk function" % node) _logger.warn("We won't check the perform function "
"of node '%s' but we will check its "
"make_thunk function" % node)
thunks_py[-1] = thunk_other thunks_py[-1] = thunk_other
# Use self.no_recycling (that was passed in accept()) to always # Use self.no_recycling (that was passed in accept()) to always
...@@ -1747,7 +1770,8 @@ class _Linker(gof.link.LocalLinker): ...@@ -1747,7 +1770,8 @@ class _Linker(gof.link.LocalLinker):
# function's outputs. no_recycling_map will be used in f() below. # function's outputs. no_recycling_map will be used in f() below.
if self.no_recycling is True: if self.no_recycling is True:
no_recycling_map = storage_map.values() no_recycling_map = storage_map.values()
no_recycling_map = utils.difference(no_recycling_map, input_storage) no_recycling_map = utils.difference(no_recycling_map,
input_storage)
else: else:
no_recycling_map = [storage_map[r] for r in self.no_recycling no_recycling_map = [storage_map[r] for r in self.no_recycling
if r not in fgraph.inputs] if r not in fgraph.inputs]
...@@ -1784,14 +1808,17 @@ class _Linker(gof.link.LocalLinker): ...@@ -1784,14 +1808,17 @@ class _Linker(gof.link.LocalLinker):
# the evaluation of this function, even when the graph # the evaluation of this function, even when the graph
# has destructive ops in it # has destructive ops in it
# #
# This dictionary is used to populate the storage_map as necessary # This dictionary is used to populate the storage_map
# as necessary
r_vals = {} r_vals = {}
# dr_vals are the values taken by variables after being destroyed # dr_vals are the values taken by variables after
# being destroyed
dr_vals = {} dr_vals = {}
assert len(thunks_py) == len(order) assert len(thunks_py) == len(order)
# transfer the initial values from the storage_map to the r_vals # transfer the initial values from the storage_map to
# the r_vals
_logger.debug("DEBUGMODE: transfer initial values") _logger.debug("DEBUGMODE: transfer initial values")
# r_vals_initialized keeps track of the values that have # r_vals_initialized keeps track of the values that have
# actually been transferred from storage_map to r_vals # actually been transferred from storage_map to r_vals
...@@ -1803,17 +1830,20 @@ class _Linker(gof.link.LocalLinker): ...@@ -1803,17 +1830,20 @@ class _Linker(gof.link.LocalLinker):
# for a Generic object). We only want to raise # for a Generic object). We only want to raise
# an error if it is not valid. # an error if it is not valid.
if (storage_map[r][0] is None): if (storage_map[r][0] is None):
raise InvalidValueError(r, storage_map[r][0], raise InvalidValueError(
hint="Graph Input '%s' is missing" % str(r)) r, storage_map[r][0],
raise InvalidValueError(r, storage_map[r][0], hint=("Graph Input '%s' is missing" %
str(r)))
raise InvalidValueError(
r, storage_map[r][0],
hint=("Graph Input '%s' has invalid value " hint=("Graph Input '%s' has invalid value "
"%s" % (r, storage_map[r][0]))) "%s" % (r, storage_map[r][0])))
r_vals[r] = storage_map[r][0] r_vals[r] = storage_map[r][0]
storage_map[r][0] = None storage_map[r][0] = None
r_vals_initialized.append(r) r_vals_initialized.append(r)
# store preallocated outputs in another map, and test the thunks on # store preallocated outputs in another map, and test
# them as output storages. # the thunks on them as output storages.
init_outputs = {} init_outputs = {}
for r in storage_map: for r in storage_map:
if r in fgraph.outputs: if r in fgraph.outputs:
...@@ -1835,8 +1865,6 @@ class _Linker(gof.link.LocalLinker): ...@@ -1835,8 +1865,6 @@ class _Linker(gof.link.LocalLinker):
for i, (thunk_py, thunk_c, node) in enumerate(zip(thunks_py, for i, (thunk_py, thunk_c, node) in enumerate(zip(thunks_py,
thunks_c, thunks_c,
order)): order)):
this_node_destroyed_variables = set()
_logger.debug("%i - starting node %i %s", i, i, node) _logger.debug("%i - starting node %i %s", i, i, node)
# put a copy of each input into the storage_map # put a copy of each input into the storage_map
...@@ -1844,7 +1872,6 @@ class _Linker(gof.link.LocalLinker): ...@@ -1844,7 +1872,6 @@ class _Linker(gof.link.LocalLinker):
for r in node.inputs: for r in node.inputs:
assert isinstance(r, gof.Variable) assert isinstance(r, gof.Variable)
assert r in r_vals assert r in r_vals
# print >> sys.stderr,i, "DEBUGMODE: deepcopy input ", r
storage_map[r][0] = _lessbroken_deepcopy(r_vals[r]) storage_map[r][0] = _lessbroken_deepcopy(r_vals[r])
if not r.type.is_valid_value(storage_map[r][0]): if not r.type.is_valid_value(storage_map[r][0]):
raise InvalidValueError(r, storage_map[r][0], raise InvalidValueError(r, storage_map[r][0],
...@@ -1872,10 +1899,12 @@ class _Linker(gof.link.LocalLinker): ...@@ -1872,10 +1899,12 @@ class _Linker(gof.link.LocalLinker):
raise raise
opt = str(reason[0][0]) opt = str(reason[0][0])
msg = ( msg = (
"An optimization (probably %s ) inserted an apply node that raise an error." % opt + "An optimization (probably %s) inserted an "
"\nThe information we have about this optimizations is:" + str(reason[0][1]) + "apply node that raise an error." % opt +
"\n" + reason[0][2] + "\nThe information we have about this "
"\n\nThe original exception: \n" + str(e)) "optimizations is:" + str(reason[0][1]) +
"\n" + reason[0][2] +
"\n\nThe original exception: \n" + str(e))
new_e = e.__class__(msg) new_e = e.__class__(msg)
exc_type, exc_value, exc_trace = sys.exc_info() exc_type, exc_value, exc_trace = sys.exc_info()
exc_value = new_e exc_value = new_e
...@@ -1891,19 +1920,19 @@ class _Linker(gof.link.LocalLinker): ...@@ -1891,19 +1920,19 @@ class _Linker(gof.link.LocalLinker):
raise InvalidValueError(r, storage_map[r][0], raise InvalidValueError(r, storage_map[r][0],
hint='perform output', hint='perform output',
specific_hint=hint2) specific_hint=hint2)
warn_inp = config.DebugMode.warn_input_not_reused
py_inplace_outs = _check_inputs( py_inplace_outs = _check_inputs(
node, storage_map, r_vals, dr_vals, node, storage_map, r_vals, dr_vals,
active_order_set, active_order_set,
clobber_dr_vals=True, perform='py', clobber_dr_vals=True, perform='py',
warn_input_not_reused=config.DebugMode.warn_input_not_reused) warn_input_not_reused=warn_inp)
_check_viewmap(node, storage_map) _check_viewmap(node, storage_map)
# Retrieve each output from the storage_map # Retrieve each output from the storage_map.
# The return values of this first run will be the reference ones # The return values of this first run will be
# the reference ones
for r in node.outputs: for r in node.outputs:
assert r not in r_vals assert r not in r_vals
# print >> sys.stderr, i, "DEBUGMODE storing reference output %x" % id(storage_map[r][0])
r_vals[r] = storage_map[r][0] r_vals[r] = storage_map[r][0]
# clear the storage_map of outputs for the thunk_c # clear the storage_map of outputs for the thunk_c
storage_map[r][0] = None storage_map[r][0] = None
...@@ -1927,9 +1956,6 @@ class _Linker(gof.link.LocalLinker): ...@@ -1927,9 +1956,6 @@ class _Linker(gof.link.LocalLinker):
inplace_outs=py_inplace_outs, inplace_outs=py_inplace_outs,
init_outputs=init_outputs) init_outputs=init_outputs)
# print >> sys.stderr, i, "DEBUGMODE thunk_py %100s %50s %30s" % (node,
#[(id(o), numpy.asarray(storage_map[o][0])[0,0]) for o in node.inputs],
#[(id(o), numpy.asarray(storage_map[o][0])[0,0]) for o in node.outputs])
sys.stdout.flush() sys.stdout.flush()
if thunk_c: if thunk_c:
...@@ -1939,18 +1965,22 @@ class _Linker(gof.link.LocalLinker): ...@@ -1939,18 +1965,22 @@ class _Linker(gof.link.LocalLinker):
dmap = getattr(node.op, 'destroy_map', {}) dmap = getattr(node.op, 'destroy_map', {})
vmap = getattr(node.op, 'view_map', {}) vmap = getattr(node.op, 'view_map', {})
for i, r in enumerate(node.inputs): for i, r in enumerate(node.inputs):
# if thunk_py ran, and we still got this far, # if thunk_py ran, and we still got
# it means that the destroy_map of the Op (and view_map) are # this far, it means that the
# accurate # destroy_map of the Op (and view_map)
# so we can assume that inputs not marked as destroyed have in # are accurate so we can assume that
# fact not been destroyed. # inputs not marked as destroyed have
# Therefore... we only need to overwrite inputs that *have* # in fact not been destroyed.
# been marked as destroyed. # Therefore... we only need to
# Inputs marked as viewd are unsafe too, # overwrite inputs that *have* been
# because the corresponding output can # marked as destroyed. Inputs marked
# be destroyed. # as viewd are unsafe too, because the
if any(i in v for v in (dmap.values() + vmap.values())): # corresponding output can be
storage_map[r][0] = _lessbroken_deepcopy(r_vals[r]) # destroyed.
if any(i in v for v in (dmap.values() +
vmap.values())):
storage_map[r][0] = _lessbroken_deepcopy(
r_vals[r])
clobber = False clobber = False
...@@ -1969,10 +1999,12 @@ class _Linker(gof.link.LocalLinker): ...@@ -1969,10 +1999,12 @@ class _Linker(gof.link.LocalLinker):
raise raise
opt = str(reason[0][0]) opt = str(reason[0][0])
msg = ( msg = (
"An optimization (probably %s ) inserted an apply node that raise an error." % opt + "An optimization (probably %s) inserted "
"\nThe information we have about this optimizations is:" + str(reason[0][1]) + "an apply node that raise an error." % opt +
"\n" + reason[0][2] + "\nThe information we have about this "
"\n\nThe original exception: \n" + str(e)) "optimizations is:" + str(reason[0][1]) +
"\n" + reason[0][2] +
"\n\nThe original exception: \n" + str(e))
new_e = e.__class__(msg) new_e = e.__class__(msg)
exc_type, exc_value, exc_trace = sys.exc_info() exc_type, exc_value, exc_trace = sys.exc_info()
exc_value = new_e exc_value = new_e
...@@ -1982,21 +2014,26 @@ class _Linker(gof.link.LocalLinker): ...@@ -1982,21 +2014,26 @@ class _Linker(gof.link.LocalLinker):
for r in node.outputs: for r in node.outputs:
# check output values for type-correctness # check output values for type-correctness
if not r.type.is_valid_value(storage_map[r][0]): if not r.type.is_valid_value(storage_map[r][0]):
raise InvalidValueError(r, storage_map[r][0], hint='c output') raise InvalidValueError(r, storage_map[r][0],
hint='c output')
if thunk_py: if thunk_py:
assert r in r_vals # because we put it in during the thunk_py branch # because we put it in during the
# check for stride correctness (may raise exception) # thunk_py branch
_check_strides_match(r_vals[r], assert r in r_vals
storage_map[r][0], # check for stride correctness (may
# raise exception)
_check_strides_match(
r_vals[r], storage_map[r][0],
self.maker.mode.require_matching_strides, self.maker.mode.require_matching_strides,
node.op) node.op)
warn_inp = config.DebugMode.warn_input_not_reused
c_inplace_outs = _check_inputs( c_inplace_outs = _check_inputs(
node, storage_map, r_vals, node, storage_map, r_vals,
dr_vals, active_order_set, dr_vals, active_order_set,
clobber_dr_vals=clobber, perform='c', clobber_dr_vals=clobber, perform='c',
warn_input_not_reused=config.DebugMode.warn_input_not_reused) warn_input_not_reused=warn_inp)
_check_viewmap(node, storage_map) _check_viewmap(node, storage_map)
...@@ -2006,11 +2043,14 @@ class _Linker(gof.link.LocalLinker): ...@@ -2006,11 +2043,14 @@ class _Linker(gof.link.LocalLinker):
# compares the version from thunk_py # compares the version from thunk_py
# (in r_vals) to the version produced # (in r_vals) to the version produced
# by thunk_c (in storage_map) # by thunk_c (in storage_map)
if not check_eq(r, r_vals[r], storage_map[r][0]): if not check_eq(r, r_vals[r],
inputs_val = [storage_map[inp][0] for inp in r.owner.inputs] storage_map[r][0]):
inputs_val = [storage_map[inp][0]
for inp in r.owner.inputs]
raise BadThunkOutput( raise BadThunkOutput(
r, thunk1='perform', val1=r_vals[r], r, thunk1='perform', val1=r_vals[r],
thunk2='c_code', val2=storage_map[r][0], thunk2='c_code',
val2=storage_map[r][0],
inputs_val=inputs_val) inputs_val=inputs_val)
else: else:
# retrieve each output from the storage_map # retrieve each output from the storage_map
...@@ -2021,6 +2061,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -2021,6 +2061,7 @@ class _Linker(gof.link.LocalLinker):
if self.maker.mode.check_preallocated_output: if self.maker.mode.check_preallocated_output:
prealloc_modes = \ prealloc_modes = \
self.maker.mode.check_preallocated_output self.maker.mode.check_preallocated_output
def thunk(): def thunk():
try: try:
thunk_c() thunk_c()
...@@ -2042,15 +2083,11 @@ class _Linker(gof.link.LocalLinker): ...@@ -2042,15 +2083,11 @@ class _Linker(gof.link.LocalLinker):
inplace_outs=c_inplace_outs, inplace_outs=c_inplace_outs,
init_outputs=init_outputs) init_outputs=init_outputs)
# print >> sys.stderr, i, "DEBUGMODE thunk_c %100s %50s %30s" % (node,
#[(id(o), numpy.asarray(storage_map[o][0])[0,0]) for o in node.inputs],
#[(id(o), numpy.asarray(storage_map[o][0])[0,0]) for o in node.outputs])
sys.stdout.flush() sys.stdout.flush()
# we're done with this thunk # we're done with this thunk
# clear everything out of the storage_map # clear everything out of the storage_map
for r in node.inputs: for r in node.inputs:
#print >> sys.stderr, i, "DEBUGMODE clearing input", r
storage_map[r][0] = None storage_map[r][0] = None
_logger.debug("%i - done with node", i) _logger.debug("%i - done with node", i)
...@@ -2059,7 +2096,8 @@ class _Linker(gof.link.LocalLinker): ...@@ -2059,7 +2096,8 @@ class _Linker(gof.link.LocalLinker):
# But it is very slow and it is not sure it will help. # But it is very slow and it is not sure it will help.
gc.collect() gc.collect()
_find_bad_optimizations(order, fgraph.equivalence_tracker.reasons, _find_bad_optimizations(order,
fgraph.equivalence_tracker.reasons,
r_vals) r_vals)
##### #####
...@@ -2081,18 +2119,24 @@ class _Linker(gof.link.LocalLinker): ...@@ -2081,18 +2119,24 @@ class _Linker(gof.link.LocalLinker):
for r in r_vals: for r in r_vals:
if r.owner is None: if r.owner is None:
if r in fgraph.inputs: if r in fgraph.inputs:
assert storage_map[r] is input_storage[fgraph.inputs.index(r)] assert (storage_map[r] is
input_storage[fgraph.inputs.index(r)])
storage_map[r][0] = r_vals[r] storage_map[r][0] = r_vals[r]
# if an input was destroyed, the destroyed value should be returned # if an input was destroyed, the destroyed value
# should be returned
for r in dr_vals: for r in dr_vals:
assert dr_vals[r][0] is not None assert dr_vals[r][0] is not None
if r.owner is None: if r.owner is None:
assert r in fgraph.inputs assert r in fgraph.inputs
# HACK TO LOOK LIKE A REAL DESTRUCTIVE ACTION TOOK PLACE # HACK TO LOOK LIKE A REAL DESTRUCTIVE ACTION
if type(dr_vals[r][0]) in (numpy.ndarray, numpy.memmap) \ # TOOK PLACE
and dr_vals[r][0].dtype == storage_map[r][0].dtype \ if ((type(dr_vals[r][0]) in
and dr_vals[r][0].shape == storage_map[r][0].shape: (numpy.ndarray, numpy.memmap)) and
(dr_vals[r][0].dtype ==
storage_map[r][0].dtype) and
(dr_vals[r][0].shape ==
storage_map[r][0].shape)):
if len(dr_vals[r][0].shape): if len(dr_vals[r][0].shape):
storage_map[r][0][:] = dr_vals[r][0] storage_map[r][0][:] = dr_vals[r][0]
else: else:
...@@ -2111,10 +2155,6 @@ class _Linker(gof.link.LocalLinker): ...@@ -2111,10 +2155,6 @@ class _Linker(gof.link.LocalLinker):
storage_map[r][0] = None storage_map[r][0] = None
raise raise
# print ""
# print output_storage
# print dr_vals
# print storage_map
for r in storage_map: for r in storage_map:
if (r.owner is None): if (r.owner is None):
if not r.type.is_valid_value(None): if not r.type.is_valid_value(None):
...@@ -2130,12 +2170,14 @@ class _Linker(gof.link.LocalLinker): ...@@ -2130,12 +2170,14 @@ class _Linker(gof.link.LocalLinker):
# so it will screw up if we are trying to use # so it will screw up if we are trying to use
# multiple modes at once. # multiple modes at once.
old_filter_checks_isfinite = TensorType.filter_checks_isfinite old_filter_checks_isfinite = TensorType.filter_checks_isfinite
TensorType.filter_checks_isfinite = self.maker.mode.check_isfinite TensorType.filter_checks_isfinite = \
self.maker.mode.check_isfinite
try: try:
return f() return f()
finally: finally:
# put back the filter_checks_isfinite # put back the filter_checks_isfinite
TensorType.filter_checks_isfinite = old_filter_checks_isfinite TensorType.filter_checks_isfinite = \
old_filter_checks_isfinite
return deco return deco
f = run_with_tensortype_filter_check(f) f = run_with_tensortype_filter_check(f)
...@@ -2143,12 +2185,12 @@ class _Linker(gof.link.LocalLinker): ...@@ -2143,12 +2185,12 @@ class _Linker(gof.link.LocalLinker):
f.allow_gc = True f.allow_gc = True
assert len(fgraph.inputs) == len(input_storage) assert len(fgraph.inputs) == len(input_storage)
assert len(fgraph.outputs) == len(output_storage) assert len(fgraph.outputs) == len(output_storage)
# print 'make_all returning output', [id(z) for z in output_storage] return (f,
return f, [link.Container(input, storage, readonly=False) [link.Container(input, storage, readonly=False)
for input, storage in zip(fgraph.inputs, input_storage)], \ for input, storage in zip(fgraph.inputs, input_storage)],
[link.Container(output, storage, readonly=True) [link.Container(output, storage, readonly=True)
for output, storage in zip(fgraph.outputs, output_storage)], \ for output, storage in zip(fgraph.outputs, output_storage)],
thunks_py, order thunks_py, order)
_NODEFAULT = ['NODEFAULT'] _NODEFAULT = ['NODEFAULT']
...@@ -2170,25 +2212,29 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2170,25 +2212,29 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
""" """
:type inputs: a list of SymbolicInput instances :type inputs: a list of SymbolicInput instances
:type outputs: a list of SymbolicOutput instances :type outputs: a list of SymbolicOutput instances outputs may
outputs may also be a single Variable (not a list), in which also be a single Variable (not a list), in
case the functions produced by FunctionMaker will return which case the functions produced by
their output value directly FunctionMaker will return their output value
directly
:param accept_inplace: True iff it is acceptable to have :param accept_inplace: True iff it is acceptable to have
inplace operations in the graph from the inputs to inplace operations in the graph from the inputs to
the outputs the outputs
:param on_unused_input: What to do if a variable in the 'inputs' list is :param on_unused_input: What to do if a variable in the
not used in the graph. Possible values are 'raise', 'warn', and 'ignore'. 'inputs' list is not used in the
graph. Possible values are 'raise',
'warn', and 'ignore'.
:param output_keys: If the outputs argument for theano.function was a :param output_keys: If the outputs argument for
list, then output_keys is None. If the outputs argument was a dict, theano.function was a list, then
then output_keys is a sorted list of the keys from that dict. output_keys is None. If the outputs
argument was a dict, then output_keys is a
sorted list of the keys from that dict.
:note: this function sets TensorType.filter_checks_isfinite :note: this function sets TensorType.filter_checks_isfinite
when `mode.check_isfinite` is True when `mode.check_isfinite` is True
""" """
self.profile = profile self.profile = profile
# Handle the case where inputs and/or outputs is a single # Handle the case where inputs and/or outputs is a single
...@@ -2205,7 +2251,8 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2205,7 +2251,8 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
inputs = [inputs] inputs = [inputs]
# Wrap them in In or Out instances if needed. # Wrap them in In or Out instances if needed.
inputs, outputs = map(self.wrap_in, inputs), map(self.wrap_out, outputs) inputs = map(self.wrap_in, inputs)
outputs = map(self.wrap_out, outputs)
_inputs = gof.graph.inputs([o.variable for o in outputs] + _inputs = gof.graph.inputs([o.variable for o in outputs] +
[i.update for i in inputs [i.update for i in inputs
if getattr(i, 'update', False)]) if getattr(i, 'update', False)])
...@@ -2213,9 +2260,11 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2213,9 +2260,11 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
# Check if some input variables are unused # Check if some input variables are unused
self._check_unused_inputs(inputs, outputs, on_unused_input) self._check_unused_inputs(inputs, outputs, on_unused_input)
# Make a list of (SymbolicInput|SymblicInputKits, indices, [SymbolicInput,...]), one # Make a list of (SymbolicInput|SymblicInputKits, indices,
# tuple for each input. (See Function.indices for more details) # [SymbolicInput,...]), one tuple for each input. (See
indices = [[input] + self.expand_in(input, _inputs) for input in inputs] # Function.indices for more details)
indices = [[input] + self.expand_in(input, _inputs)
for input in inputs]
# make the fgraph # make the fgraph
for i in xrange(mode.stability_patience): for i in xrange(mode.stability_patience):
...@@ -2226,58 +2275,53 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2226,58 +2275,53 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
# optimize the fgraph # optimize the fgraph
compute_test_value_orig = theano.config.compute_test_value compute_test_value_orig = theano.config.compute_test_value
try: try:
theano.config.compute_test_value = theano.config.compute_test_value_opt theano.config.compute_test_value = \
theano.config.compute_test_value_opt
optimizer(fgraph) optimizer(fgraph)
theano.compile.function_module.insert_deepcopy(fgraph, inputs, theano.compile.function_module.insert_deepcopy(
outputs + additional_outputs) fgraph, inputs, outputs + additional_outputs)
finally: finally:
theano.config.compute_test_value = compute_test_value_orig theano.config.compute_test_value = compute_test_value_orig
if i: if i == 0:
fgraph0 = fgraph
else:
li = fgraph.equivalence_tracker.event_list li = fgraph.equivalence_tracker.event_list
l0 = fgraph0.equivalence_tracker.event_list l0 = fgraph0.equivalence_tracker.event_list
if li != l0 : if li != l0:
infolog = StringIO() infolog = StringIO()
print("WARNING: Optimization process is unstable...", file=infolog) print("WARNING: Optimization process is unstable...",
print(" (HINT: Ops that the nodes point to must compare equal)", file=infolog) file=infolog)
print("(event index) (one event trace) (other event trace)", file=infolog) print(" (HINT: Ops that the nodes point to must compare "
print("-----------------------------------------------------", file=infolog) "equal)", file=infolog)
print("(event index) (one event trace) (other event "
"trace)", file=infolog)
print("-------------------------------------------------"
"----", file=infolog)
for j in xrange(max(len(li), len(l0))): for j in xrange(max(len(li), len(l0))):
if j >= len(li): if j >= len(li):
print('trailing event in optimization 0 :', j, file=infolog) print('trailing event in optimization 0 :', j,
file=infolog)
print(' ', str(l0[j]), file=infolog) print(' ', str(l0[j]), file=infolog)
elif j >= len(l0): elif j >= len(l0):
print('trailing event in optimization', i, ':', j, file=infolog) print('trailing event in optimization', i, ':',
j, file=infolog)
print(' ', str(li[j]), file=infolog) print(' ', str(li[j]), file=infolog)
elif li[j] != l0[j]: elif li[j] != l0[j]:
print('non-equal optimization events', i, ':', j, file=infolog) print('non-equal optimization events', i, ':',
j, file=infolog)
print(' ', str(l0[j]), file=infolog) print(' ', str(l0[j]), file=infolog)
print(' ', str(li[j]), file=infolog) print(' ', str(li[j]), file=infolog)
#print >> infolog, "* ", j,
# if j < len(li):
# msg = str(li[j])
# else:
# msg = '-'
#print >> infolog, " ", msg
# if j < len(l0):
# msg = str(l0[j])
# else:
# msg = '-'
#print >> infolog, " ", msg
else: else:
pass pass
raise StochasticOrder(infolog.getvalue()) raise StochasticOrder(infolog.getvalue())
else: else:
if self.verbose: if self.verbose:
print("OPTCHECK: optimization", i, \ print("OPTCHECK: optimization", i,
"of", len(li), "events was stable.", file=sys.stderr) "of", len(li), "events was stable.",
else: file=sys.stderr)
fgraph0 = fgraph
del fgraph0
self.fgraph = fgraph self.fgraph = fgraph
# equivalence_tracker.printstuff()
linker = _Linker(self) linker = _Linker(self)
...@@ -2532,8 +2576,7 @@ class DebugMode(Mode): ...@@ -2532,8 +2576,7 @@ class DebugMode(Mode):
raise Exception("DebugMode can only use its own linker! You " raise Exception("DebugMode can only use its own linker! You "
"should not provide one.", linker) "should not provide one.", linker)
super(DebugMode, self).__init__( super(DebugMode, self).__init__(optimizer=optimizer,
optimizer=optimizer,
linker=linker) linker=linker)
if stability_patience is not None: if stability_patience is not None:
......
"""Define the `function` function """Define the `function` function
""" """
__docformat__ = "restructuredtext en"
import cPickle import cPickle
import logging import logging
_logger = logging.getLogger('theano.compile.function')
import traceback as tb import traceback as tb
import re import re
...@@ -14,9 +11,11 @@ from theano.compile.function_module import orig_function ...@@ -14,9 +11,11 @@ from theano.compile.function_module import orig_function
from theano.compile.pfunc import pfunc from theano.compile.pfunc import pfunc
from numpy import any from numpy import any
import warnings import warnings
from theano import gof
from theano import compat from theano import compat
__docformat__ = "restructuredtext en"
_logger = logging.getLogger('theano.compile.function')
def function_dump(filename, inputs, outputs=None, mode=None, updates=None, def function_dump(filename, inputs, outputs=None, mode=None, updates=None,
givens=None, givens=None,
...@@ -70,54 +69,67 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -70,54 +69,67 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
:type mode: string or `Mode` instance. :type mode: string or `Mode` instance.
:param mode: compilation mode :param mode: compilation mode
:type updates: iterable over pairs (shared_variable, new_expression). List, tuple or OrderedDict. :type updates: iterable over pairs (shared_variable, new_expression).
:param updates: update the values for SharedVariable inputs according to these expressions List, tuple or OrderedDict.
:param updates: update the values for SharedVariable inputs
according to these expressions
:type givens: iterable over pairs (Var1, Var2) of Variables. List, tuple or dict. The Var1 :type givens: iterable over pairs (Var1, Var2) of Variables. List,
and Var2 in each pair must have the same Type. tuple or dict. The Var1 and Var2 in each pair must
have the same Type.
:param givens: specific substitutions to make in the computation graph (Var2 replaces :param givens: specific substitutions to make in the computation
Var1). graph (Var2 replaces Var1).
:type no_default_updates: either bool or list of Variables :type no_default_updates: either bool or list of Variables
:param no_default_updates: if True, do not perform any automatic update on Variables. :param no_default_updates: if True, do not perform any automatic
If False (default), perform them all. Else, perform automatic updates on all Variables update on Variables. If False (default), perform them
that are neither in "updates" nor in "no_default_updates". all. Else, perform automatic updates on all Variables that are
neither in "updates" nor in "no_default_updates".
:param name: an optional name for this function. The profile mode will print the time spent in this function.
:param name: an optional name for this function. The profile mode
:param rebuild_strict: True (Default) is the safer and better tested setting, in which case will print the time spent in this function.
`givens` must substitute new variables with the same Type as the variables they replace.
False is a you-better-know-what-you-are-doing setting, that permits `givens` to replace :param rebuild_strict: True (Default) is the safer and better
variables with new variables of any Type. The consequence of changing a Type is that all tested setting, in which case `givens` must substitute new
results depending on that variable may have a different Type too (the graph is rebuilt from variables with the same Type as the variables they replace.
inputs to outputs). If one of the new types does not make sense for one of the Ops in the False is a you-better-know-what-you-are-doing setting, that
permits `givens` to replace variables with new variables of
any Type. The consequence of changing a Type is that all
results depending on that variable may have a different Type
too (the graph is rebuilt from inputs to outputs). If one of
the new types does not make sense for one of the Ops in the
graph, an Exception will be raised. graph, an Exception will be raised.
:type allow_input_downcast: Boolean or None :type allow_input_downcast: Boolean or None
:param allow_input_downcast: True means that the values passed as :param allow_input_downcast: True means that the values passed as
inputs when calling the function can be silently downcasted to fit inputs when calling the function can be silently downcasted to
the dtype of the corresponding Variable, which may lose precision. fit the dtype of the corresponding Variable, which may lose
False means that it will only be cast to a more general, or precision. False means that it will only be cast to a more
precise, type. None (default) is almost like False, but allows general, or precise, type. None (default) is almost like
downcasting of Python float scalars to floatX. False, but allows downcasting of Python float scalars to
floatX.
:type profile: None, True, or ProfileStats instance :type profile: None, True, or ProfileStats instance
:param profile: accumulate profiling information into a given ProfileStats :param profile: accumulate profiling information into a given
instance. If argument is `True` then a new ProfileStats instance will be ProfileStats instance. If argument is `True` then a new
used. This profiling object will be available via self.profile. ProfileStats instance will be used. This profiling object
will be available via self.profile.
:param on_unused_input: What to do if a variable in the 'inputs' list is :param on_unused_input: What to do if a variable in the 'inputs'
not used in the graph. Possible values are 'raise', 'warn', 'ignore' and None. list is not used in the graph. Possible values are 'raise',
'warn', 'ignore' and None.
:rtype: Function instance :rtype: Function instance
:returns: a callable object that will compute the outputs (given the inputs) :returns: a callable object that will compute the outputs (given
and update the implicit function arguments according to the `updates`. the inputs) and update the implicit function arguments
according to the `updates`.
:note: Regarding givens: Be careful to make sure that these substitutions are :note: Regarding givens: Be careful to make sure that these
independent--behaviour when Var1 of one pair appears in the graph leading to Var2 in substitutions are independent--behaviour when Var1 of one pair
another expression is undefined. Replacements specified with givens are different from appears in the graph leading to Var2 in another expression is
optimizations in that Var2 is not expected to be equivalent to Var1. undefined. Replacements specified with givens are different
from optimizations in that Var2 is not expected to be
equivalent to Var1.
Internal documentation: Internal documentation:
...@@ -195,10 +207,6 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -195,10 +207,6 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
was easier to develop the VM in Python then translate it to C instead was easier to develop the VM in Python then translate it to C instead
of just writing it in C from scratch. of just writing it in C from scratch.
CVM stands for C Virtual Machine. CVM stands for C Virtual Machine.
""" """
if isinstance(outputs, dict): if isinstance(outputs, dict):
output_items = outputs.items() output_items = outputs.items()
...@@ -214,7 +222,6 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -214,7 +222,6 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
output_keys.append(pair[0]) output_keys.append(pair[0])
outputs.append(pair[1]) outputs.append(pair[1])
else: else:
output_keys = None output_keys = None
...@@ -256,12 +263,13 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -256,12 +263,13 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
if givens is None: if givens is None:
givens = [] givens = []
if not isinstance(inputs, (list, tuple)): if not isinstance(inputs, (list, tuple)):
raise Exception("Input variables of a Theano function should be" raise Exception("Input variables of a Theano function should be "
" contained in a list, even when there is a single input.") "contained in a list, even when there is a single "
"input.")
# compute some features of the arguments: # compute some features of the arguments:
uses_In = any([isinstance(i, In) for i in inputs]) # N.B. the square brackets are ncessary uses_In = any([isinstance(i, In) for i in inputs])
uses_tuple = any([isinstance(i, (list, tuple)) for i in inputs]) # N.B. the square brackets are ncessary uses_tuple = any([isinstance(i, (list, tuple)) for i in inputs])
uses_updates = bool(updates) uses_updates = bool(updates)
uses_givens = bool(givens) uses_givens = bool(givens)
...@@ -275,7 +283,8 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -275,7 +283,8 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
if uses_In or uses_tuple: if uses_In or uses_tuple:
# we must use old semantics in this case. # we must use old semantics in this case.
if profile: if profile:
raise NotImplementedError('profiling not supported in old-style function') raise NotImplementedError("profiling not supported in old-style "
"function")
if uses_updates or uses_givens: if uses_updates or uses_givens:
raise NotImplementedError( raise NotImplementedError(
"In() instances and tuple inputs trigger the old " "In() instances and tuple inputs trigger the old "
...@@ -284,8 +293,8 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None, ...@@ -284,8 +293,8 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
mode=mode, mode=mode,
accept_inplace=accept_inplace, name=name) accept_inplace=accept_inplace, name=name)
else: else:
# note: pfunc will also call orig_function-- orig_function is a choke point # note: pfunc will also call orig_function-- orig_function is
# that all compilation must pass through # a choke point that all compilation must pass through
fn = pfunc(params=inputs, fn = pfunc(params=inputs,
outputs=outputs, outputs=outputs,
mode=mode, mode=mode,
......
"""Driver of graph construction, optimization, and linking. """Driver of graph construction, optimization, and linking.
""" """
from __future__ import print_function from __future__ import print_function
__docformat__ = "restructuredtext en"
import copy import copy
import copy_reg import copy_reg
import cPickle import cPickle
...@@ -26,6 +23,8 @@ from theano.gof.op import ops_with_inner_function ...@@ -26,6 +23,8 @@ from theano.gof.op import ops_with_inner_function
import logging import logging
_logger = logging.getLogger('theano.compile.function_module') _logger = logging.getLogger('theano.compile.function_module')
__docformat__ = "restructuredtext en"
class UnusedInputError(Exception): class UnusedInputError(Exception):
""" """
...@@ -35,7 +34,7 @@ class UnusedInputError(Exception): ...@@ -35,7 +34,7 @@ class UnusedInputError(Exception):
def alias_root(v): def alias_root(v):
"""Return the variable to which v is aliased by view_maps and destroy_maps""" "Return the variable to which v is aliased by view_maps and destroy_maps"
if v.owner is None: if v.owner is None:
return v return v
vmap = getattr(v.owner.op, 'view_map', {}) vmap = getattr(v.owner.op, 'view_map', {})
...@@ -54,7 +53,8 @@ def alias_root(v): ...@@ -54,7 +53,8 @@ def alias_root(v):
def view_tree_set(v, treeset): def view_tree_set(v, treeset):
"""Add to `treeset` all variables that are views of v, given that v is not a view""" """Add to `treeset` all variables that are views of v, given that v is
not a view"""
treeset.add(v) treeset.add(v)
for cl, v_input_pos_to_cl in v.clients: for cl, v_input_pos_to_cl in v.clients:
if cl == 'output': if cl == 'output':
...@@ -69,11 +69,13 @@ def view_tree_set(v, treeset): ...@@ -69,11 +69,13 @@ def view_tree_set(v, treeset):
def infer_reuse_pattern(fgraph, outputs_to_disown): def infer_reuse_pattern(fgraph, outputs_to_disown):
""" """
Given an fgraph and a list of variables, returns the list or set of all variables which may Given an fgraph and a list of variables, returns the list or set
share the same underlying data storage as any of the specified variables. Used internally of all variables which may share the same underlying data storage
by function, FunctionMaker. as any of the specified variables. Used internally by function,
FunctionMaker.
This list (or set) is also refered to as no_recycling sometimes, especially by linker code. This list (or set) is also refered to as no_recycling sometimes,
especially by linker code.
""" """
rval = set() rval = set()
for o in outputs_to_disown: for o in outputs_to_disown:
...@@ -103,10 +105,10 @@ def fgraph_updated_vars(fgraph, expanded_inputs): ...@@ -103,10 +105,10 @@ def fgraph_updated_vars(fgraph, expanded_inputs):
class Supervisor: class Supervisor:
""" """
Listener for FunctionGraph events which makes sure that no operation overwrites the Listener for FunctionGraph events which makes sure that no
contents of protected Variables. The outputs of the FunctionGraph are protected by default. operation overwrites the contents of protected Variables. The
outputs of the FunctionGraph are protected by default.
""" """
def __init__(self, protected): def __init__(self, protected):
self.protected = list(protected) self.protected = list(protected)
...@@ -176,33 +178,38 @@ class AliasedMemoryError(Exception): ...@@ -176,33 +178,38 @@ class AliasedMemoryError(Exception):
# Function # Function
### ###
# unique id object used as a placeholder for duplicate entries
DUPLICATE = ['DUPLICATE'] # unique id object used as a placeholder for duplicate entries DUPLICATE = ['DUPLICATE']
class Function(object): class Function(object):
""" """
Type of the functions returned by theano.function or theano.FunctionMaker.create. Type of the functions returned by theano.function or
theano.FunctionMaker.create.
`Function` is the callable object that does computation. It has the storage of inputs and
outputs, performs the packing and unpacking of inputs and return values. It implements the
square-bracket indexing so that you can look up the value of a symbolic node.
Functions are copyable via {{{fn.copy()}}} and {{{copy.copy(fn)}}}.
When a function is copied, this instance is duplicated. Contrast with self.maker
(instance of `FunctionMaker`) that is shared between copies.
The meaning of copying a function is that the containers and their current values will all be duplicated.
This requires that mutable inputs be copied, whereas immutable inputs may be shared between copies.
`Function` is the callable object that does computation. It has
the storage of inputs and outputs, performs the packing and
unpacking of inputs and return values. It implements the
square-bracket indexing so that you can look up the value of a
symbolic node.
Functions are copyable via {{{fn.copy()}}} and
{{{copy.copy(fn)}}}. When a function is copied, this instance is
duplicated. Contrast with self.maker (instance of
`FunctionMaker`) that is shared between copies. The meaning of
copying a function is that the containers and their current values
will all be duplicated. This requires that mutable inputs be
copied, whereas immutable inputs may be shared between copies.
A Function instance is hashable, on the basis of its memory address (its id). A Function instance is hashable, on the basis of its memory
address (its id).
A Function instance is only equal to itself. A Function instance is only equal to itself.
A Function instance may be serialized using the `pickle` or `cPickle` modules. A Function instance may be serialized using the `pickle` or
This will save all default inputs, the graph, and *** to the pickle file (WRITEME). `cPickle` modules. This will save all default inputs, the graph,
and *** to the pickle file (WRITEME).
A Function instance have a ``trust_input`` field that default to A Function instance have a ``trust_input`` field that default to
False. When True, we don't do extra check of the input to give False. When True, we don't do extra check of the input to give
...@@ -210,7 +217,6 @@ class Function(object): ...@@ -210,7 +217,6 @@ class Function(object):
the good results if you pass a python or numpy scalar instead of a the good results if you pass a python or numpy scalar instead of a
numpy tensor. C code should raise an error if you pass an object numpy tensor. C code should raise an error if you pass an object
of the wrong type. of the wrong type.
""" """
pickle_aliased_memory_strategy = 'warn' pickle_aliased_memory_strategy = 'warn'
...@@ -218,12 +224,11 @@ class Function(object): ...@@ -218,12 +224,11 @@ class Function(object):
Meaningful settings are: 'ignore', 'warn', 'raise' Meaningful settings are: 'ignore', 'warn', 'raise'
If the value is 'warn', then a message will be printed to stderr if aliased storage is If the value is 'warn', then a message will be printed to stderr
dectected during pickle.dump. if aliased storage is dectected during pickle.dump.
If the value is 'raise', then an AliasedMemoryError will be raised if aliased storage is
detected during pickle.dump.
If the value is 'raise', then an AliasedMemoryError will be raised
if aliased storage is detected during pickle.dump.
""" """
input_storage = None input_storage = None
...@@ -233,24 +238,28 @@ class Function(object): ...@@ -233,24 +238,28 @@ class Function(object):
"""list of Container instances""" """list of Container instances"""
indices = None indices = None
"""list of (SymbolicInput|SymbolicInputKit, indices, [SymbolicInput,...]), one tuple for """list of (SymbolicInput|SymbolicInputKit, indices,
each input [SymbolicInput,...]), one tuple for each input
The first tuple element is the SymbolicInput object for the corresponding function input. The first tuple element is the SymbolicInput object for the
corresponding function input.
The second and third tuple elements are used only by Kits, which are deprecated. The second and third tuple elements are used only by Kits, which
are deprecated.
""" """
defaults = None defaults = None
""" list of 3-tuples, one 3-tuple for each input. """ list of 3-tuples, one 3-tuple for each input.
Tuple element 0: Bool: Is this input required at each function call? Tuple element 0: Bool: Is this input required at each function call?
Tuple element 1: Bool: Should this inputs value be reverted after each call? Tuple element 1: Bool: Should this inputs value be reverted after
each call?
Tuple element 2: Any: The value associated with this input. Tuple element 2: Any: The value associated with this input.
""" """
unpack_single = None unpack_single = None
"""Bool: for outputs lists of length 1, should the 0'th element be returned directly?""" """Bool: for outputs lists of length 1, should the 0'th element be
returned directly?"""
return_none = None return_none = None
"""Bool: whether the function should return None or not""" """Bool: whether the function should return None or not"""
...@@ -259,8 +268,8 @@ class Function(object): ...@@ -259,8 +268,8 @@ class Function(object):
"""FunctionMaker instance""" """FunctionMaker instance"""
fn = None fn = None
"""a function that evaluates the graph. Typically a linker's make_thunk method created this """a function that evaluates the graph. Typically a linker's
function.""" make_thunk method created this function."""
finder = None finder = None
"""Dictionary mapping several kinds of things to containers. """Dictionary mapping several kinds of things to containers.
...@@ -273,7 +282,8 @@ class Function(object): ...@@ -273,7 +282,8 @@ class Function(object):
- the name of the input - the name of the input
All entries map to the container or to DUPLICATE if an ambiguity is detected All entries map to the container or to DUPLICATE if an ambiguity
is detected
""" """
inv_finder = None inv_finder = None
...@@ -312,20 +322,22 @@ class Function(object): ...@@ -312,20 +322,22 @@ class Function(object):
input.distribute(value, indices, cs) input.distribute(value, indices, cs)
for c in cs: for c in cs:
c.provided += 1 c.provided += 1
# def assign(c, v):
#c.data = v
# Store the list of names of named inputs. # Store the list of names of named inputs.
named_inputs = [] named_inputs = []
# Count the number of un-named inputs. # Count the number of un-named inputs.
n_unnamed_inputs = 0 n_unnamed_inputs = 0
#setters = []
# Initialize the storage # Initialize the storage
# this loop works by modifying the elements (as variable c) of self.input_storage inplace. # this loop works by modifying the elements (as variable c) of
for i, ((input, indices, sinputs), (required, refeed, value)) in enumerate(zip(self.indices, defaults)): # self.input_storage inplace.
if indices is None: # this is true iff input is not a SymbolicInputKit for i, ((input, indices, sinputs), (required, refeed, value)) in \
c = containers[0] #containers is being used as a stack. Here we pop off the next one. enumerate(zip(self.indices, defaults)):
# this is true iff input is not a SymbolicInputKit
if indices is None:
# containers is being used as a stack. Here we pop off
# the next one.
c = containers[0]
c.strict = getattr(input, 'strict', False) c.strict = getattr(input, 'strict', False)
c.allow_downcast = getattr(input, 'allow_downcast', None) c.allow_downcast = getattr(input, 'allow_downcast', None)
...@@ -342,7 +354,9 @@ class Function(object): ...@@ -342,7 +354,9 @@ class Function(object):
c.value = value c.value = value
c.required = required c.required = required
c.implicit = input.implicit c.implicit = input.implicit
c.provided = 0 # this is a count of how many times the input has been provided (reinitialized to 0 on __call__) # this is a count of how many times the input has been
# provided (reinitialized to 0 on __call__)
c.provided = 0
finder[i] = c finder[i] = c
finder[input.variable] = c finder[input.variable] = c
if input.name not in finder: if input.name not in finder:
...@@ -353,17 +367,14 @@ class Function(object): ...@@ -353,17 +367,14 @@ class Function(object):
n_unnamed_inputs += 1 n_unnamed_inputs += 1
else: else:
named_inputs.append(input.name) named_inputs.append(input.name)
# backport
#finder[input.name] = c if input.name not in finder else DUPLICATE
# inv_finder maps the container to the input (useful for one error message)
inv_finder[c] = input inv_finder[c] = input
#setters.append(partial(assign, c))
containers[:1] = [] containers[:1] = []
else: else:
# TODO The following code may need to do something to handle # TODO The following code may need to do something to handle
# implicit inputs. # implicit inputs.
# The input is a SymbolicInputKit, so we take as many containers as the Kit provides inputs # The input is a SymbolicInputKit, so we take as many
# containers as the Kit provides inputs
cs = containers[:len(indices)] cs = containers[:len(indices)]
# distribute does the initialization of the containers # distribute does the initialization of the containers
input.distribute(value, indices, cs) input.distribute(value, indices, cs)
...@@ -377,12 +388,11 @@ class Function(object): ...@@ -377,12 +388,11 @@ class Function(object):
finder[input.name] = f finder[input.name] = f
else: else:
finder[input.name] = DUPLICATE finder[input.name] = DUPLICATE
# backport # For each input in the kit and its corresponding
#finder[input.name] = f if input.name not in finder else DUPLICATE # container, we put an entry in finder. This allows
# setters.append(f) # the user to micro-manage elements of the kit if need
# For each input in the kit and its corresponding container, we put an entry in finder. # be. All containers inherit the required field and
# This allows the user to micro-manage elements of the kit if need be. # have their own "provided" counter
# All containers inherit the required field and have their own "provided" counter
for c, sin in zip(cs, sinputs): for c, sin in zip(cs, sinputs):
finder[sin.variable] = c finder[sin.variable] = c
finder[sin.name] = c finder[sin.name] = c
...@@ -390,8 +400,6 @@ class Function(object): ...@@ -390,8 +400,6 @@ class Function(object):
finder[sin.name] = c finder[sin.name] = c
else: else:
finder[sin.name] = DUPLICATE finder[sin.name] = DUPLICATE
# backport
#finder[sin.name] = c if sin.name not in finder else DUPLICATE
inv_finder[c] = input inv_finder[c] = input
c.required = required c.required = required
c.provided = 0 c.provided = 0
...@@ -410,12 +418,14 @@ class Function(object): ...@@ -410,12 +418,14 @@ class Function(object):
except KeyError: except KeyError:
raise TypeError("Unknown input or state: %s" % str(item)) raise TypeError("Unknown input or state: %s" % str(item))
if s is DUPLICATE: if s is DUPLICATE:
raise TypeError("Ambiguous name: %s - please check the names "\ raise TypeError("Ambiguous name: %s - please check the "
"of the inputs of your function for duplicates." % str(item)) "names of the inputs of your function "
"for duplicates." % str(item))
if isinstance(s, gof.Container): if isinstance(s, gof.Container):
return s.value return s.value
else: else:
raise NotImplementedError raise NotImplementedError
def __setitem__(self, item, value): def __setitem__(self, item, value):
try: try:
s = finder[item] s = finder[item]
...@@ -425,13 +435,15 @@ class Function(object): ...@@ -425,13 +435,15 @@ class Function(object):
raise TypeError("Unknown input or state: %s. %s" % raise TypeError("Unknown input or state: %s. %s" %
(str(item), msg)) (str(item), msg))
if s is DUPLICATE: if s is DUPLICATE:
raise TypeError("Ambiguous name: %s - please check the names "\ raise TypeError("Ambiguous name: %s - please check the "
"of the inputs of your function for duplicates." % str(item)) "names of the inputs of your function "
"for duplicates." % str(item))
if isinstance(s, gof.Container): if isinstance(s, gof.Container):
s.value = value s.value = value
s.provided += 1 s.provided += 1
else: else:
s(value) s(value)
def __contains__(self, item): def __contains__(self, item):
return finder.__contains__(item) return finder.__contains__(item)
...@@ -441,6 +453,7 @@ class Function(object): ...@@ -441,6 +453,7 @@ class Function(object):
class ContainerAttribute(object): class ContainerAttribute(object):
def __getitem__(self, item): def __getitem__(self, item):
return finder[item] return finder[item]
def __contains__(self, item): def __contains__(self, item):
return finder.__contains__(item) return finder.__contains__(item)
# You cannot set the container # You cannot set the container
...@@ -513,20 +526,17 @@ class Function(object): ...@@ -513,20 +526,17 @@ class Function(object):
s.storage[0] = arg s.storage[0] = arg
else: else:
try: try:
s.storage[0] = s.type.filter(arg, strict=s.strict, s.storage[0] = s.type.filter(
arg, strict=s.strict,
allow_downcast=s.allow_downcast) allow_downcast=s.allow_downcast)
except Exception as e: except Exception as e:
function_name = "theano function" function_name = "theano function"
if self.name: if self.name:
function_name += ' with name "' + self.name + '" ' function_name += ' with name "' + self.name + '" '
# end if e.args = ("Bad input argument to " + function_name +
e.args = tuple(["Bad input argument to " + function_name + " at index %d(0-based)" % i,) + e.args
" at index %d(0-based)" % i] +
list(e.args))
raise raise
# end except
# end if
s.provided += 1 s.provided += 1
i += 1 i += 1
...@@ -535,9 +545,8 @@ class Function(object): ...@@ -535,9 +545,8 @@ class Function(object):
for k, arg in kwargs.iteritems(): for k, arg in kwargs.iteritems():
self[k] = arg self[k] = arg
if not self.trust_input and ( if (not self.trust_input and
not hasattr(self, '_check_for_aliased_inputs') or getattr(self, '_check_for_aliased_inputs', True)):
self._check_for_aliased_inputs):
# Collect aliased inputs among the storage space # Collect aliased inputs among the storage space
args_share_memory = [] args_share_memory = []
for i in xrange(len(self.input_storage)): for i in xrange(len(self.input_storage)):
...@@ -566,10 +575,6 @@ class Function(object): ...@@ -566,10 +575,6 @@ class Function(object):
# Check for groups of more than one argument that share memory # Check for groups of more than one argument that share memory
for group in args_share_memory: for group in args_share_memory:
if len(group) > 1: if len(group) > 1:
# see if any of these arguments are mutable
mutable = numpy.any([(self.maker.inputs[idx].mutable or
self.maker.inputs[idx].borrow)
for idx in group])
# copy all but the first # copy all but the first
for idx in group[1:]: for idx in group[1:]:
self.input_storage[i].storage[0] = copy.copy( self.input_storage[i].storage[0] = copy.copy(
...@@ -696,13 +701,15 @@ class Function(object): ...@@ -696,13 +701,15 @@ class Function(object):
container = property( container = property(
lambda self: self._container, lambda self: self._container,
None, # this property itself is not settable None, # this property itself is not settable
doc="""dictionary-like access to the containers associated with Variables""") doc=("dictionary-like access to the containers associated with "
"Variables"))
def free(self): def free(self):
""" """
When allow_gc = False, clear the Variables in storage_map When allow_gc = False, clear the Variables in storage_map
""" """
# 1.no allow_gc return False 2.has allow_gc, if allow_gc is False, return True # 1.no allow_gc return False
# 2.has allow_gc, if allow_gc is False, return True
if not getattr(self.fn, 'allow_gc', True): if not getattr(self.fn, 'allow_gc', True):
for key in self.fn.storage_map.keys(): for key in self.fn.storage_map.keys():
if not isinstance(key, theano.gof.Constant): if not isinstance(key, theano.gof.Constant):
...@@ -719,7 +726,8 @@ def _pickle_Function(f): ...@@ -719,7 +726,8 @@ def _pickle_Function(f):
ins = list(f.input_storage) ins = list(f.input_storage)
input_storage = [] input_storage = []
for (input, indices, inputs), (required, refeed, default) in zip(f.indices, f.defaults): for (input, indices, inputs), (required, refeed, default) in \
zip(f.indices, f.defaults):
if isinstance(input, SymbolicInputKit): if isinstance(input, SymbolicInputKit):
li = len(indices) li = len(indices)
if not default: if not default:
...@@ -734,18 +742,21 @@ def _pickle_Function(f): ...@@ -734,18 +742,21 @@ def _pickle_Function(f):
inputs_data = [x.data for x in f.input_storage] inputs_data = [x.data for x in f.input_storage]
# HACK to detect aliased storage. # HACK to detect aliased storage.
# This is here because aliased relationships are not [currently] preserved across the pickle operation # This is here because aliased relationships are not [currently]
# preserved across the pickle operation
if not (f.pickle_aliased_memory_strategy == 'ignore'): if not (f.pickle_aliased_memory_strategy == 'ignore'):
all_data = input_storage + inputs_data # addition here means list append all_data = input_storage + inputs_data
for i, d_i in enumerate(all_data): for i, d_i in enumerate(all_data):
for j, d_j in enumerate(all_data): for j, d_j in enumerate(all_data):
if (i < j) and isinstance(d_i, numpy.ndarray) and isinstance(d_j, numpy.ndarray): if ((i < j) and isinstance(d_i, numpy.ndarray) and
isinstance(d_j, numpy.ndarray)):
if numpy.may_share_memory(d_i, d_j): if numpy.may_share_memory(d_i, d_j):
if f.pickle_aliased_memory_strategy == 'warn': if f.pickle_aliased_memory_strategy == 'warn':
_logger.warning(('aliased relationship between' _logger.warning('aliased relationship between '
' Function arguments %s, %s' 'Function arguments %s, %s '
' will not be preserved by un-pickling' 'will not be preserved by '
' operation') % (str(d_i), str(d_j))) 'un-pickling operation' %
(str(d_i), str(d_j)))
else: else:
raise AliasedMemoryError(d_i, d_j) raise AliasedMemoryError(d_i, d_j)
rval = (_constructor_Function, (f.maker, input_storage, inputs_data)) rval = (_constructor_Function, (f.maker, input_storage, inputs_data))
...@@ -774,20 +785,25 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): ...@@ -774,20 +785,25 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
""" """
Insert deepcopy in the fgraph to break aliasing of outputs Insert deepcopy in the fgraph to break aliasing of outputs
""" """
# This loop was inserted to remove aliasing between outputs when they all # This loop was inserted to remove aliasing between outputs when
# evaluete to the same value. Originally it was OK for outputs to be aliased, # they all evaluete to the same value. Originally it was OK for
# but some of the outputs can be shared variables, and is not good for shared # outputs to be aliased, but some of the outputs can be shared
# variables to be aliased. It might be possible to optimize this by making sure # variables, and is not good for shared variables to be
# aliased. It might be possible to optimize this by making sure
# there is no aliasing only between shared variables. # there is no aliasing only between shared variables.
# If some outputs are constant, we add deep copy to respect the memory contract # If some outputs are constant, we add deep copy to respect the
# memory contract
# We don't insert deep copy when the output.borrow is True for all conserned outputs. # We don't insert deep copy when the output.borrow is True for all
# conserned outputs.
assert len(wrapped_inputs) == len(fgraph.inputs) assert len(wrapped_inputs) == len(fgraph.inputs)
assert len(wrapped_outputs) == len(fgraph.outputs) assert len(wrapped_outputs) == len(fgraph.outputs)
reason = "insert_deepcopy" reason = "insert_deepcopy"
updated_fgraph_inputs = [fgraph_i for i, fgraph_i in zip(wrapped_inputs, fgraph.inputs) if getattr(i, 'update', False)] updated_fgraph_inputs = [fgraph_i for i, fgraph_i in
zip(wrapped_inputs, fgraph.inputs)
if getattr(i, 'update', False)]
# We can't use fgraph.inputs as this don't include Constant Value. # We can't use fgraph.inputs as this don't include Constant Value.
all_graph_inputs = gof.graph.inputs(fgraph.outputs) all_graph_inputs = gof.graph.inputs(fgraph.outputs)
...@@ -802,10 +818,12 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): ...@@ -802,10 +818,12 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
# and not(wrapped_outputs[i].borrow and wrapped_outputs[j].borrow): # and not(wrapped_outputs[i].borrow and wrapped_outputs[j].borrow):
if fgraph.outputs[j] in views_of_output_i: if fgraph.outputs[j] in views_of_output_i:
if wrapped_outputs[i].borrow and wrapped_outputs[j].borrow: if wrapped_outputs[i].borrow and wrapped_outputs[j].borrow:
fgraph.change_input('output', i, view_op(fgraph.outputs[i]), fgraph.change_input('output', i,
view_op(fgraph.outputs[i]),
reason=reason) reason=reason)
else: else:
fgraph.change_input('output', i, deep_copy_op(fgraph.outputs[i]), fgraph.change_input('output', i,
deep_copy_op(fgraph.outputs[i]),
reason=reason) reason=reason)
copied = True copied = True
break break
...@@ -813,31 +831,40 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): ...@@ -813,31 +831,40 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
if not copied: if not copied:
for input_j in all_graph_inputs: for input_j in all_graph_inputs:
# do not allow outputs to be aliased to an inputs (j), unless # do not allow outputs to be aliased to an inputs (j), unless
# a) that j'th input has been 'destroyed' by e.g. in-place computations # a) that j'th input has been 'destroyed' by
# b) that j'th input is a shared variable that is also being updated # e.g. in-place computations
# b) that j'th input is a shared variable that is also
# being updated
if (hasattr(fgraph, 'get_destroyers_of') and if (hasattr(fgraph, 'get_destroyers_of') and
fgraph.get_destroyers_of(input_j)): fgraph.get_destroyers_of(input_j)):
continue continue
if input_j in updated_fgraph_inputs: if input_j in updated_fgraph_inputs:
continue continue
if input_j in views_of_output_i: if input_j in views_of_output_i:
# We don't put deep_copy_op if the input and the output have borrow==True # We don't put deep_copy_op if the input and the
# output have borrow==True
if input_j in fgraph.inputs: if input_j in fgraph.inputs:
j = fgraph.inputs.index(input_j) j = fgraph.inputs.index(input_j)
if wrapped_outputs[i].borrow and wrapped_inputs[j].borrow: if (wrapped_outputs[i].borrow and
fgraph.change_input('output', i, view_op(fgraph.outputs[i]), wrapped_inputs[j].borrow):
fgraph.change_input('output', i,
view_op(fgraph.outputs[i]),
reason="insert_deepcopy") reason="insert_deepcopy")
break break
else: else:
fgraph.change_input('output', i, deep_copy_op(fgraph.outputs[i]), fgraph.change_input(
'output', i,
deep_copy_op(fgraph.outputs[i]),
reason="insert_deepcopy") reason="insert_deepcopy")
break break
elif wrapped_outputs[i].borrow: elif wrapped_outputs[i].borrow:
fgraph.change_input('output', i, view_op(fgraph.outputs[i]), fgraph.change_input('output', i,
view_op(fgraph.outputs[i]),
reason="insert_deepcopy") reason="insert_deepcopy")
break break
else: else:
fgraph.change_input('output', i, deep_copy_op(fgraph.outputs[i]), fgraph.change_input('output', i,
deep_copy_op(fgraph.outputs[i]),
reason="insert_deepcopy") reason="insert_deepcopy")
break break
...@@ -866,17 +893,20 @@ class FunctionMaker(object): ...@@ -866,17 +893,20 @@ class FunctionMaker(object):
if len(input) == 2: if len(input) == 2:
return SymbolicInput(input[0], update=input[1]) return SymbolicInput(input[0], update=input[1])
else: else:
raise TypeError("Expected two elements in the list or tuple.", input) raise TypeError("Expected two elements in the list or tuple.",
input)
else: else:
raise TypeError("Unknown input type: %s (%s), expected Variable instance", type(input), input) raise TypeError("Unknown input type: %s (%s), expected Variable "
"instance", type(input), input)
@staticmethod @staticmethod
def expand_in(sinput, rinputs): def expand_in(sinput, rinputs):
# For SymbolicInputKits, this extracts a list of SymbolicInput instances # For SymbolicInputKits, this extracts a list of SymbolicInput
# and corresponding indices such that these SymbolicInputs are representative # instances and corresponding indices such that these
# of some of the Variable instances in inputs. # SymbolicInputs are representative of some of the Variable
# For SymbolicInput, this returns None as the list of indices and a list with # instances in inputs. For SymbolicInput, this returns None
# just the SymbolicInput. # as the list of indices and a list with just the
# SymbolicInput.
if isinstance(sinput, SymbolicInputKit): if isinstance(sinput, SymbolicInputKit):
return sinput.complete(rinputs) return sinput.complete(rinputs)
elif isinstance(sinput, SymbolicInput): elif isinstance(sinput, SymbolicInput):
...@@ -889,24 +919,25 @@ class FunctionMaker(object): ...@@ -889,24 +919,25 @@ class FunctionMaker(object):
elif isinstance(output, gof.Variable): elif isinstance(output, gof.Variable):
return SymbolicOutput(output) return SymbolicOutput(output)
else: else:
raise TypeError("Unknown output type: %s (%s)", type(output), output) raise TypeError("Unknown output type: %s (%s)", type(output),
output)
def optimize_graph_with_cache(self, optimizer, inputs, outputs): def optimize_graph_with_cache(self, optimizer, inputs, outputs):
# This function is not finished # This function is not finished
from theano.gof.compilelock import get_lock, release_lock from theano.gof.compilelock import get_lock, release_lock
import os.path import os.path
graph_db_file = os.path.join(theano.config.compiledir, 'optimized_graphs.pkl') graph_db_file = os.path.join(theano.config.compiledir,
'optimized_graphs.pkl')
# the inputs, outputs, and size of the graph to be optimized # the inputs, outputs, and size of the graph to be optimized
inputs_new = [inp.variable for inp in inputs] inputs_new = [inp.variable for inp in inputs]
outputs_new = [out.variable for out in outputs] outputs_new = [out.variable for out in outputs]
size_new = len(self.fgraph.apply_nodes) size_new = len(self.fgraph.apply_nodes)
need_optimize = False
get_lock() get_lock()
key = None
# Beginning of cache optimizations. # Beginning of cache optimizations.
# Could be refactored in different functions. # Could be refactored in different functions.
def load_graph_db(): def load_graph_db():
if os.path.isfile(graph_db_file): if os.path.isfile(graph_db_file):
print('graph_db already exists') print('graph_db already exists')
...@@ -919,8 +950,9 @@ class FunctionMaker(object): ...@@ -919,8 +950,9 @@ class FunctionMaker(object):
# load the graph_db dictionary # load the graph_db dictionary
try: try:
f = open(graph_db_file, 'rb') f = open(graph_db_file, 'rb')
# Temporary hack to allow theano.scan_module.tests.test_scan.T_Scan # Temporary hack to allow
# to finish. Should be changed in definitive version. # theano.scan_module.tests.test_scan.T_Scan to
# finish. Should be changed in definitive version.
tmp = theano.config.unpickle_function tmp = theano.config.unpickle_function
theano.config.unpickle_function = False theano.config.unpickle_function = False
graph_db = cPickle.load(f) graph_db = cPickle.load(f)
...@@ -961,16 +993,21 @@ class FunctionMaker(object): ...@@ -961,16 +993,21 @@ class FunctionMaker(object):
# two graphs are for sure different # two graphs are for sure different
print('need to optimize, because output size is different') print('need to optimize, because output size is different')
continue continue
elif not all(input_new.type == input_old.type for elif not all(input_new.type == input_old.type
input_new, input_old in zip(inputs_new, inputs_old)): for input_new, input_old in
print('need to optimize, because inputs are of different types') zip(inputs_new, inputs_old)):
print('need to optimize, because inputs are of different '
'types')
continue continue
elif not all(output_new.type == output_old.type for elif not all(output_new.type == output_old.type
output_new, output_old in zip(outputs_new, outputs_old)): for output_new, output_old in
print('need to optimize, because outputs are of different types') zip(outputs_new, outputs_old)):
print('need to optimize, because outputs are of different '
'types')
continue continue
elif not size_old == size_new: elif not size_old == size_new:
print('need to optimize, because numbers of nodes in graph are different') print('need to optimize, because numbers of nodes in graph'
' are different')
continue continue
else: else:
flags = [] flags = []
...@@ -1032,7 +1069,8 @@ class FunctionMaker(object): ...@@ -1032,7 +1069,8 @@ class FunctionMaker(object):
return found_graph_in_db return found_graph_in_db
graph_db = load_graph_db() graph_db = load_graph_db()
print('loaded graph_db from %s, size=%d' % (graph_db_file, len(graph_db))) print('loaded graph_db from %s, size=%d' % (graph_db_file,
len(graph_db)))
found_graph = find_same_graph_in_db(graph_db) found_graph = find_same_graph_in_db(graph_db)
if found_graph: if found_graph:
self.fgraph = found_graph self.fgraph = found_graph
...@@ -1043,7 +1081,7 @@ class FunctionMaker(object): ...@@ -1043,7 +1081,7 @@ class FunctionMaker(object):
self.fgraph.variables = set(gof.graph.variables( self.fgraph.variables = set(gof.graph.variables(
self.fgraph.inputs, self.fgraph.outputs)) self.fgraph.inputs, self.fgraph.outputs))
# check_integrity parameters was added to ignore # check_integrity parameters was added to ignore
#"excess cached variables" errors. Works that way # "excess cached variables" errors. Works that way
# but once again the error couldbe worth # but once again the error couldbe worth
# investigating. # investigating.
before_opt = self.fgraph.clone(check_integrity=False) before_opt = self.fgraph.clone(check_integrity=False)
...@@ -1063,16 +1101,18 @@ class FunctionMaker(object): ...@@ -1063,16 +1101,18 @@ class FunctionMaker(object):
""" """
:type inputs: a list of SymbolicInput instances :type inputs: a list of SymbolicInput instances
:type outputs: a list of SymbolicOutput instances
outputs may also be a single Variable (not a list), in which
case the functions produced by FunctionMaker will return
their output value directly
:param mode: a Mode instance telling FunctionMaker how to optimize and link. None :type outputs: a list of SymbolicOutput instances outputs may
means to use the `config.mode`. also be a single Variable (not a list), in which case the
functions produced by FunctionMaker will return their
output value directly
:param accept_inplace: True iff it is acceptable to have inplace operations :param mode: a Mode instance telling FunctionMaker how to
in the graph from the inputs to the outputs optimize and link. None means to use the `config.mode`.
:param accept_inplace: True iff it is acceptable to have
inplace operations in the graph from the inputs to the
outputs
:param on_unused_input: What to do if a variable in the 'inputs' list :param on_unused_input: What to do if a variable in the 'inputs' list
is not used in the graph. Possible values are: is not used in the graph. Possible values are:
...@@ -1098,9 +1138,10 @@ class FunctionMaker(object): ...@@ -1098,9 +1138,10 @@ class FunctionMaker(object):
# This is very important: # This is very important:
# 1) We preload the cache here to don't have its timming # 1) We preload the cache here to don't have its timming
# included in optimization that compile function. # included in optimization that compile function.
# 2) Do not refresh the cache here by default. It cause too much # 2) Do not refresh the cache here by default. It cause
# execution time during testing as we compile much more functions # too much execution time during testing as we compile
# then the number of compile c module. # much more functions then the number of compile c
# module.
theano.gof.cc.get_module_cache().refresh() theano.gof.cc.get_module_cache().refresh()
# Handle the case where inputs and/or outputs is a single # Handle the case where inputs and/or outputs is a single
# Variable (not in a list) # Variable (not in a list)
...@@ -1117,21 +1158,27 @@ class FunctionMaker(object): ...@@ -1117,21 +1158,27 @@ class FunctionMaker(object):
inputs = [inputs] inputs = [inputs]
# Wrap them in In or Out instances if needed. # Wrap them in In or Out instances if needed.
inputs, outputs = map(self.wrap_in, inputs), map(self.wrap_out, outputs) inputs = map(self.wrap_in, inputs)
_inputs = gof.graph.inputs([o.variable for o in outputs] + [i.update outputs = map(self.wrap_out, outputs)
for i in inputs if getattr(i, 'update', False)]) _inputs = gof.graph.inputs([o.variable for o in outputs] +
[i.update for i in inputs
if getattr(i, 'update', False)])
# Check if some input variables are unused # Check if some input variables are unused
self._check_unused_inputs(inputs, outputs, on_unused_input) self._check_unused_inputs(inputs, outputs, on_unused_input)
# Make a list of (SymbolicInput|SymblicInputKits, indices, [SymbolicInput,...]), one # Make a list of (SymbolicInput|SymblicInputKits, indices,
# tuple for each input. (See Function.indices for more details) # [SymbolicInput,...]), one tuple for each input. (See
indices = [[input] + self.expand_in(input, _inputs) for input in inputs] # Function.indices for more details)
indices = [[input] + self.expand_in(input, _inputs)
for input in inputs]
if fgraph is None: if fgraph is None:
need_opt = True need_opt = True
# make the fgraph (copies the graph, creates NEW INPUT AND OUTPUT VARIABLES) # make the fgraph (copies the graph, creates NEW INPUT AND
fgraph, additional_outputs = std_fgraph(inputs, outputs, accept_inplace) # OUTPUT VARIABLES)
fgraph, additional_outputs = std_fgraph(inputs, outputs,
accept_inplace)
fgraph.profile = profile fgraph.profile = profile
else: else:
# fgraph is already an optimized one # fgraph is already an optimized one
...@@ -1149,7 +1196,8 @@ class FunctionMaker(object): ...@@ -1149,7 +1196,8 @@ class FunctionMaker(object):
# Why we add stack on node when it get done in output var? # Why we add stack on node when it get done in output var?
try: try:
# optimize the fgraph # optimize the fgraph
theano.config.compute_test_value = theano.config.compute_test_value_opt theano.config.compute_test_value = \
theano.config.compute_test_value_opt
theano.config.traceback.limit = 0 theano.config.traceback.limit = 0
start_optimizer = time.time() start_optimizer = time.time()
...@@ -1165,7 +1213,8 @@ class FunctionMaker(object): ...@@ -1165,7 +1213,8 @@ class FunctionMaker(object):
if profile: if profile:
profile.optimizer_time += opt_time profile.optimizer_time += opt_time
if theano.config.profile_optimizer: if theano.config.profile_optimizer:
profile.optimizer_profile = (optimizer, optimizer_profile) profile.optimizer_profile = (optimizer,
optimizer_profile)
_logger.debug('Optimizing took %f seconds', opt_time) _logger.debug('Optimizing took %f seconds', opt_time)
# Add deep copy to respect the memory interface # Add deep copy to respect the memory interface
...@@ -1176,14 +1225,19 @@ class FunctionMaker(object): ...@@ -1176,14 +1225,19 @@ class FunctionMaker(object):
# initialize the linker # initialize the linker
if not hasattr(linker, 'accept'): if not hasattr(linker, 'accept'):
raise ValueError("'linker' parameter of FunctionMaker should be a Linker with an accept method " \ raise ValueError("'linker' parameter of FunctionMaker should be "
"or one of %s" % theano.compile.mode.predefined_linkers.keys()) "a Linker with an accept method or one of %s" %
theano.compile.mode.predefined_linkers.keys())
# the 'no_borrow' outputs are the ones for which that we can't return the internal storage pointer. # the 'no_borrow' outputs are the ones for which that we can't
# return the internal storage pointer.
assert len(fgraph.outputs) == len(outputs + additional_outputs) assert len(fgraph.outputs) == len(outputs + additional_outputs)
no_borrow = [output for output, spec in zip(fgraph.outputs, outputs + additional_outputs) if not spec.borrow] no_borrow = [output for output, spec in
zip(fgraph.outputs, outputs + additional_outputs)
if not spec.borrow]
if no_borrow: if no_borrow:
self.linker = linker.accept(fgraph, no_recycling=infer_reuse_pattern(fgraph, no_borrow)) self.linker = linker.accept(
fgraph, no_recycling=infer_reuse_pattern(fgraph, no_borrow))
else: else:
self.linker = linker.accept(fgraph) self.linker = linker.accept(fgraph)
...@@ -1209,8 +1263,7 @@ class FunctionMaker(object): ...@@ -1209,8 +1263,7 @@ class FunctionMaker(object):
(i.value is not None and (i.value is not None and
not isinstance(i.value, gof.Container) and not isinstance(i.value, gof.Container) and
i.update is None) i.update is None)
for i in self.inputs for i in self.inputs]
]
def _check_unused_inputs(self, inputs, outputs, on_unused_input): def _check_unused_inputs(self, inputs, outputs, on_unused_input):
if on_unused_input is None: if on_unused_input is None:
...@@ -1223,8 +1276,8 @@ class FunctionMaker(object): ...@@ -1223,8 +1276,8 @@ class FunctionMaker(object):
# - variables that have to be provided (used_inputs) # - variables that have to be provided (used_inputs)
# - shared variables that will be updated # - shared variables that will be updated
used_inputs = gof.graph.ancestors( used_inputs = gof.graph.ancestors(
([o.variable for o in outputs] ([o.variable for o in outputs] +
+ [i.update for i in inputs if getattr(i, 'update', False)]), [i.update for i in inputs if getattr(i, 'update', False)]),
blockers=[i.variable for i in inputs]) blockers=[i.variable for i in inputs])
msg = ("theano.function was asked to create a function computing " msg = ("theano.function was asked to create a function computing "
...@@ -1241,39 +1294,46 @@ class FunctionMaker(object): ...@@ -1241,39 +1294,46 @@ class FunctionMaker(object):
for i in inputs: for i in inputs:
if ((i.variable not in used_inputs) and (i.update is None)): if ((i.variable not in used_inputs) and (i.update is None)):
if on_unused_input == 'warn': if on_unused_input == 'warn':
warnings.warn(msg % (inputs.index(i), i.variable, warn_msg), stacklevel=6) warnings.warn(msg % (inputs.index(i), i.variable,
warn_msg), stacklevel=6)
elif on_unused_input == 'raise': elif on_unused_input == 'raise':
raise UnusedInputError(msg % (inputs.index(i), i.variable, err_msg)) raise UnusedInputError(msg % (inputs.index(i),
i.variable, err_msg))
else: else:
raise ValueError(("Invalid value for keyword " raise ValueError("Invalid value for keyword "
"on_unused_input of theano.function: '%s'. " "on_unused_input of theano.function: "
"valid values are 'raise', 'warn', and 'ignore'." "'%s'.\nValid values are 'raise', "
% on_unused_input)) "'warn', and 'ignore'." % on_unused_input)
def create(self, input_storage=None, trustme=False): def create(self, input_storage=None, trustme=False):
""" """
Create a function. Create a function.
input_storage -> a list matching the inputs list and providing default values input_storage -> a list matching the inputs list and providing
if the default for an input is None, then that input is a default values if the default for an input is
required input. For an input with an update, the default None, then that input is a required input. For an
acts as initialization. input with an update, the default acts as
initialization.
trustme -> disables some exceptions, used internally trustme -> disables some exceptions, used internally
""" """
if input_storage is None: if input_storage is None:
input_storage = [None] * len(self.inputs) input_storage = [None] * len(self.inputs)
input_storage_lists = [] # list of independent one-element lists, will be passed to the linker # list of independent one-element lists, will be passed to the linker
input_storage_lists = []
defaults = [] defaults = []
# The following loop is to fill in the input_storage_lists and defaults lists. # The following loop is to fill in the input_storage_lists and
# defaults lists.
assert len(self.indices) == len(input_storage) assert len(self.indices) == len(input_storage)
for i, ((input, indices, subinputs), input_storage_i) in enumerate(zip(self.indices, input_storage)): for i, ((input, indices, subinputs), input_storage_i) in \
enumerate(zip(self.indices, input_storage)):
# Replace any default value given as a variable by its container.
# Note that this makes sense only in the context of shared variables, # Replace any default value given as a variable by its
# but for now we avoid dealing directly with them to avoid dependency # container. Note that this makes sense only in the
# on the shared variables work-in-progress repository. # context of shared variables, but for now we avoid
# dealing directly with them to avoid dependency on the
# shared variables work-in-progress repository.
if isinstance(input_storage_i, gof.Variable): if isinstance(input_storage_i, gof.Variable):
input_storage_i = input_storage_i.container input_storage_i = input_storage_i.container
...@@ -1282,7 +1342,8 @@ class FunctionMaker(object): ...@@ -1282,7 +1342,8 @@ class FunctionMaker(object):
# share the same storage. This is done by appending # share the same storage. This is done by appending
# input_storage_i.storage to input_storage_lists. # input_storage_i.storage to input_storage_lists.
if indices is not None: if indices is not None:
raise TypeError("Cannot take a Container instance as default for a SymbolicInputKit.") raise TypeError("Cannot take a Container instance as "
"default for a SymbolicInputKit.")
input_storage_lists.append(input_storage_i.storage) input_storage_lists.append(input_storage_i.storage)
storage = input_storage[i].storage[0] storage = input_storage[i].storage[0]
...@@ -1295,7 +1356,8 @@ class FunctionMaker(object): ...@@ -1295,7 +1356,8 @@ class FunctionMaker(object):
required = self.required[i] required = self.required[i]
refeed = self.refeed[i] refeed = self.refeed[i]
# sanity check-- if an input is required it should not need to be refed # sanity check-- if an input is required it should not
# need to be refed
assert not (required and refeed) assert not (required and refeed)
# shared variables need neither be input by the user nor refed # shared variables need neither be input by the user nor refed
...@@ -1312,9 +1374,7 @@ class FunctionMaker(object): ...@@ -1312,9 +1374,7 @@ class FunctionMaker(object):
if storage is not None: if storage is not None:
assert refeed or not required assert refeed or not required
defaults.append((required, defaults.append((required, refeed, storage))
refeed,
storage))
# Get a function instance # Get a function instance
start_linker = time.time() start_linker = time.time()
...@@ -1338,7 +1398,8 @@ class FunctionMaker(object): ...@@ -1338,7 +1398,8 @@ class FunctionMaker(object):
self.profile.import_time += import_time self.profile.import_time += import_time
fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs, fn = self.function_builder(_fn, _i, _o, self.indices, self.outputs,
defaults, self.unpack_single, self.return_none, self.output_keys, self) defaults, self.unpack_single,
self.return_none, self.output_keys, self)
fn.profile = self.profile fn.profile = self.profile
return fn return fn
...@@ -1367,19 +1428,6 @@ def _constructor_FunctionMaker(kwargs): ...@@ -1367,19 +1428,6 @@ def _constructor_FunctionMaker(kwargs):
copy_reg.pickle(FunctionMaker, _pickle_FunctionMaker) copy_reg.pickle(FunctionMaker, _pickle_FunctionMaker)
try:
# Pickle of slice is implemented on python 2.6. To enabled be
# compatible with python 2.4, we implement pickling of slice
# ourself.
cPickle.dumps(slice(0, 10, 100))
except TypeError:
# This slice pickle implementation seam backward and forward compatible.
def _pickle_slice(s):
return (slice, (s.start, s.stop, s.step))
copy_reg.pickle(slice, _pickle_slice)
__checkers = [] __checkers = []
...@@ -1390,7 +1438,6 @@ def check_equal(x, y): ...@@ -1390,7 +1438,6 @@ def check_equal(x, y):
except Exception: except Exception:
continue continue
return x == y return x == y
#raise Exception('No checker for equality between %s and %s' % (x, y))
def register_checker(checker): def register_checker(checker):
...@@ -1405,10 +1452,10 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False, ...@@ -1405,10 +1452,10 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
:param inputs: list of `SymbolicInput` or `In` instances :param inputs: list of `SymbolicInput` or `In` instances
:param outputs: a SymbolicOutput or a list of `SymbolicOutput` or `Out` :param outputs: a SymbolicOutput or a list of `SymbolicOutput` or
instances. The return value of the returned function will match the `Out` instances. The return value of the returned function
format of this argument (either the value itself or a list of one or more will match the format of this argument (either the value
return values) itself or a list of one or more return values)
:param mode: a descriptive string or a Mode instance. (Default of None :param mode: a descriptive string or a Mode instance. (Default of None
means to use `config.mode` (See below for descriptive string list). means to use `config.mode` (See below for descriptive string list).
...@@ -1422,7 +1469,8 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False, ...@@ -1422,7 +1469,8 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
- FAST_COMPILE (minimal optimization) - FAST_COMPILE (minimal optimization)
- ProfileMode(deprecated): allow to print a profile mode with mode.print_summary - ProfileMode(deprecated): allow to print a profile mode with
mode.print_summary
- DebugMode: verify many internal conditions that are normally assumed - DebugMode: verify many internal conditions that are normally assumed
(slow) (slow)
...@@ -1471,7 +1519,7 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False, ...@@ -1471,7 +1519,7 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
accept_inplace=accept_inplace, accept_inplace=accept_inplace,
profile=profile, profile=profile,
on_unused_input=on_unused_input, on_unused_input=on_unused_input,
output_keys = output_keys).create( output_keys=output_keys).create(
defaults) defaults)
t2 = time.time() t2 = time.time()
...@@ -1590,15 +1638,15 @@ def get_info_on_inputs(named_inputs, n_unnamed_inputs): ...@@ -1590,15 +1638,15 @@ def get_info_on_inputs(named_inputs, n_unnamed_inputs):
"constructor to give it a name)." % n_unnamed_inputs) "constructor to give it a name)." % n_unnamed_inputs)
else: else:
if n_unnamed_inputs == 0: if n_unnamed_inputs == 0:
msg = ("The function has %s named input%s (%s)." % ( msg = ("The function has %s named input%s (%s)." %
n_named_inputs, get_plural(n_named_inputs), (n_named_inputs, get_plural(n_named_inputs),
', '.join(named_inputs))) ', '.join(named_inputs)))
else: else:
msg = ("The function has %s named input%s (%s), and %s unnamed " msg = ("The function has %s named input%s (%s), and %s unnamed "
"input%s which thus cannot be accessed through keyword " "input%s which thus cannot be accessed through keyword "
"argument%s (use 'name=...' in a variable's constructor " "argument%s (use 'name=...' in a variable's constructor "
"to give it a name)." % ( "to give it a name)." %
n_named_inputs, get_plural(n_named_inputs), (n_named_inputs, get_plural(n_named_inputs),
', '.join(named_inputs), n_unnamed_inputs, ', '.join(named_inputs), n_unnamed_inputs,
get_plural(n_unnamed_inputs), get_plural(n_unnamed_inputs),
get_plural(n_unnamed_inputs))) get_plural(n_unnamed_inputs)))
......
"""Define `SymbolicInput`, `SymbolicOutput`, `In`, `Out` """ """Define `SymbolicInput`, `SymbolicOutput`, `In`, `Out` """
__docformat__ = 'restructuredtext en'
from theano import gof from theano import gof
from sharedvalue import SharedVariable from sharedvalue import SharedVariable
...@@ -7,6 +6,8 @@ from sharedvalue import SharedVariable ...@@ -7,6 +6,8 @@ from sharedvalue import SharedVariable
import logging import logging
_logger = logging.getLogger("theano.compile.io") _logger = logging.getLogger("theano.compile.io")
__docformat__ = 'restructuredtext en'
class SymbolicInput(object): class SymbolicInput(object):
""" """
...@@ -17,34 +18,47 @@ class SymbolicInput(object): ...@@ -17,34 +18,47 @@ class SymbolicInput(object):
not computed from its owner. not computed from its owner.
name: Any type. (If autoname=True, defaults to variable.name). name: Any type. (If autoname=True, defaults to variable.name).
If name is a valid Python identifier, this input can be set by kwarg, and its value If name is a valid Python identifier, this input can be set by
can be accessed by self.<name>. kwarg, and its value can be accessed by self.<name>.
update: Variable instance (default: None) update: Variable instance (default: None)
value (see previous) will be replaced with this expression variable after each function call. value (see previous) will be replaced with this expression
If update is None, the update will be the default value of the input. variable after each function call. If update is None, the
update will be the default value of the input.
mutable: Bool (default: False if update is None, True if update is
not None)
True: permit the compiled function to modify the python object
being passed as the input
mutable: Bool (default: False if update is None, True if update is not None) False: do not permit the compiled function to modify the
True: permit the compiled function to modify the python object being passed as the input python object being passed as the input.
False: do not permit the compiled function to modify the python object being passed as the input.
strict: Bool (default: False) strict: Bool (default: False)
True: means that the value you pass for this input must have exactly the right type
False: the value you pass for this input may be cast automatically to the proper type True: means that the value you pass for this input must have
exactly the right type
False: the value you pass for this input may be cast
automatically to the proper type
allow_downcast: Bool or None (default: None) allow_downcast: Bool or None (default: None)
Only applies when `strict` is False. Only applies when `strict` is False.
True: the value you pass for this input can be silently True: the value you pass for this input can be silently
downcasted to fit the right type, which may lose precision. downcasted to fit the right type, which may lose precision.
False: the value will only be cast to a more general, or precise, type.
None: Almost like False, but allows downcast of Python floats to floatX. False: the value will only be cast to a more general, or
precise, type. None: Almost like False, but allows downcast
of Python floats to floatX.
autoname: Bool (default: True) autoname: Bool (default: True)
See the name option. See the name option.
implicit: Bool (default: False) implicit: Bool (default: False)
See help(In). Note that 'None' is not allowed here, since we are in the See help(In). Note that 'None' is not allowed here, since we
symbolic case. are in the symbolic case.
""" """
def __init__(self, variable, name=None, update=None, mutable=None, def __init__(self, variable, name=None, update=None, mutable=None,
...@@ -146,36 +160,54 @@ class In(SymbolicInput): ...@@ -146,36 +160,54 @@ class In(SymbolicInput):
not computed from its owner. not computed from its owner.
name: Any type. (If autoname=True, defaults to variable.name). name: Any type. (If autoname=True, defaults to variable.name).
If name is a valid Python identifier, this input can be set by kwarg, and its value If name is a valid Python identifier, this input can be set by
can be accessed by self.<name>. kwarg, and its value can be accessed by self.<name>.
value: Any type. value: Any type.
The initial/default value for this input. If update is None, this input acts just like The initial/default value for this input. If update is None,
an argument with a default value in Python. If update is not None, changes to this this input acts just like an argument with a default value in
value will "stick around", whether due to an update or a user's explicit action. Python. If update is not None, changes to this value will
"stick around", whether due to an update or a user's explicit
action.
update: Variable instance (default: None) update: Variable instance (default: None)
value (see previous) will be replaced with this expression variable after each function call. value (see previous) will be replaced with this expression
If update is None, the update will be the default value of the input. variable after each function call. If update is None, the
update will be the default value of the input.
mutable: Bool (default: False if update is None, True if update is not None) mutable: Bool (default: False if update is None, True if update is
True: permit the compiled function to modify the python object being passed as the input not None)
False: do not permit the compiled function to modify the python object being passed as the input.
True: permit the compiled function to modify the python object
being passed as the input
False: do not permit the compiled function to modify the
python object being passed as the input.
borrow: Bool (default: take the same value as mutable) borrow: Bool (default: take the same value as mutable)
True: permit the output of the compiled function to be aliased to the input
True: permit the output of the compiled function to be aliased
to the input
False: do not permit any output to be aliased to the input False: do not permit any output to be aliased to the input
strict: Bool (default: False) strict: Bool (default: False)
True: means that the value you pass for this input must have exactly the right type
False: the value you pass for this input may be cast automatically to the proper type True: means that the value you pass for this input must have
exactly the right type
False: the value you pass for this input may be cast
automatically to the proper type
allow_downcast: Bool or None (default: None) allow_downcast: Bool or None (default: None)
Only applies when `strict` is False. Only applies when `strict` is False.
True: the value you pass for this input can be silently True: the value you pass for this input can be silently
downcasted to fit the right type, which may lose precision. downcasted to fit the right type, which may lose precision.
False: the value will only be cast to a more general, or precise, type.
None: Almost like False, but allows downcast of Python floats to floatX. False: the value will only be cast to a more general, or
precise, type. None: Almost like False, but allows downcast
of Python floats to floatX.
autoname: Bool (default: True) autoname: Bool (default: True)
See the name option. See the name option.
...@@ -194,11 +226,11 @@ class In(SymbolicInput): ...@@ -194,11 +226,11 @@ class In(SymbolicInput):
# Note: the documentation above is duplicated in doc/topics/function.txt, # Note: the documentation above is duplicated in doc/topics/function.txt,
# try to keep it synchronized. # try to keep it synchronized.
def __init__(self, variable, name=None, value=None, update=None, def __init__(self, variable, name=None, value=None, update=None,
mutable=None, strict=False, allow_downcast=None, autoname=True, mutable=None, strict=False, allow_downcast=None,
implicit=None, borrow=None, shared=False): autoname=True, implicit=None, borrow=None, shared=False):
# if shared, an input's value comes from its persistent
# if shared, an input's value comes from its persistent storage, not from a default stored # storage, not from a default stored in the function or from
# in the function or from the caller # the caller
self.shared = shared self.shared = shared
if borrow is None: if borrow is None:
......
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
""" """
from __future__ import print_function from __future__ import print_function
import logging import logging
import warnings
from textwrap import dedent
import numpy import numpy
...@@ -11,24 +9,24 @@ import theano ...@@ -11,24 +9,24 @@ import theano
from theano import gof from theano import gof
import theano.gof.vm import theano.gof.vm
from theano.configparser import config, AddConfigVar, StrParam from theano.configparser import config, AddConfigVar, StrParam
from theano.compile.ops import register_view_op_c_code, _output_guard from theano.compile.ops import _output_guard
_logger = logging.getLogger('theano.compile.mode') _logger = logging.getLogger('theano.compile.mode')
AddConfigVar('optimizer_excluding', AddConfigVar('optimizer_excluding',
("When using the default mode, we will remove optimizer with these " ("When using the default mode, we will remove optimizer with "
"tags. Separate tags with ':'."), "these tags. Separate tags with ':'."),
StrParam("", allow_override=False), StrParam("", allow_override=False),
in_c_key=False) in_c_key=False)
AddConfigVar('optimizer_including', AddConfigVar('optimizer_including',
("When using the default mode, we will add optimizer with these tags. " ("When using the default mode, we will add optimizer with "
"Separate tags with ':'."), "these tags. Separate tags with ':'."),
StrParam("", allow_override=False), StrParam("", allow_override=False),
in_c_key=False) in_c_key=False)
AddConfigVar('optimizer_requiring', AddConfigVar('optimizer_requiring',
("When using the default mode, we will require optimizer with these " ("When using the default mode, we will require optimizer with "
"tags. Separate tags with ':'."), "these tags. Separate tags with ':'."),
StrParam("", allow_override=False), StrParam("", allow_override=False),
in_c_key=False) in_c_key=False)
...@@ -50,9 +48,9 @@ def check_equal(x, y): ...@@ -50,9 +48,9 @@ def check_equal(x, y):
y = y.todense() y = y.todense()
if isinstance(x, numpy.ndarray) and isinstance(y, numpy.ndarray): if isinstance(x, numpy.ndarray) and isinstance(y, numpy.ndarray):
if (x.dtype != y.dtype if (x.dtype != y.dtype or
or x.shape != y.shape x.shape != y.shape or
or numpy.any(abs(x - y) > 1e-10)): numpy.any(abs(x - y) > 1e-10)):
raise Exception("Output mismatch.", raise Exception("Output mismatch.",
{'performlinker': x, 'clinker': y}) {'performlinker': x, 'clinker': y})
else: else:
...@@ -287,7 +285,8 @@ class Mode(object): ...@@ -287,7 +285,8 @@ class Mode(object):
def __str__(self): def __str__(self):
return "%s(linker = %s, optimizer = %s)" % (self.__class__.__name__, return "%s(linker = %s, optimizer = %s)" % (self.__class__.__name__,
self.provided_linker, self.provided_optimizer) self.provided_linker,
self.provided_optimizer)
def __get_optimizer(self): def __get_optimizer(self):
if isinstance(self._optimizer, gof.Query): if isinstance(self._optimizer, gof.Query):
...@@ -364,10 +363,11 @@ def get_mode(orig_string): ...@@ -364,10 +363,11 @@ def get_mode(orig_string):
# DebugMode use its own linker. # DebugMode use its own linker.
ret = DebugMode(optimizer=config.optimizer) ret = DebugMode(optimizer=config.optimizer)
else: else:
# The import is needed in case string is ProfileMode # This might be required if the string is 'ProfileMode'
from profilemode import ProfileMode, prof_mode_instance_to_print from profilemode import ProfileMode # noqa
ret = eval(string from profilemode import prof_mode_instance_to_print
+ '(linker=config.linker, optimizer=config.optimizer)') ret = eval(string +
'(linker=config.linker, optimizer=config.optimizer)')
elif string in predefined_modes: elif string in predefined_modes:
ret = predefined_modes[string] ret = predefined_modes[string]
else: else:
......
from __future__ import print_function from __future__ import print_function
# Note: this code was initially copied from the 'pyutools' package by its # Note: this code was initially copied from the 'pyutools' package by its
# original author, and re-licensed under Theano's license. # original author, and re-licensed under Theano's license.
import numpy
import theano import theano
from theano.compile.mode import Mode from theano.compile.mode import Mode
......
...@@ -71,11 +71,12 @@ class ViewOp(gof.Op): ...@@ -71,11 +71,12 @@ class ViewOp(gof.Op):
version = [] version = []
# If any of the c code is unversionned, we have to return () # If any of the c code is unversionned, we have to return ()
# Else, we will return a list of (type name, version) pairs. # Else, we will return a list of (type name, version) pairs.
for t, (c, v) in sorted(self.c_code_and_version.items(), key=lambda pair: str(pair[0])): for t, (c, v) in sorted(self.c_code_and_version.items(),
key=lambda pair: str(pair[0])):
if not v: if not v:
warnings.warn("Type %s has C code for ViewOp, but it has " warnings.warn("Type %s has C code for ViewOp, but it has no "
"no version. You should add a 'version' keyword arg " "version. You should add a 'version' keyword "
"when calling register_view_op_c_code." % t, "arg when calling register_view_op_c_code." % t,
stacklevel=2) stacklevel=2)
return () return ()
version.append((str(t), v)) version.append((str(t), v))
...@@ -165,11 +166,13 @@ class DeepCopyOp(gof.Op): ...@@ -165,11 +166,13 @@ class DeepCopyOp(gof.Op):
version = [] version = []
# If any of the c code is unversionned, we have to return () # If any of the c code is unversionned, we have to return ()
# Else, we will return a list of (type name, version) pairs. # Else, we will return a list of (type name, version) pairs.
for t, (c, v) in sorted(self.c_code_and_version.items(), key=lambda pair: str(pair[0])): for t, (c, v) in sorted(self.c_code_and_version.items(),
key=lambda pair: str(pair[0])):
if not v: if not v:
warnings.warn("Type %s has C code for DeepCopyOp, but it has " warnings.warn("Type %s has C code for DeepCopyOp, but it has "
"no version. You should add a 'version' keyword arg " "no version. You should add a 'version' keyword"
"when calling register_deep_copy_op_c_code." % t, " arg when calling "
"register_deep_copy_op_c_code." % t,
stacklevel=2) stacklevel=2)
return () return ()
version.append((str(t), v)) version.append((str(t), v))
...@@ -284,11 +287,12 @@ class Shape(gof.Op): ...@@ -284,11 +287,12 @@ class Shape(gof.Op):
version = [] version = []
# If any of the c code is unversionned, we have to return () # If any of the c code is unversionned, we have to return ()
# Else, we will return a list of (type name, version) pairs. # Else, we will return a list of (type name, version) pairs.
for t, (c, v) in sorted(self.c_code_and_version.items(), key=lambda pair: str(pair[0])): for t, (c, v) in sorted(self.c_code_and_version.items(),
key=lambda pair: str(pair[0])):
if not v: if not v:
warnings.warn("Type %s has C code for Shape, but it has " warnings.warn("Type %s has C code for Shape, but it has no "
"no version. You should add a 'version' keyword arg " "version. You should add a 'version' keyword "
"when calling register_shape_c_code." % t, "arg when calling register_shape_c_code." % t,
stacklevel=2) stacklevel=2)
return () return ()
version.append((str(t), v)) version.append((str(t), v))
...@@ -301,7 +305,6 @@ class Shape(gof.Op): ...@@ -301,7 +305,6 @@ class Shape(gof.Op):
shape = Shape() shape = Shape()
_shape = shape # was used in the past, now use shape directly. _shape = shape # was used in the past, now use shape directly.
#pprint.assign(_shape, printing.MemberPrinter('shape'))
class Shape_i(gof.Op): class Shape_i(gof.Op):
...@@ -389,8 +392,11 @@ class Shape_i(gof.Op): ...@@ -389,8 +392,11 @@ class Shape_i(gof.Op):
return [()] return [()]
def grad(self, inp, grads): def grad(self, inp, grads):
return [theano.gradient.grad_not_implemented(op=self, x_pos=0, x=inp[0], return [theano.gradient.grad_not_implemented(
comment="No gradient for the shape of a matrix is implemented.")] op=self, x_pos=0, x=inp[0],
comment=("No gradient for the shape of a matrix "
"is implemented."))]
def shape_i(var, i, fgraph=None): def shape_i(var, i, fgraph=None):
"""Equivalent of var.shape[i], but apply if possible the shape """Equivalent of var.shape[i], but apply if possible the shape
...@@ -435,9 +441,10 @@ def shape_i(var, i, fgraph=None): ...@@ -435,9 +441,10 @@ def shape_i(var, i, fgraph=None):
def register_shape_i_c_code(typ, code, check_input, version=()): def register_shape_i_c_code(typ, code, check_input, version=()):
""" Tell Shape_i how to generate C code for a Theano Type """ Tell Shape_i how to generate C code for a Theano Type
:param typ: A Theano type. It must be the Theano class itself and not an :param typ: A Theano type. It must be the Theano class itself and not
instance of the class. an instance of the class.
:param code: C code that gets the shape of dimensions %(i)s for the Theano type 'typ'. :param code: C code that gets the shape of dimensions %(i)s for the
Theano type 'typ'.
Use %(iname)s and %(oname)s for the input and output C Use %(iname)s and %(oname)s for the input and output C
variable names respectively. variable names respectively.
:param version: A number indicating the version of the code, for cache. :param version: A number indicating the version of the code, for cache.
...@@ -620,7 +627,8 @@ class Rebroadcast(gof.Op): ...@@ -620,7 +627,8 @@ class Rebroadcast(gof.Op):
return type(self) == type(other) and self.axis == other.axis return type(self) == type(other) and self.axis == other.axis
def __hash__(self): def __hash__(self):
items = sorted(self.axis.iteritems()) # no ambiguity because each item key is unique # no ambiguity because each item key is unique
items = sorted(self.axis.iteritems())
return hash((type(self), tuple(items))) return hash((type(self), tuple(items)))
def __str__(self): def __str__(self):
...@@ -637,9 +645,9 @@ class Rebroadcast(gof.Op): ...@@ -637,9 +645,9 @@ class Rebroadcast(gof.Op):
def make_node(self, x): def make_node(self, x):
if self.axis.keys() and (x.ndim <= numpy.max(self.axis.keys())): if self.axis.keys() and (x.ndim <= numpy.max(self.axis.keys())):
raise ValueError('Trying to rebroadcast non-existent dimension') raise ValueError('Trying to rebroadcast non-existent dimension')
t = x.type.clone(broadcastable=[self.axis.get(i, b) t = x.type.clone(
for i, b in enumerate( broadcastable=[self.axis.get(i, b)
x.type.broadcastable)]) for i, b in enumerate(x.type.broadcastable)])
return gof.Apply(self, [x], [t()]) return gof.Apply(self, [x], [t()])
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
...@@ -702,9 +710,10 @@ class Rebroadcast(gof.Op): ...@@ -702,9 +710,10 @@ class Rebroadcast(gof.Op):
for t, (c, v) in sorted(self.c_code_and_version.items(), for t, (c, v) in sorted(self.c_code_and_version.items(),
key=lambda pair: str(pair[0])): key=lambda pair: str(pair[0])):
if not v: if not v:
warnings.warn("Type %s has C code for Rebroadcast, but it has " warnings.warn("Type %s has C code for Rebroadcast, but it "
"no version. You should add a 'version' keyword arg " "has no version. You should add a 'version' "
"when calling register_rebroadcast_c_code." % t, "keyword arg when calling "
"register_rebroadcast_c_code." % t,
stacklevel=2) stacklevel=2)
return () return ()
version.append((str(t), v)) version.append((str(t), v))
...@@ -718,17 +727,18 @@ def register_specify_shape_c_code(typ, code, version=(), ...@@ -718,17 +727,18 @@ def register_specify_shape_c_code(typ, code, version=(),
c_support_code_apply=None): c_support_code_apply=None):
""" Tell SpecifyShape how to generate C code for a Theano Type """ Tell SpecifyShape how to generate C code for a Theano Type
:param typ: A Theano type. It must be the Theano class itself and not an :param typ: A Theano type. It must be the Theano class itself and
instance of the class. not an instance of the class.
:param code: C code that checks the shape and returns a view for the Theano type 'typ'. :param code: C code that checks the shape and returns a view for
Use %(iname)s and %(oname)s for the input and output C the Theano type 'typ'. Use %(iname)s and %(oname)s
variable names respectively. for the input and output C variable names
%(shape)s is the vector of shape of %(iname)s. respectively. %(shape)s is the vector of shape of
Check that its length is good. %(iname)s. Check that its length is good.
:param version: A number indicating the version of the code, for cache. :param version: A number indicating the version of the code, for cache.
:param c_support_code_apply: extra code. :param c_support_code_apply: extra code.
""" """
SpecifyShape.c_code_and_version[typ] = (code, version, c_support_code_apply) SpecifyShape.c_code_and_version[typ] = (code, version,
c_support_code_apply)
class SpecifyShape(gof.Op): class SpecifyShape(gof.Op):
...@@ -784,7 +794,8 @@ class SpecifyShape(gof.Op): ...@@ -784,7 +794,8 @@ class SpecifyShape(gof.Op):
new_shape = [] new_shape = []
for dim in xrange(node.inputs[0].ndim): for dim in xrange(node.inputs[0].ndim):
try: try:
s = theano.tensor.get_scalar_constant_value(node.inputs[1][dim]) s = theano.tensor.get_scalar_constant_value(
node.inputs[1][dim])
s = theano.tensor.as_tensor_variable(s) s = theano.tensor.as_tensor_variable(s)
new_shape.append(s) new_shape.append(s)
except theano.tensor.NotScalarConstantError: except theano.tensor.NotScalarConstantError:
...@@ -832,7 +843,8 @@ class SpecifyShape(gof.Op): ...@@ -832,7 +843,8 @@ class SpecifyShape(gof.Op):
code, version, _ = self.c_code_and_version[itype] code, version, _ = self.c_code_and_version[itype]
return code % locals() return code % locals()
return super(SpecifyShape, self).c_code(node, node, inames, onames, sub) return super(SpecifyShape, self).c_code(node, node, inames,
onames, sub)
def c_code_cache_version(self): def c_code_cache_version(self):
version = [] version = []
...@@ -841,9 +853,10 @@ class SpecifyShape(gof.Op): ...@@ -841,9 +853,10 @@ class SpecifyShape(gof.Op):
for t, (c, v, _) in sorted(self.c_code_and_version.items(), for t, (c, v, _) in sorted(self.c_code_and_version.items(),
key=lambda pair: str(pair[0])): key=lambda pair: str(pair[0])):
if not v: if not v:
warnings.warn("Type %s has C code for SpecifyShape, but it has " warnings.warn("Type %s has C code for SpecifyShape, but it "
"no version. You should add a 'version' keyword arg " "has no version. You should add a 'version' "
"when calling register_specify_shape_c_code." % t, "keyword arg when calling "
"register_specify_shape_c_code." % t,
stacklevel=2) stacklevel=2)
return () return ()
version.append((str(t), v)) version.append((str(t), v))
......
"""Provide a simple user friendly API """ """Provide a simple user friendly API """
__docformat__ = 'restructuredtext en'
from theano import config from theano import config
from theano.compile import orig_function, In, Out from theano.compile import orig_function, In, Out
from theano.compile import UnusedInputError from theano.compile import UnusedInputError
...@@ -13,6 +9,8 @@ from theano.gof import Variable, Constant ...@@ -13,6 +9,8 @@ from theano.gof import Variable, Constant
import logging import logging
_logger = logging.getLogger("theano.compile.pfunc") _logger = logging.getLogger("theano.compile.pfunc")
__docformat__ = 'restructuredtext en'
def rebuild_collect_shared(outputs, def rebuild_collect_shared(outputs,
inputs=None, inputs=None,
...@@ -199,7 +197,7 @@ def rebuild_collect_shared(outputs, ...@@ -199,7 +197,7 @@ def rebuild_collect_shared(outputs,
# filter_variable ensure smooth conversion of cpu/gpu Types # filter_variable ensure smooth conversion of cpu/gpu Types
try: try:
update_val = store_into.type.filter_variable(update_val) update_val = store_into.type.filter_variable(update_val)
except TypeError as e: except TypeError:
err_msg = ('An update must have the same type as the' err_msg = ('An update must have the same type as the'
' original shared variable (shared_var=%s,' ' original shared variable (shared_var=%s,'
' shared_var.type=%s,' ' shared_var.type=%s,'
...@@ -232,8 +230,8 @@ def rebuild_collect_shared(outputs, ...@@ -232,8 +230,8 @@ def rebuild_collect_shared(outputs,
cloned_outputs.append(Out(cloned_v, borrow=v.borrow)) cloned_outputs.append(Out(cloned_v, borrow=v.borrow))
else: else:
raise TypeError('Outputs must be theano Variable or ' raise TypeError('Outputs must be theano Variable or '
'Out instances. Received ' + str(v) 'Out instances. Received ' + str(v) +
+ ' of type ' + str(type(v))) ' of type ' + str(type(v)))
# computed_list.append(cloned_v) # computed_list.append(cloned_v)
else: else:
if isinstance(outputs, Variable): if isinstance(outputs, Variable):
...@@ -275,35 +273,38 @@ def rebuild_collect_shared(outputs, ...@@ -275,35 +273,38 @@ def rebuild_collect_shared(outputs,
class Param(object): class Param(object):
def __init__(self, variable, default=None, name=None, mutable=False, def __init__(self, variable, default=None, name=None, mutable=False,
strict=False, allow_downcast=None, implicit=None, borrow=None): strict=False, allow_downcast=None, implicit=None,
borrow=None):
""" """
:param variable: A variable in an expression graph to use as a :param variable: A variable in an expression graph to use as a
compiled-function parameter compiled-function parameter
:param default: The default value to use at call-time (can also be a Container where :param default: The default value to use at call-time (can
the function will find a value at call-time.) also be a Container where the function will find a value
at call-time.)
:param name: A string to identify this parameter from function kwargs. :param name: A string to identify this parameter from function kwargs.
:param mutable: True -> function is allowed to modify this argument. :param mutable: True -> function is allowed to modify this argument.
:param borrow: Whether the function is allowed to alias some output to :param borrow: Whether the function is allowed to alias some
this input. Using None (default) means we re-use the same value as the output to this input. Using None (default) means we re-use
`mutable` flag. the same value as the `mutable` flag.
False: do not permit any output to be aliased to the input False: do not permit any output to be aliased to the input
:param strict: False -> function arguments may be copied or cast to match the :param strict: False -> function arguments may be copied or
type required by the parameter `variable`. cast to match the type required by the parameter
`variable`.
True -> function arguments must exactly match the type True -> function arguments must exactly match the type
required by `variable`. required by `variable`.
:param allow_downcast: Only applies if `strict` is False. :param allow_downcast: Only applies if `strict` is False.
True -> allow assigned value to lose precision when cast during assignment. True -> allow assigned value to lose precision when cast
during assignment.
False -> never allow precision loss. False -> never allow precision loss.
None -> only allow downcasting of a Python float to a scalar floatX. None -> only allow downcasting of a Python float to a scalar floatX.
:param implicit: see help(theano.io.In) :param implicit: see help(theano.io.In)
""" """
self.variable = variable self.variable = variable
self.default = default self.default = default
...@@ -335,7 +336,7 @@ class Param(object): ...@@ -335,7 +336,7 @@ class Param(object):
def pfunc(params, outputs=None, mode=None, updates=None, givens=None, def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
no_default_updates=False, accept_inplace=False, name=None, no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=None, rebuild_strict=True, allow_input_downcast=None,
profile=None, on_unused_input=None,output_keys=None): profile=None, on_unused_input=None, output_keys=None):
"""Function-constructor for graphs with shared variables. """Function-constructor for graphs with shared variables.
:type params: list of either Variable or Param instances. :type params: list of either Variable or Param instances.
...@@ -348,30 +349,35 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None, ...@@ -348,30 +349,35 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
:type mode: string or `theano.compile.Mode` instance. :type mode: string or `theano.compile.Mode` instance.
:param mode: compilation mode :param mode: compilation mode
:type updates: iterable over pairs (shared_variable, new_expression). List, tuple or dict. :type updates: iterable over pairs (shared_variable,
:param updates: update the values for SharedVariable inputs according to these expressions new_expression). List, tuple or dict.
:param updates: update the values for SharedVariable inputs
according to these expressions
:type givens: iterable over pairs (Var1, Var2) of Variables. List, tuple or dict. The Var1 :type givens: iterable over pairs (Var1, Var2) of Variables. List,
and Var2 in each pair must have the same Type. tuple or dict. The Var1 and Var2 in each pair must have the
same Type.
:param givens: specific substitutions to make in the computation graph (Var2 replaces :param givens: specific substitutions to make in the computation
Var1). graph (Var2 replaces Var1).
:type no_default_updates: either bool or list of Variables :type no_default_updates: either bool or list of Variables
:param no_default_updates: if True, do not perform any automatic update on Variables. :param no_default_updates: if True, do not perform any automatic
If False (default), perform them all. Else, perform automatic updates on all Variables update on Variables. If False (default), perform them
that are neither in "updates" nor in "no_default_updates". all. Else, perform automatic updates on all Variables that are
neither in "updates" nor in "no_default_updates".
:type name: None or string :type name: None or string
:param name: attaches a name to the profiling result of this function. :param name: attaches a name to the profiling result of this function.
:type allow_input_downcast: Boolean :type allow_input_downcast: Boolean
:param allow_input_downcast: True means that the values passed as :param allow_input_downcast: True means that the values passed as
inputs when calling the function can be silently downcasted to fit inputs when calling the function can be silently downcasted to
the dtype of the corresponding Variable, which may lose precision. fit the dtype of the corresponding Variable, which may lose
False means that it will only be cast to a more general, or precision. False means that it will only be cast to a more
precise, type. None (default) is almost like False, but allows general, or precise, type. None (default) is almost like
downcasting of Python float scalars to floatX. False, but allows downcasting of Python float scalars to
floatX.
:type profile: None, True, str, or ProfileStats instance :type profile: None, True, str, or ProfileStats instance
:param profile: accumulate profiling information into a given ProfileStats :param profile: accumulate profiling information into a given ProfileStats
...@@ -389,30 +395,32 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None, ...@@ -389,30 +395,32 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
:rtype: theano.compile.Function :rtype: theano.compile.Function
:returns: a callable object that will compute the outputs (given the inputs) :returns: a callable object that will compute the outputs (given
and update the implicit function arguments according to the `updates`. the inputs) and update the implicit function arguments
according to the `updates`.
:note: Regarding givens: Be careful to make sure that these substitutions are
independent--behaviour when Var1 of one pair appears in the graph leading to Var2 in
another expression is undefined. Replacements specified with givens are different from
optimizations in that Var2 is not expected to be equivalent to Var1.
:note: Regarding givens: Be careful to make sure that these
substitutions are independent--behaviour when Var1 of one pair
appears in the graph leading to Var2 in another expression is
undefined. Replacements specified with givens are different
from optimizations in that Var2 is not expected to be
equivalent to Var1.
""" """
# #
# This function works by cloning the graph (except for the inputs), and then shipping it # This function works by cloning the graph (except for the
# off to compile.function # inputs), and then shipping it off to compile.function (There it
# (There it will be cloned again, unnecessarily, because it doesn't know that we already # will be cloned again, unnecessarily, because it doesn't know
# cloned it.) # that we already cloned it.)
# #
# First, it clones the replacements named in the givens argument, and points each Var1 to # First, it clones the replacements named in the givens argument,
# the clone of Var2. # and points each Var1 to the clone of Var2. Then it sets the
# Then it sets the inputs in the clone dictionary. # inputs in the clone dictionary. After these steps, we are
# After these steps, we are assuming that the clone dictionary contains all the inputs to # assuming that the clone dictionary contains all the inputs to
# the computation graph. # the computation graph.
# #
# Then it clones the outputs and the update expressions. This rebuilds a computation graph # Then it clones the outputs and the update expressions. This
# from the inputs and the givens. # rebuilds a computation graph from the inputs and the givens.
# #
if updates is None: if updates is None:
updates = [] updates = []
...@@ -431,11 +439,13 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None, ...@@ -431,11 +439,13 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
# useful. # useful.
if not isinstance(params, (list, tuple)): if not isinstance(params, (list, tuple)):
raise Exception("in pfunc() the first argument must be a list or a tuple") raise Exception("in pfunc() the first argument must be a list or "
"a tuple")
if not isinstance(no_default_updates, bool)\ if not isinstance(no_default_updates, bool)\
and not isinstance(no_default_updates, list): and not isinstance(no_default_updates, list):
raise TypeError("no_default_update should be either a boolean or a list") raise TypeError("no_default_update should be either a boolean or "
"a list")
if len(updates) > 0 and any(isinstance(v, Variable) if len(updates) > 0 and any(isinstance(v, Variable)
for v in iter_over_pairs(updates)): for v in iter_over_pairs(updates)):
...@@ -494,9 +504,10 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None, ...@@ -494,9 +504,10 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
i.variable = iv i.variable = iv
for sv in shared_inputs: for sv in shared_inputs:
# pass value of None here # pass value of None
# value will be stored in the resulting functions' defaults list # value will be stored in the resulting functions' defaults
# but since the value of shared variables never needs to be refed, it is not needed # list but since the value of shared variables never needs to
# be refed, it is not needed
if sv in update_d: if sv in update_d:
si = In(variable=sv, value=sv.container, mutable=True, si = In(variable=sv, value=sv.container, mutable=True,
borrow=True, update=update_d[sv], shared=True) borrow=True, update=update_d[sv], shared=True)
...@@ -506,8 +517,9 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None, ...@@ -506,8 +517,9 @@ def pfunc(params, outputs=None, mode=None, updates=None, givens=None,
inputs.append(si) inputs.append(si)
return orig_function(inputs, cloned_outputs, mode, return orig_function(inputs, cloned_outputs, mode,
accept_inplace=accept_inplace, name=name, profile=profile, accept_inplace=accept_inplace, name=name,
on_unused_input=on_unused_input, output_keys=output_keys) profile=profile, on_unused_input=on_unused_input,
output_keys=output_keys)
def _pfunc_param_to_in(param, strict=False, allow_downcast=None): def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
......
...@@ -10,15 +10,15 @@ from theano.gof.link import WrapLinker ...@@ -10,15 +10,15 @@ from theano.gof.link import WrapLinker
from theano.compile.mode import (Mode, register_mode, from theano.compile.mode import (Mode, register_mode,
predefined_modes, predefined_linkers, predefined_modes, predefined_linkers,
predefined_optimizers) predefined_optimizers)
from theano import gof
from theano.configparser import config, AddConfigVar, IntParam, BoolParam from theano.configparser import config, AddConfigVar, IntParam, BoolParam
from theano.compile.function_module import FunctionMaker from theano.compile.function_module import FunctionMaker
run_cthunk = None # Will be imported only when needed.
from profiling import ProfileStats from profiling import ProfileStats
run_cthunk = None # Will be imported only when needed.
import_time = time.time() import_time = time.time()
AddConfigVar('ProfileMode.n_apply_to_print', AddConfigVar('ProfileMode.n_apply_to_print',
"Number of apply instances to print by default", "Number of apply instances to print by default",
IntParam(15, lambda i: i > 0), IntParam(15, lambda i: i > 0),
...@@ -30,8 +30,8 @@ AddConfigVar('ProfileMode.n_ops_to_print', ...@@ -30,8 +30,8 @@ AddConfigVar('ProfileMode.n_ops_to_print',
in_c_key=False) in_c_key=False)
AddConfigVar('ProfileMode.min_memory_size', AddConfigVar('ProfileMode.min_memory_size',
"""For the memory profile, do not print apply nodes if the size "For the memory profile, do not print apply nodes if the size "
of their outputs (in bytes) is lower then this threshold""", "of their outputs (in bytes) is lower then this threshold",
IntParam(1024, lambda i: i >= 0), IntParam(1024, lambda i: i >= 0),
in_c_key=False) in_c_key=False)
...@@ -84,7 +84,8 @@ class Profile_Maker(FunctionMaker): ...@@ -84,7 +84,8 @@ class Profile_Maker(FunctionMaker):
def new_fn(): def new_fn():
self.mode.apply_time = self.mode.profile_stats[ret].apply_time self.mode.apply_time = self.mode.profile_stats[ret].apply_time
self.mode.variable_shape = self.mode.profile_stats[ret].variable_shape self.mode.variable_shape = \
self.mode.profile_stats[ret].variable_shape
ret_fn() ret_fn()
# delete the old apply_time variable # delete the old apply_time variable
# because it doesn't mean the same thing anymore. # because it doesn't mean the same thing anymore.
...@@ -97,12 +98,12 @@ class Profile_Maker(FunctionMaker): ...@@ -97,12 +98,12 @@ class Profile_Maker(FunctionMaker):
global run_cthunk global run_cthunk
if run_cthunk is None and any(profile.apply_cimpl.values()): if run_cthunk is None and any(profile.apply_cimpl.values()):
# Lazy import to avoid compilation when importing theano. # Lazy import to avoid compilation when importing theano.
from theano.gof.cutils import run_cthunk from theano.gof.cutils import run_cthunk # noqa
warnings.warn( warnings.warn(
"DEPRECATION WARNING: The ProfileMode is deprecated. Use the Theano" "DEPRECATION WARNING: The ProfileMode is deprecated. "
" flags/parameter to theano.function 'profile=True' instead" "Use the Theano flags/parameter to theano.function "
" of 'mode=ProfileMode'") "'profile=True' instead of 'mode=ProfileMode'")
return ret return ret
...@@ -209,17 +210,34 @@ class ProfileMode(Mode): ...@@ -209,17 +210,34 @@ class ProfileMode(Mode):
self.fn_time = 0 self.fn_time = 0
def print_summary(self, **kwargs): def print_summary(self, **kwargs):
""" Print 3 summary that show where the time is spend. The first show an Apply-wise summary, the second show an Op-wise summary, the third show an type-Op-wise summary. """ Print 3 summaries that show where time is spent. The first shows
an Apply-wise summary, the second an Op-wise summary and the
The Apply-wise summary print the timing information for the worst offending Apply nodes. This corresponds to individual Op applications within your graph which take the longest to execute (so if you use dot twice, you will see two entries there). third a type-Op-wise summary.
The Op-wise summary print the execution time of all Apply nodes executing the same Op are grouped together and the total execution time per Op is shown (so if you use dot twice, you will see only one entry there corresponding to the sum of the time spent in each of them). If two Op have different hash value, they will be separate.
The type-Op-wise summary group the result by type of op. So event if two Op have different hash value, they will be merged. The Apply-wise summary prints the timing information for the
worst offending Apply nodes. This corresponds to individual Op
Their is an hack with the Op-wise summary. Go see it if you want to know more. applications within your graph which take the longest to
execute (so if you use dot twice, you will see two entries
there).
The Op-wise summary prints the execution time of all Apply
nodes executing the same Op grouped together and the total
execution time per Op is shown (so if you use dot twice, you
will see only one entry there corresponding to the sum of the
time spent in each of them). If two Ops have different hash
value, they will be separate.
The type-Op-wise summary group the result by type of op. So
event if two Op have different hash value, they will be
merged.
There is an hack with the Op-wise summary. Go see it if you
want to know more.
:param kwargs: They are passed to print_summary_ expanded. :param kwargs: They are passed to print_summary_ expanded.
Currently there is n_apply_to_print, n_ops_to_print and min_memory_size Currently there is n_apply_to_print,
that are accepted. n_ops_to_print and min_memory_size that are
accepted.
""" """
compile_time = sum([ps.compile_time for ps compile_time = sum([ps.compile_time for ps
in self.profile_stats.values()]) in self.profile_stats.values()])
...@@ -261,14 +279,18 @@ class ProfileMode(Mode): ...@@ -261,14 +279,18 @@ class ProfileMode(Mode):
**kwargs) **kwargs)
def print_diff_summary(self, other, **kwargs): def print_diff_summary(self, other, **kwargs):
""" As print_summary, but print the difference on two different profile mode. """ As print_summary, but print the difference on two different
TODO: Also we don't print the Apply-wise summary as it don't work for now. profile mode.
TODO: Also we don't print the Apply-wise summary as it don't
work for now.
TODO: make comparaison with gpu code. TODO: make comparaison with gpu code.
:param other: the other instance of ProfileMode that we want to be compared to. :param other: the other instance of ProfileMode that we want
to be compared to.
:param kwargs: They are passed to print_summary_ expanded. :param kwargs: They are passed to print_summary_ expanded.
Currently there is n_apply_to_print, n_ops_to_print and min_memory_size Currently there is n_apply_to_print, n_ops_to_print and
that are accepted. min_memory_size that are accepted.
""" """
def diff_dict(a_time, b_time_): def diff_dict(a_time, b_time_):
...@@ -331,7 +353,8 @@ class ProfileMode(Mode): ...@@ -331,7 +353,8 @@ class ProfileMode(Mode):
print("ProfileMode is deprecated! Use the new profiler.") print("ProfileMode is deprecated! Use the new profiler.")
print(" The Theano flags to enable it ise: profile=True") print(" The Theano flags to enable it ise: profile=True")
print(" The Theano flags for the memory profile to it is: profile_memory=True") print(" The Theano flags for the memory profile to it is: "
"profile_memory=True")
total_time = time.time() - import_time total_time = time.time() - import_time
total_fct_time = sum(fct_call_time.values()) total_fct_time = sum(fct_call_time.values())
...@@ -352,25 +375,37 @@ class ProfileMode(Mode): ...@@ -352,25 +375,37 @@ class ProfileMode(Mode):
print('ProfileMode.%s(%s)' % (fct_name, message)) print('ProfileMode.%s(%s)' % (fct_name, message))
print('---------------------------') print('---------------------------')
print() print()
print('Time since import %.3fs'%(total_time)) print('Time since import %.3fs' % (total_time))
print('Theano compile time: %.3fs (%.1f%% since import)'%(compile_time, compile_time/total_time*100)) print('Theano compile time: %.3fs (%.1f%% since import)' %
print(' Optimization time: %.3fs'%(other_time['optimizer_time'])) (compile_time, compile_time/total_time*100))
print(' Linker time: %.3fs'%(other_time['linker_time'])) print(' Optimization time: %.3fs' % (other_time['optimizer_time']))
print('Theano fct call %.3fs (%.1f%% since import)'%(total_fct_time, total_fct_time/total_time*100)) print(' Linker time: %.3fs' % (other_time['linker_time']))
print(' Theano Op time %.3fs %.1f%%(since import) %.1f%%(of fct call)' % ( print('Theano fct call %.3fs (%.1f%% since import)' %
local_time, local_time/total_time*100, time_pr_in_fct)) (total_fct_time, total_fct_time/total_time*100))
print(' Theano function overhead in ProfileMode %.3fs %.1f%%(since import) %.1f%%(of fct call)' % ( print(' Theano Op time %.3fs %.1f%%(since import) %.1f%%'
overhead_time, overhead_time/total_time*100, overhead_time_pourcent_fct_time)) '(of fct call)' % (local_time, local_time/total_time*100,
print('%i Theano fct call, %.3fs per call'%(total_fct_call, time_per_call)) time_pr_in_fct))
print('Rest of the time since import %.3fs %.1f%%'%(unknown_time, unknown_time/total_time*100)) print(' Theano function overhead in ProfileMode %.3fs %.1f%%'
'(since import) %.1f%%(of fct call)' % (
overhead_time, overhead_time/total_time*100,
overhead_time_pourcent_fct_time))
print('%i Theano fct call, %.3fs per call' %
(total_fct_call, time_per_call))
print('Rest of the time since import %.3fs %.1f%%' %
(unknown_time, unknown_time/total_time*100))
print() print()
print('Theano fct summary:') print('Theano fct summary:')
print('<% total fct time> <total time> <time per call> <nb call> <fct name>') print('<% total fct time> <total time> <time per call> <nb call> '
'<fct name>')
for key in fct_call.keys(): for key in fct_call.keys():
if fct_call[key] > 0: if fct_call[key] > 0:
print(' %4.1f%% %.3fs %.2es %d %s'%(fct_call_time[key]/total_fct_time*100 , fct_call_time[key], print(' %4.1f%% %.3fs %.2es %d %s' %
fct_call_time[key]/fct_call[key], fct_call[key], key.name)) (fct_call_time[key]/total_fct_time*100,
fct_call_time[key],
fct_call_time[key]/fct_call[key],
fct_call[key],
key.name))
else: else:
print(' NOT CALLED', key.name) print(' NOT CALLED', key.name)
...@@ -387,7 +422,8 @@ class ProfileMode(Mode): ...@@ -387,7 +422,8 @@ class ProfileMode(Mode):
op_apply.setdefault(op, 0) op_apply.setdefault(op, 0)
sop_apply.setdefault(type(a.op), 0) sop_apply.setdefault(type(a.op), 0)
op_time[op] += t op_time[op] += t
nb_call = [v for k, v in fct_call.items() if k.maker.fgraph is a.fgraph][0] nb_call = [v for k, v in fct_call.items()
if k.maker.fgraph is a.fgraph][0]
op_cimpl.setdefault(a.op, True) op_cimpl.setdefault(a.op, True)
op_cimpl[a.op] = op_cimpl[a.op] and apply_cimpl.get(a, False) op_cimpl[a.op] = op_cimpl[a.op] and apply_cimpl.get(a, False)
if t == 0: if t == 0:
...@@ -401,7 +437,8 @@ class ProfileMode(Mode): ...@@ -401,7 +437,8 @@ class ProfileMode(Mode):
sop_time = {} sop_time = {}
sop_call = {} sop_call = {}
sop_op = {} sop_op = {}
sop_cimpl = {} # map each op class to Bool. True iff all applies were done in c. # map each op class to Bool. True iff all applies were done in c.
sop_cimpl = {}
for a, t in op_time.items(): for a, t in op_time.items():
typ = type(a) typ = type(a)
sop_time.setdefault(typ, 0) sop_time.setdefault(typ, 0)
...@@ -415,8 +452,11 @@ class ProfileMode(Mode): ...@@ -415,8 +452,11 @@ class ProfileMode(Mode):
# Print the summary per op class. # Print the summary per op class.
print() print()
print('Single Op-wise summary:') print('Single Op-wise summary:')
print('<% of local_time spent on this kind of Op> <cumulative %> <self seconds> <cumulative seconds> <time per call> [*] <nb_call> <nb_op> <nb_apply> <Op name>') print('<% of local_time spent on this kind of Op> <cumulative %> '
sotimes = [(t*100/local_time, t, a, sop_cimpl[a], sop_call[a], sop_op[a], sop_apply[a]) for a, t in sop_time.items()] '<self seconds> <cumulative seconds> <time per call> [*] '
'<nb_call> <nb_op> <nb_apply> <Op name>')
sotimes = [(t*100/local_time, t, a, sop_cimpl[a], sop_call[a],
sop_op[a], sop_apply[a]) for a, t in sop_time.items()]
sotimes.sort() sotimes.sort()
sotimes.reverse() sotimes.reverse()
tot = 0 tot = 0
...@@ -430,9 +470,12 @@ class ProfileMode(Mode): ...@@ -430,9 +470,12 @@ class ProfileMode(Mode):
msg = '*' msg = '*'
else: else:
msg = ' ' msg = ' '
print(' %4.1f%% %5.1f%% %5.3fs %5.3fs %.2es %s %5d %2d %2d %s' % (f, ftot, t, tot, t/nb_call, msg, nb_call, nb_op, nb_apply, a)) print(' %4.1f%% %5.1f%% %5.3fs %5.3fs %.2es %s %5d %2d '
print(' ... (remaining %i single Op account for %.2f%%(%.2fs) of the runtime)'\ '%2d %s' % (f, ftot, t, tot, t/nb_call, msg, nb_call,
% (max(0, len(sotimes)-n_ops_to_print), nb_op, nb_apply, a))
print(' ... (remaining %i single Op account for %.2f%%(%.2fs) of '
'the runtime)' %
(max(0, len(sotimes)-n_ops_to_print),
sum(soinfo[0] for soinfo in sotimes[n_ops_to_print:]), sum(soinfo[0] for soinfo in sotimes[n_ops_to_print:]),
sum(soinfo[1] for soinfo in sotimes[n_ops_to_print:]))) sum(soinfo[1] for soinfo in sotimes[n_ops_to_print:])))
...@@ -446,12 +489,18 @@ class ProfileMode(Mode): ...@@ -446,12 +489,18 @@ class ProfileMode(Mode):
flops_msg = '' flops_msg = ''
if op_flops: if op_flops:
flops_msg = ' <MFlops/s>' flops_msg = ' <MFlops/s>'
print('\nHACK WARNING: we print the flops for some OP, but the logic don\'t always work. You need to know the internal of Theano to make it work correctly. Otherwise don\'t use!') print("\nHACK WARNING: we print the flops for some OP, but the "
"logic doesn't always work. You need to know the "
"internals of Theano to make it work correctly. "
"Otherwise don't use it!")
print() print()
print('Op-wise summary:') print('Op-wise summary:')
print('<%% of local_time spent on this kind of Op> <cumulative %%> <self seconds> <cumulative seconds> <time per call> [*] %s <nb_call> <nb apply> <Op name>'%(flops_msg)) print('<%% of local_time spent on this kind of Op> <cumulative %%> '
'<self seconds> <cumulative seconds> <time per call> [*] %s '
'<nb_call> <nb apply> <Op name>' % (flops_msg))
otimes = [(t*100/local_time, t, a, op_cimpl.get(a, 0), op_call.get(a, 0), op_apply.get(a, 0)) otimes = [(t*100/local_time, t, a, op_cimpl.get(a, 0),
op_call.get(a, 0), op_apply.get(a, 0))
for a, t in op_time.items()] for a, t in op_time.items()]
otimes.sort() otimes.sort()
otimes.reverse() otimes.reverse()
...@@ -467,20 +516,33 @@ class ProfileMode(Mode): ...@@ -467,20 +516,33 @@ class ProfileMode(Mode):
else: else:
msg = ' ' msg = ' '
if op_flops: if op_flops:
print(' %4.1f%% %5.1f%% %5.3fs %5.3fs %.2es %s %7.1f %5d %2d %s' % (f, ftot, t, tot, t/nb_call, msg, op_flops.get(a, -1), nb_call, nb_apply, a)) print(' %4.1f%% %5.1f%% %5.3fs %5.3fs %.2es %s %7.1f '
'%5d %2d %s' % (f, ftot, t, tot, t/nb_call, msg,
op_flops.get(a, -1), nb_call, nb_apply,
a))
else: else:
print(' %4.1f%% %5.1f%% %5.3fs %5.3fs %.2es %s %5d %2d %s' % (f, ftot, t, tot, t/nb_call, msg, nb_call, nb_apply, a)) print(' %4.1f%% %5.1f%% %5.3fs %5.3fs %.2es %s %5d %2d '
print(' ... (remaining %i Op account for %6.2f%%(%.2fs) of the runtime)'\ '%s' % (f, ftot, t, tot, t/nb_call, msg, nb_call,
% (max(0, len(otimes)-n_ops_to_print), nb_apply, a))
sum(f for f, t, a, ci, nb_call, nb_op in otimes[n_ops_to_print:]), print(' ... (remaining %i Op account for %6.2f%%(%.2fs) of the '
sum(t for f, t, a, ci, nb_call, nb_op in otimes[n_ops_to_print:]))) 'runtime)' %
(max(0, len(otimes)-n_ops_to_print),
sum(f for f, t, a, ci, nb_call, nb_op in
otimes[n_ops_to_print:]),
sum(t for f, t, a, ci, nb_call, nb_op in
otimes[n_ops_to_print:])))
print('(*) Op is running a c implementation') print('(*) Op is running a c implementation')
if print_apply: if print_apply:
print() print()
print('Apply-wise summary:') print('Apply-wise summary:')
print('<% of local_time spent at this position> <cumulative %%> <apply time> <cumulative seconds> <time per call> [*] <nb_call> <Apply position> <Apply Op name>') print('<% of local_time spent at this position> <cumulative %%> '
atimes = [(t*100/local_time, t, a, [v for k, v in fct_call.items() if k.maker.fgraph is a[1].fgraph][0]) for a, t in apply_time.items()] '<apply time> <cumulative seconds> <time per call> [*] '
'<nb_call> <Apply position> <Apply Op name>')
atimes = [(t*100/local_time, t, a,
[v for k, v in fct_call.items()
if k.maker.fgraph is a[1].fgraph][0])
for a, t in apply_time.items()]
atimes.sort() atimes.sort()
atimes.reverse() atimes.reverse()
tot = 0 tot = 0
...@@ -493,10 +555,13 @@ class ProfileMode(Mode): ...@@ -493,10 +555,13 @@ class ProfileMode(Mode):
msg = '*' msg = '*'
else: else:
msg = ' ' msg = ' '
print(' %4.1f%% %5.1f%% %5.3fs %5.3fs %.2es %s %i %2i %s' % ( print(' %4.1f%% %5.1f%% %5.3fs %5.3fs %.2es %s %i '
f, ftot, t, tot, t/nb_call, msg, nb_call, a[0], str(a[1]))) '%2i %s' %
print(' ... (remaining %i Apply instances account for %.2f%%(%.2fs) of the runtime)'\ (f, ftot, t, tot, t/nb_call, msg, nb_call, a[0],
% (max(0, len(atimes)-n_apply_to_print), str(a[1])))
print(' ... (remaining %i Apply instances account for '
'%.2f%%(%.2fs) of the runtime)' %
(max(0, len(atimes)-n_apply_to_print),
sum(f for f, t, a, nb_call in atimes[n_apply_to_print:]), sum(f for f, t, a, nb_call in atimes[n_apply_to_print:]),
sum(t for f, t, a, nb_call in atimes[n_apply_to_print:]))) sum(t for f, t, a, nb_call in atimes[n_apply_to_print:])))
print('(*) Op is running a c implementation') print('(*) Op is running a c implementation')
...@@ -506,8 +571,9 @@ class ProfileMode(Mode): ...@@ -506,8 +571,9 @@ class ProfileMode(Mode):
other_time) other_time)
if not variable_shape: if not variable_shape:
print("""\nProfile of Theano intermediate memory disabled. print("\nProfile of Theano intermediate memory disabled. "
To enabled, put the Theano flag ProfileMode.profile_memory to True.""") "To enable, set the Theano flag ProfileMode.profile_memory "
"to True.""")
else: else:
print(""" print("""
The memory profile in ProfileMode is removed! The memory profile in ProfileMode is removed!
...@@ -540,7 +606,6 @@ Test them first, as they are not guaranteed to always provide a speedup.""") ...@@ -540,7 +606,6 @@ Test them first, as they are not guaranteed to always provide a speedup.""")
scal.Cosh, scal.Sinh, scal.Cosh, scal.Sinh,
T.nnet.sigm.ScalarSigmoid, T.nnet.sigm.ScalarSigmoid,
T.nnet.sigm.ScalarSoftplus] T.nnet.sigm.ScalarSoftplus]
# Abs, Mod in float{32,64} only
def get_scalar_ops(s): def get_scalar_ops(s):
if isinstance(s, theano.scalar.Composite): if isinstance(s, theano.scalar.Composite):
...@@ -566,7 +631,8 @@ Test them first, as they are not guaranteed to always provide a speedup.""") ...@@ -566,7 +631,8 @@ Test them first, as they are not guaranteed to always provide a speedup.""")
if s_op.__class__ in scalar_op_amdlibm_speed_up: if s_op.__class__ in scalar_op_amdlibm_speed_up:
return True return True
elif s_op.__class__ not in scalar_op_amdlibm_no_speed_up: elif s_op.__class__ not in scalar_op_amdlibm_no_speed_up:
print("We don't know if amdlibm will accelerate this scalar op.", s_op) print("We don't know if amdlibm will accelerate "
"this scalar op.", s_op)
return False return False
def exp_float32_op(op): def exp_float32_op(op):
...@@ -585,7 +651,9 @@ Test them first, as they are not guaranteed to always provide a speedup.""") ...@@ -585,7 +651,9 @@ Test them first, as they are not guaranteed to always provide a speedup.""")
# tip 2 # tip 2
if not config.lib.amdlibm and any([amdlibm_speed_up(a.op) for i, a if not config.lib.amdlibm and any([amdlibm_speed_up(a.op) for i, a
in apply_time]): in apply_time]):
print(" - Try installing amdlibm and set the Theano flag lib.amdlibm=True. This speeds up only some Elemwise operation.") print(" - Try installing amdlibm and set the Theano flag "
"lib.amdlibm=True. This speeds up only some Elemwise "
"operation.")
printed_tip = True printed_tip = True
# tip 3 # tip 3
...@@ -601,7 +669,8 @@ Test them first, as they are not guaranteed to always provide a speedup.""") ...@@ -601,7 +669,8 @@ Test them first, as they are not guaranteed to always provide a speedup.""")
for a, t in apply_time.iteritems(): for a, t in apply_time.iteritems():
node = a[1] node = a[1]
if (isinstance(node.op, T.Dot) and if (isinstance(node.op, T.Dot) and
all([len(i.type.broadcastable) == 2 for i in node.inputs])): all([len(i.type.broadcastable) == 2
for i in node.inputs])):
print((" - You have a dot operation that was not optimized to" print((" - You have a dot operation that was not optimized to"
" dot22 (which is faster). Make sure the inputs are " " dot22 (which is faster). Make sure the inputs are "
"float32 or float64, and are the same for both inputs. " "float32 or float64, and are the same for both inputs. "
......
...@@ -240,7 +240,6 @@ class ProfileStats(object): ...@@ -240,7 +240,6 @@ class ProfileStats(object):
else: else:
self.flag_time_thunks = flag_time_thunks self.flag_time_thunks = flag_time_thunks
self.__dict__.update(kwargs) self.__dict__.update(kwargs)
#print >> sys.stderr, "self.message", self.message
if atexit_print: if atexit_print:
global _atexit_print_list global _atexit_print_list
_atexit_print_list.append(self) _atexit_print_list.append(self)
...@@ -377,9 +376,6 @@ class ProfileStats(object): ...@@ -377,9 +376,6 @@ class ProfileStats(object):
tot = 0 tot = 0
print('Class', file=file) print('Class', file=file)
print('---', file=file) print('---', file=file)
#print >> file, '<% time> <cumulative %%> <apply time>,'
#print >>file, '<cumulative seconds> <time per call> <nb_call>'
#print >>file, '<Class name>'
hs = [] hs = []
# formatting string # formatting string
es = [] es = []
...@@ -421,7 +417,8 @@ class ProfileStats(object): ...@@ -421,7 +417,8 @@ class ProfileStats(object):
tot += t tot += t
ftot = tot * 100 / local_time ftot = tot * 100 / local_time
# Remove the useless start and end of the class name: # Remove the useless start and end of the class name:
# "<class 'theano.sandbox.cuda.blas.GpuDot22'>" -> "theano.sandbox.cuda.blas.GpuDot22" # "<class 'theano.sandbox.cuda.blas.GpuDot22'>" ->
# "theano.sandbox.cuda.blas.GpuDot22"
class_name = str(a)[8:-2][:maxlen] class_name = str(a)[8:-2][:maxlen]
print(format_str % (f, ftot, t, t / nb_call, print(format_str % (f, ftot, t, t / nb_call,
impl, nb_call, impl, nb_call,
...@@ -429,10 +426,12 @@ class ProfileStats(object): ...@@ -429,10 +426,12 @@ class ProfileStats(object):
# While this carries over less information, it is arranged such # While this carries over less information, it is arranged such
# that it way more readeable that the previous output of the # that it way more readeable that the previous output of the
# profiler # profiler
print(' ... (remaining %i Classes account for %6.2f%%(%.2fs) of the runtime)'\ print(' ... (remaining %i Classes account for %6.2f%%(%.2fs) of '
% (max(0, len(otimes) - N), 'the runtime)' %
(max(0, len(otimes) - N),
sum(f for f, t, a, ci, nb_call, nb_op in otimes[N:]), sum(f for f, t, a, ci, nb_call, nb_op in otimes[N:]),
sum(t for f, t, a, ci, nb_call, nb_op in otimes[N:])), file=file) sum(t for f, t, a, ci, nb_call, nb_op in otimes[N:])),
file=file)
print('', file=file) print('', file=file)
def summary_ops(self, file=sys.stderr, N=None): def summary_ops(self, file=sys.stderr, N=None):
...@@ -459,9 +458,6 @@ class ProfileStats(object): ...@@ -459,9 +458,6 @@ class ProfileStats(object):
tot = 0 tot = 0
print('Ops', file=file) print('Ops', file=file)
print('---', file=file) print('---', file=file)
#print >> file, '<% time> <cumulative %%> <apply time>,'
#print >>file, '<cumulative seconds> <time per call> <nb_call>'
#print >>file, '<Op name>'
hs = [] hs = []
# formatting string # formatting string
es = [] es = []
...@@ -508,10 +504,12 @@ class ProfileStats(object): ...@@ -508,10 +504,12 @@ class ProfileStats(object):
# While this carries over less information, it is arranged such # While this carries over less information, it is arranged such
# that it way more readeable that the previous output of the # that it way more readeable that the previous output of the
# profiler # profiler
print(' ... (remaining %i Ops account for %6.2f%%(%.2fs) of the runtime)'\ print(' ... (remaining %i Ops account for %6.2f%%(%.2fs) of '
% (max(0, len(otimes) - N), 'the runtime)' %
(max(0, len(otimes) - N),
sum(f for f, t, a, ci, nb_call, nb_op in otimes[N:]), sum(f for f, t, a, ci, nb_call, nb_op in otimes[N:]),
sum(t for f, t, a, ci, nb_call, nb_op in otimes[N:])), file=file) sum(t for f, t, a, ci, nb_call, nb_op in otimes[N:])),
file=file)
print('', file=file) print('', file=file)
def summary_nodes(self, file=sys.stderr, N=None): def summary_nodes(self, file=sys.stderr, N=None):
...@@ -526,7 +524,6 @@ class ProfileStats(object): ...@@ -526,7 +524,6 @@ class ProfileStats(object):
print('Apply', file=file) print('Apply', file=file)
print('------', file=file) print('------', file=file)
#print >> file, '<% time> <cumulative %%> <apply time> <cumulative seconds> <time per call> <nb_call> <Apply Op name>'
# headers # headers
hs = [] hs = []
# formatting string # formatting string
...@@ -620,10 +617,9 @@ class ProfileStats(object): ...@@ -620,10 +617,9 @@ class ProfileStats(object):
idx, dtype, sh, st), file=file) idx, dtype, sh, st), file=file)
# Same as before, this I've sacrificied some information making # Same as before, this I've sacrificied some information making
# the output more readable # the output more readable
# print >> file, ' %4.1f%% %5.1f%% %5.3fs %5.3fs %.2es %i %s'%( print(' ... (remaining %i Apply instances account for '
# f, ftot, t, tot, t/nb_call,nb_call, str(a)) '%.2f%%(%.2fs) of the runtime)' %
print(' ... (remaining %i Apply instances account for %.2f%%(%.2fs) of the runtime)'\ (max(0, len(atimes) - N),
% (max(0, len(atimes) - N),
sum(f for f, t, a, nd_id, nb_call in atimes[N:]), sum(f for f, t, a, nd_id, nb_call in atimes[N:]),
sum(t for f, t, a, nd_id, nb_call in atimes[N:])), file=file) sum(t for f, t, a, nd_id, nb_call in atimes[N:])), file=file)
print('', file=file) print('', file=file)
...@@ -640,15 +636,17 @@ class ProfileStats(object): ...@@ -640,15 +636,17 @@ class ProfileStats(object):
100 * self.vm_call_time / self.fct_call_time), file=file) 100 * self.vm_call_time / self.fct_call_time), file=file)
local_time = sum(self.apply_time.values()) local_time = sum(self.apply_time.values())
if local_time > 0: if local_time > 0:
print(' Time in thunks: %es (%.3f%%)' % ( print(' Time in thunks: %es (%.3f%%)' %
local_time, 100 * local_time / self.fct_call_time), file=file) (local_time, 100 * local_time / self.fct_call_time),
file=file)
print(' Total compile time: %es' % self.compile_time, file=file) print(' Total compile time: %es' % self.compile_time, file=file)
print(' Number of Apply nodes: %d' % self.nb_nodes, file=file) print(' Number of Apply nodes: %d' % self.nb_nodes, file=file)
print(' Theano Optimizer time: %es' % self.optimizer_time, file=file) print(' Theano Optimizer time: %es' % self.optimizer_time,
print(' Theano validate time: %es' % self.validate_time, file=file) file=file)
print((' Theano Linker time (includes C,' print(' Theano validate time: %es' % self.validate_time,
' CUDA code generation/compiling): %es' % file=file)
self.linker_time), file=file) print(' Theano Linker time (includes C, CUDA code '
'generation/compiling): %es' % self.linker_time, file=file)
print(' Import time %es' % self.import_time, file=file) print(' Import time %es' % self.import_time, file=file)
print('', file=file) print('', file=file)
...@@ -656,7 +654,8 @@ class ProfileStats(object): ...@@ -656,7 +654,8 @@ class ProfileStats(object):
assert self.validate_time < self.optimizer_time assert self.validate_time < self.optimizer_time
def summary_globals(self, file): def summary_globals(self, file):
print('Time in all call to theano.grad() %es' % theano.gradient.grad_time, file=file) print('Time in all call to theano.grad() %es' %
theano.gradient.grad_time, file=file)
def summary_memory(self, file, N=None): def summary_memory(self, file, N=None):
fct_memory = {} # fgraph->dict(node->[outputs size]) fct_memory = {} # fgraph->dict(node->[outputs size])
...@@ -742,7 +741,8 @@ class ProfileStats(object): ...@@ -742,7 +741,8 @@ class ProfileStats(object):
# two data structure used to mimic Python gc # two data structure used to mimic Python gc
viewed_by = {} # {var1: [vars that view var1]} viewed_by = {} # {var1: [vars that view var1]}
# The len of the list is the value of python ref count. But we use a list, not just the ref count value. # The len of the list is the value of python ref
# count. But we use a list, not just the ref count value.
# This is more safe to help detect potential bug in the algo # This is more safe to help detect potential bug in the algo
for var in fgraph.variables: for var in fgraph.variables:
viewed_by[var] = [] viewed_by[var] = []
...@@ -778,14 +778,16 @@ class ProfileStats(object): ...@@ -778,14 +778,16 @@ class ProfileStats(object):
ins = None ins = None
if dmap and idx2 in dmap: if dmap and idx2 in dmap:
vidx = dmap[idx2] vidx = dmap[idx2]
assert len( assert len(vidx) == 1, ("Here we only support the "
vidx) == 1, "Here we only support the possibility to destroy one input" "possibility to destroy one "
"input")
ins = node.inputs[vidx[0]] ins = node.inputs[vidx[0]]
if vmap and idx2 in vmap: if vmap and idx2 in vmap:
assert ins is None assert ins is None
vidx = vmap[idx2] vidx = vmap[idx2]
assert len( assert len(vidx) == 1, ("Here we only support the "
vidx) == 1, "Here we only support the possibility to view one input" "possibility to view one "
"input")
ins = node.inputs[vidx[0]] ins = node.inputs[vidx[0]]
if ins is not None: if ins is not None:
# This is needed for destroy_map in case it # This is needed for destroy_map in case it
...@@ -818,7 +820,8 @@ class ProfileStats(object): ...@@ -818,7 +820,8 @@ class ProfileStats(object):
if (dependencies[ins] and if (dependencies[ins] and
ins not in fgraph.outputs and ins not in fgraph.outputs and
ins.owner and ins.owner and
all([compute_map[v][0] for v in dependencies[ins]])): all([compute_map[v][0]
for v in dependencies[ins]])):
if ins not in view_of and not viewed_by.get(ins, []): if ins not in view_of and not viewed_by.get(ins, []):
running_memory_size[cg] -= var_mem[ins] running_memory_size[cg] -= var_mem[ins]
elif ins in view_of: elif ins in view_of:
...@@ -907,14 +910,16 @@ class ProfileStats(object): ...@@ -907,14 +910,16 @@ class ProfileStats(object):
ins = None ins = None
if dmap and idx in dmap: if dmap and idx in dmap:
vidx = dmap[idx] vidx = dmap[idx]
assert len( assert len(vidx) == 1, ("Here we only support "
vidx) == 1, "Here we only support the possibility to destroy one input" "the possibility to "
"destroy one input")
ins = node.inputs[vidx[0]] ins = node.inputs[vidx[0]]
if vmap and idx in vmap: if vmap and idx in vmap:
assert ins is None assert ins is None
vidx = vmap[idx] vidx = vmap[idx]
assert len( assert len(vidx) == 1, ("Here we only support "
vidx) == 1, "Here we only support the possibility to destroy one input" "the possibility to "
"view one input")
ins = node.inputs[vidx[0]] ins = node.inputs[vidx[0]]
if ins is not None: if ins is not None:
# This is needed for destroy_map in case it # This is needed for destroy_map in case it
...@@ -922,7 +927,7 @@ class ProfileStats(object): ...@@ -922,7 +927,7 @@ class ProfileStats(object):
# the output could be different then the # the output could be different then the
# input. # input.
assert isinstance(ins, theano.Variable) assert isinstance(ins, theano.Variable)
# We keep trac of view only again the original # We keep track of view only again the original
origin = view_of.get(ins, ins) origin = view_of.get(ins, ins)
view_of[out] = origin view_of[out] = origin
viewof_change.append(out) viewof_change.append(out)
...@@ -944,8 +949,10 @@ class ProfileStats(object): ...@@ -944,8 +949,10 @@ class ProfileStats(object):
if (dependencies[ins] and if (dependencies[ins] and
ins not in fgraph.outputs and ins not in fgraph.outputs and
ins.owner and ins.owner and
all([compute_map[v][0] for v in dependencies[ins]])): all([compute_map[v][0]
if ins not in view_of and not viewed_by.get(ins, []): for v in dependencies[ins]])):
if (ins not in view_of and
not viewed_by.get(ins, [])):
mem_freed += var_mem[ins] mem_freed += var_mem[ins]
elif ins in view_of: elif ins in view_of:
origin = view_of[ins] origin = view_of[ins]
...@@ -953,7 +960,8 @@ class ProfileStats(object): ...@@ -953,7 +960,8 @@ class ProfileStats(object):
viewedby_remove[origin].append(ins) viewedby_remove[origin].append(ins)
if (not viewed_by[origin] and if (not viewed_by[origin] and
origin not in fgraph.inputs and origin not in fgraph.inputs and
not isinstance(origin, theano.Constant)): not isinstance(origin,
theano.Constant)):
mem_freed += var_mem[origin] mem_freed += var_mem[origin]
else: else:
# ins is viewed_by something else, so its # ins is viewed_by something else, so its
...@@ -964,7 +972,8 @@ class ProfileStats(object): ...@@ -964,7 +972,8 @@ class ProfileStats(object):
done_set.add(node) done_set.add(node)
frozen_set = frozenset(done_set) frozen_set = frozenset(done_set)
if done_dict.get(frozen_set, max_mem_count + 1) > max_mem_count: if (done_dict.get(frozen_set, max_mem_count + 1) >
max_mem_count):
# check if frozen_set is in done_set # check if frozen_set is in done_set
# no, add it to done_set # no, add it to done_set
# yes, then compare the past mem and current mem # yes, then compare the past mem and current mem
...@@ -1008,7 +1017,8 @@ class ProfileStats(object): ...@@ -1008,7 +1017,8 @@ class ProfileStats(object):
# two data structure used to mimic Python gc # two data structure used to mimic Python gc
viewed_by = {} # {var1: [vars that view var1]} viewed_by = {} # {var1: [vars that view var1]}
# The len of the list is the value of python ref count. But we use a list, not just the ref count value. # The len of the list is the value of python ref
# count. But we use a list, not just the ref count value.
# This is more safe to help detect potential bug in the algo # This is more safe to help detect potential bug in the algo
for var in fgraph.variables: for var in fgraph.variables:
viewed_by[var] = [] viewed_by[var] = []
...@@ -1043,28 +1053,29 @@ class ProfileStats(object): ...@@ -1043,28 +1053,29 @@ class ProfileStats(object):
max_sum_size = max(max_sum_size, sum_size) max_sum_size = max(max_sum_size, sum_size)
max_node_memory_size[0] = max(max_node_memory_size[0], max_node_memory_size[0] = max(max_node_memory_size[0],
sum(old_running_memory[0])) sum(old_running_memory[0]))
max_running_max_memory_size[0] = max(max_running_max_memory_size[0], max_running_max_memory_size[0] = \
sum(old_running_memory[2])) max(max_running_max_memory_size[0], sum(old_running_memory[2]))
# Separate CPU and GPU # Separate CPU and GPU
max_node_memory_size[1] = max(max_node_memory_size[1], max_node_memory_size[1] = max(max_node_memory_size[1],
old_running_memory[0][0]) old_running_memory[0][0])
max_node_memory_size[2] = max(max_node_memory_size[2], max_node_memory_size[2] = max(max_node_memory_size[2],
old_running_memory[0][1]) old_running_memory[0][1])
max_running_max_memory_size[1] = max(max_running_max_memory_size[1], max_running_max_memory_size[1] = \
old_running_memory[2][0]) max(max_running_max_memory_size[1], old_running_memory[2][0])
max_running_max_memory_size[2] = max(max_running_max_memory_size[2], max_running_max_memory_size[2] = \
old_running_memory[2][1]) max(max_running_max_memory_size[2], old_running_memory[2][1])
max_node_memory_saved_by_inplace = max( max_node_memory_saved_by_inplace = \
max_node_memory_saved_by_inplace, old_running_memory[3]) max(max_node_memory_saved_by_inplace, old_running_memory[3])
max_node_memory_saved_by_view = max(max_node_memory_saved_by_view, max_node_memory_saved_by_view = max(max_node_memory_saved_by_view,
old_running_memory[4]) old_running_memory[4])
# Store max of some stats with new order # Store max of some stats with new order
new_max_node_memory_size[0] = max(new_max_node_memory_size[0], new_max_node_memory_size[0] = max(new_max_node_memory_size[0],
sum(new_running_memory[0])) sum(new_running_memory[0]))
new_max_running_max_memory_size[0] = max(new_max_running_max_memory_size[0], new_max_running_max_memory_size[0] = \
max(new_max_running_max_memory_size[0],
sum(new_running_memory[2])) sum(new_running_memory[2]))
# Separate CPU and GPU # Separate CPU and GPU
...@@ -1072,15 +1083,18 @@ class ProfileStats(object): ...@@ -1072,15 +1083,18 @@ class ProfileStats(object):
new_running_memory[0][0]) new_running_memory[0][0])
new_max_node_memory_size[2] = max(new_max_node_memory_size[2], new_max_node_memory_size[2] = max(new_max_node_memory_size[2],
new_running_memory[0][1]) new_running_memory[0][1])
new_max_running_max_memory_size[1] = max(new_max_running_max_memory_size[1], new_max_running_max_memory_size[1] = \
max(new_max_running_max_memory_size[1],
new_running_memory[2][0]) new_running_memory[2][0])
new_max_running_max_memory_size[2] = max(new_max_running_max_memory_size[2], new_max_running_max_memory_size[2] = \
max(new_max_running_max_memory_size[2],
new_running_memory[2][1]) new_running_memory[2][1])
new_max_node_memory_saved_by_inplace = max( new_max_node_memory_saved_by_inplace = \
new_max_node_memory_saved_by_inplace, new_running_memory[3]) max(new_max_node_memory_saved_by_inplace,
new_max_node_memory_saved_by_view = max(new_max_node_memory_saved_by_view, new_running_memory[3])
new_running_memory[4]) new_max_node_memory_saved_by_view = \
max(new_max_node_memory_saved_by_view, new_running_memory[4])
# Config: whether print min memory peak # Config: whether print min memory peak
if config.profiling.min_peak_memory: if config.profiling.min_peak_memory:
...@@ -1093,8 +1107,8 @@ class ProfileStats(object): ...@@ -1093,8 +1107,8 @@ class ProfileStats(object):
del fgraph, nodes_mem del fgraph, nodes_mem
if len(fct_memory) > 1: if len(fct_memory) > 1:
print(("Memory Profile " print("Memory Profile (the max between all functions in "
"(the max between all functions in that profile)"), file=file) "that profile)", file=file)
else: else:
print("Memory Profile", file=file) print("Memory Profile", file=file)
...@@ -1129,17 +1143,21 @@ class ProfileStats(object): ...@@ -1129,17 +1143,21 @@ class ProfileStats(object):
print("---", file=file) print("---", file=file)
if min_max_peak: if min_max_peak:
print(" Minimum peak from all valid apply node order is %dKB(took %.3fs to compute)" % (int(round( print(" Minimum peak from all valid apply node order is "
min_max_peak / 1024.)), min_peak_time), file=file) "%dKB(took %.3fs to compute)" %
print(" Memory saved if views are used: %dKB (%dKB)" % (int( (int(round(min_max_peak / 1024.)), min_peak_time), file=file)
round(new_max_node_memory_saved_by_view / 1024.)), int( print(" Memory saved if views are used: %dKB (%dKB)" %
round(max_node_memory_saved_by_view / 1024.))), file=file) (int(round(new_max_node_memory_saved_by_view / 1024.)),
print(" Memory saved if inplace ops are used: %dKB (%dKB)" % \ int(round(max_node_memory_saved_by_view / 1024.))), file=file)
print(" Memory saved if inplace ops are used: %dKB (%dKB)" %
(int(round(new_max_node_memory_saved_by_inplace / 1024.)), (int(round(new_max_node_memory_saved_by_inplace / 1024.)),
int(round(max_node_memory_saved_by_inplace / 1024.))), file=file) int(round(max_node_memory_saved_by_inplace / 1024.))),
print(" Memory saved if gc is enabled: %dKB (%dKB)" % (int( file=file)
round(new_max_node_memory_size[0] - new_max_running_max_memory_size[0]) / 1024.), int( print(" Memory saved if gc is enabled: %dKB (%dKB)" %
round(max_node_memory_size[0] - max_running_max_memory_size[0]) / 1024.)), file=file) (int(round(new_max_node_memory_size[0] -
new_max_running_max_memory_size[0]) / 1024.),
int(round(max_node_memory_size[0] -
max_running_max_memory_size[0]) / 1024.)), file=file)
print("---", file=file) print("---", file=file)
...@@ -1148,19 +1166,19 @@ class ProfileStats(object): ...@@ -1148,19 +1166,19 @@ class ProfileStats(object):
hasattr(theano.sandbox.cuda, 'cuda_ndarray') and hasattr(theano.sandbox.cuda, 'cuda_ndarray') and
hasattr(theano.sandbox.cuda.cuda_ndarray.cuda_ndarray, hasattr(theano.sandbox.cuda.cuda_ndarray.cuda_ndarray,
'theano_allocated')): 'theano_allocated')):
_, gpu_max = theano.sandbox.cuda.cuda_ndarray.cuda_ndarray.theano_allocated() cuda_ndarray = theano.sandbox.cuda.cuda_ndarray.cuda_ndarray
print((" Max Memory allocated on the GPU " _, gpu_max = cuda_ndarray.theano_allocated()
"(for all functions): %dKB" % print(" Max Memory allocated on the GPU (for all functions): "
int(round(gpu_max / 1024.))), file=file) "%dKB" % int(round(gpu_max / 1024.)), file=file)
print("", file=file) print("", file=file)
if len(fct_memory) > 1: if len(fct_memory) > 1:
print(( print(" This list is based on all functions in the profile",
" This list is based on all functions in the profile"), file=file) file=file)
print((" <Sum apply outputs (bytes)>" print(" <Sum apply outputs (bytes)>"
" <Apply outputs shape>" " <Apply outputs shape>"
" <created/inplace/view>" " <created/inplace/view>"
" <Apply node>"), file=file) " <Apply node>", file=file)
print("", file=file) print("", file=file)
items = node_mem.items() items = node_mem.items()
items.sort(key=lambda a: a[1], reverse=True) items.sort(key=lambda a: a[1], reverse=True)
...@@ -1181,9 +1199,8 @@ class ProfileStats(object): ...@@ -1181,9 +1199,8 @@ class ProfileStats(object):
else: else:
size = "%10s" % "Unknown" size = "%10s" % "Unknown"
print(' %s %s %s %s' % (size, print(' %s %s %s %s' % (size, shapes, ' '.join(code), node),
shapes, file=file)
' '.join(code), node), file=file)
sum_remaining = sum(size for _, size in items[N:]) sum_remaining = sum(size for _, size in items[N:])
size_sum_dense = sum(node_mem.values()) size_sum_dense = sum(node_mem.values())
...@@ -1191,23 +1208,21 @@ class ProfileStats(object): ...@@ -1191,23 +1208,21 @@ class ProfileStats(object):
p = "0%" p = "0%"
else: else:
p = "(%.2f%%)" % (float(sum_remaining) / size_sum_dense * 100) p = "(%.2f%%)" % (float(sum_remaining) / size_sum_dense * 100)
print(( print(' ... (remaining %i Apply account for %4dB/%dB (%s) of the'
' ... (remaining %i Apply account for %4dB/%dB (%s) of the' ' Apply with dense outputs sizes)' % (max(0, len(node_mem) - N),
' Apply with dense outputs sizes)') % (max(0, len(node_mem) - N),
sum_remaining, sum_remaining,
size_sum_dense, p size_sum_dense, p),
), file=file) file=file)
print('', file=file) print('', file=file)
if N == 0: if N == 0:
print((' All Apply nodes have output sizes that take' print(' All Apply nodes have output sizes that take less '
' less than %dB.' % 'than %dB.' % config.profiling.min_memory_size, file=file)
config.profiling.min_memory_size), file=file) print(" <created/inplace/view> is taken from the Op's declaration.",
print(( file=file)
" <created/inplace/view> is taken from the Op's declaration."), file=file) print(" Apply nodes marked 'inplace' or 'view' may"
print((" Apply nodes marked 'inplace' or 'view' may"
" actually allocate memory, this is not reported" " actually allocate memory, this is not reported"
" here. If you use DebugMode, warnings will be" " here. If you use DebugMode, warnings will be"
" emitted in those cases."), file=file) " emitted in those cases.", file=file)
print('', file=file) print('', file=file)
def summary(self, file=sys.stderr, n_ops_to_print=20, def summary(self, file=sys.stderr, n_ops_to_print=20,
...@@ -1220,8 +1235,8 @@ class ProfileStats(object): ...@@ -1220,8 +1235,8 @@ class ProfileStats(object):
self.summary_ops(file, n_ops_to_print) self.summary_ops(file, n_ops_to_print)
self.summary_nodes(file, n_apply_to_print) self.summary_nodes(file, n_apply_to_print)
elif self.fct_callcount > 0: elif self.fct_callcount > 0:
print((" No execution time accumulated " print(" No execution time accumulated "
"(hint: try config profiling.time_thunks=1)"), file=file) "(hint: try config profiling.time_thunks=1)", file=file)
if self.variable_shape or self.variable_strides: if self.variable_shape or self.variable_strides:
self.summary_memory(file, n_apply_to_print) self.summary_memory(file, n_apply_to_print)
if self.optimizer_profile: if self.optimizer_profile:
...@@ -1231,7 +1246,7 @@ class ProfileStats(object): ...@@ -1231,7 +1246,7 @@ class ProfileStats(object):
self.optimizer_profile[1]) self.optimizer_profile[1])
if 0: # old code still to be ported from ProfileMode if False: # old code still to be ported from ProfileMode
def long_print(self, file=sys.stderr, fct_name=None, message=None, def long_print(self, file=sys.stderr, fct_name=None, message=None,
n_apply_to_print=15, n_ops_to_print=20, print_apply=False): n_apply_to_print=15, n_ops_to_print=20, print_apply=False):
""" """
......
"""Provide a simple user friendly API to Theano-managed memory""" """Provide a simple user friendly API to Theano-managed memory"""
__docformat__ = 'restructuredtext en'
# Standard imports # Standard imports
import copy import copy
import logging import logging
...@@ -12,6 +10,7 @@ import numpy ...@@ -12,6 +10,7 @@ import numpy
from theano.gof import Container, Variable, generic, utils from theano.gof import Container, Variable, generic, utils
_logger = logging.getLogger('theano.compile.sharedvalue') _logger = logging.getLogger('theano.compile.sharedvalue')
__docformat__ = 'restructuredtext en'
class SharedVariable(Variable): class SharedVariable(Variable):
...@@ -49,7 +48,8 @@ class SharedVariable(Variable): ...@@ -49,7 +48,8 @@ class SharedVariable(Variable):
or copied, so they must have the correct type. or copied, so they must have the correct type.
:param allow_downcast: Only applies if `strict` is False. :param allow_downcast: Only applies if `strict` is False.
True -> allow assigned value to lose precision when cast during assignment. True -> allow assigned value to lose precision when cast
during assignment.
False -> never allow precision loss. False -> never allow precision loss.
None -> only allow downcasting of a Python float to a scalar floatX. None -> only allow downcasting of a Python float to a scalar floatX.
...@@ -65,12 +65,13 @@ class SharedVariable(Variable): ...@@ -65,12 +65,13 @@ class SharedVariable(Variable):
if container is not None: if container is not None:
self.container = container self.container = container
if (value is not None) or (strict is not None): if (value is not None) or (strict is not None):
raise TypeError( raise TypeError('value and strict are ignored if you pass '
'value and strict are ignored if you pass a container here') 'a container here')
else: else:
if container is not None: if container is not None:
raise TypeError('Error to specify both value and container') raise TypeError('Error to specify both value and container')
self.container = Container(self, self.container = Container(
self,
storage=[type.filter(value, strict=strict, storage=[type.filter(value, strict=strict,
allow_downcast=allow_downcast)], allow_downcast=allow_downcast)],
readonly=False, readonly=False,
...@@ -183,7 +184,8 @@ def shared(value, name=None, strict=False, allow_downcast=None, **kwargs): ...@@ -183,7 +184,8 @@ def shared(value, name=None, strict=False, allow_downcast=None, **kwargs):
potential constructors to those that can accept those kwargs. potential constructors to those that can accept those kwargs.
:note: Some shared variable have ``borrow`` as extra kwargs. :note: Some shared variable have ``borrow`` as extra kwargs.
`See <http://deeplearning.net/software/theano/tutorial/aliasing.html#borrowing-when-creating-shared-variables>`_ for detail. `See <http://deeplearning.net/software/theano/tutorial/aliasing.\
html#borrowing-when-creating-shared-variables>`_ for detail.
:note: Some shared variable have ``broadcastable`` as extra kwargs. :note: Some shared variable have ``broadcastable`` as extra kwargs.
As shared variable shapes can change, all dimensions default As shared variable shapes can change, all dimensions default
...@@ -200,7 +202,8 @@ def shared(value, name=None, strict=False, allow_downcast=None, **kwargs): ...@@ -200,7 +202,8 @@ def shared(value, name=None, strict=False, allow_downcast=None, **kwargs):
try: try:
if isinstance(value, Variable): if isinstance(value, Variable):
raise TypeError(" Shared variable constructor needs numeric values and not symbolic variables.") raise TypeError("Shared variable constructor needs numeric "
"values and not symbolic variables.")
for ctor in reversed(shared.constructors): for ctor in reversed(shared.constructors):
try: try:
......
import cPickle, logging import cPickle
import logging
_logger = logging.getLogger("theano.gof.callcache") _logger = logging.getLogger("theano.gof.callcache")
...@@ -18,9 +19,6 @@ class CallCache(object): ...@@ -18,9 +19,6 @@ class CallCache(object):
def persist(self, filename=None): def persist(self, filename=None):
if filename is None: if filename is None:
filename = self.filename filename = self.filename
# backport
#filename = self.filename if filename is None else filename
f = open(filename, 'w') f = open(filename, 'w')
cPickle.dump(self.cache, f) cPickle.dump(self.cache, f)
f.close() f.close()
...@@ -28,9 +26,6 @@ class CallCache(object): ...@@ -28,9 +26,6 @@ class CallCache(object):
def call(self, fn, args=(), key=None): def call(self, fn, args=(), key=None):
if key is None: if key is None:
key = (fn, tuple(args)) key = (fn, tuple(args))
# backport
#key = (fn, tuple(args)) if key is None else key
if key not in self.cache: if key not in self.cache:
_logger.debug('cache miss %i', len(self.cache)) _logger.debug('cache miss %i', len(self.cache))
self.cache[key] = fn(*args) self.cache[key] = fn(*args)
......
...@@ -8,7 +8,6 @@ import re ...@@ -8,7 +8,6 @@ import re
import shutil import shutil
import struct import struct
import socket import socket
import subprocess
import sys import sys
import textwrap import textwrap
...@@ -295,7 +294,8 @@ def cleanup(): ...@@ -295,7 +294,8 @@ def cleanup():
have_npy_abi_version = True have_npy_abi_version = True
elif obj.startswith('c_compiler_str='): elif obj.startswith('c_compiler_str='):
have_c_compiler = True have_c_compiler = True
elif (isinstance(obj, (theano.gof.Op, theano.gof.Type)) and elif (isinstance(obj, (theano.gof.Op,
theano.gof.Type)) and
hasattr(obj, 'c_code_cache_version')): hasattr(obj, 'c_code_cache_version')):
v = obj.c_code_cache_version() v = obj.c_code_cache_version()
if v not in [(), None] and v not in key[0]: if v not in [(), None] and v not in key[0]:
...@@ -310,7 +310,7 @@ def cleanup(): ...@@ -310,7 +310,7 @@ def cleanup():
if keydata.key_pkl != filename: if keydata.key_pkl != filename:
keydata.key_pkl = filename keydata.key_pkl = filename
keydata.remove_key(key) keydata.remove_key(key)
except IOError as e: except IOError:
_logger.error( _logger.error(
"Could not remove file '%s'. To complete " "Could not remove file '%s'. To complete "
"the clean-up, please remove manually " "the clean-up, please remove manually "
...@@ -395,7 +395,7 @@ def print_compiledir_content(): ...@@ -395,7 +395,7 @@ def print_compiledir_content():
if big_key_files: if big_key_files:
big_key_files = sorted(big_key_files, key=lambda t: str(t[1])) big_key_files = sorted(big_key_files, key=lambda t: str(t[1]))
big_total_size = sum([size for dir, size, ops in big_key_files]) big_total_size = sum([sz for _, sz, _ in big_key_files])
print(("There are directories with key files bigger than %d bytes " print(("There are directories with key files bigger than %d bytes "
"(they probably contain big tensor constants)" % "(they probably contain big tensor constants)" %
max_key_file_size)) max_key_file_size))
......
...@@ -102,8 +102,8 @@ def get_lock(lock_dir=None, **kw): ...@@ -102,8 +102,8 @@ def get_lock(lock_dir=None, **kw):
# the lock state and raise an error. # the lock state and raise an error.
while get_lock.n_lock > 0: while get_lock.n_lock > 0:
release_lock() release_lock()
raise Exception("For some unknow reason, the lock was already taken," raise Exception("For some unknow reason, the lock was already "
" but no start time was registered.") "taken, but no start time was registered.")
now = time.time() now = time.time()
if now - get_lock.start_time > config.compile.timeout/2: if now - get_lock.start_time > config.compile.timeout/2:
lockpath = os.path.join(get_lock.lock_dir, 'lock') lockpath = os.path.join(get_lock.lock_dir, 'lock')
......
...@@ -14,18 +14,21 @@ if os.path.exists(os.path.join(config.compiledir, 'cutils_ext.so')): ...@@ -14,18 +14,21 @@ if os.path.exists(os.path.join(config.compiledir, 'cutils_ext.so')):
def compile_cutils_code(): def compile_cutils_code():
types = ['npy_' + t for t in ['int8', 'int16', 'int32', 'int64', 'int128', types = ['npy_' + t for t in ['int8', 'int16', 'int32', 'int64', 'int128',
'int256', 'uint8', 'uint16', 'uint32', 'uint64', 'uint128', 'uint256', 'int256', 'uint8', 'uint16', 'uint32',
'float16', 'float32', 'float64', 'float80', 'float96', 'float128', 'uint64', 'uint128', 'uint256',
'float16', 'float32', 'float64',
'float80', 'float96', 'float128',
'float256']] 'float256']]
complex_types = ['npy_' + t for t in ['complex32', 'complex64', complex_types = ['npy_' + t for t in ['complex32', 'complex64',
'complex128', 'complex160', 'complex192', 'complex512']] 'complex128', 'complex160',
'complex192', 'complex512']]
inplace_map_template = """ inplace_map_template = """
#if defined(%(typen)s) #if defined(%(typen)s)
static void %(type)s_inplace_add(PyArrayMapIterObject *mit, PyArrayIterObject *it, int inc_or_set) static void %(type)s_inplace_add(PyArrayMapIterObject *mit,
PyArrayIterObject *it, int inc_or_set)
{ {
int index = mit->size; int index = mit->size;
while (index--) { while (index--) {
...@@ -38,10 +41,13 @@ def compile_cutils_code(): ...@@ -38,10 +41,13 @@ def compile_cutils_code():
#endif #endif
""" """
floatadd = "((%(type)s*)mit->dataptr)[0] = inc_or_set * ((%(type)s*)mit->dataptr)[0] + ((%(type)s*)it->dataptr)[0];" floatadd = ("((%(type)s*)mit->dataptr)[0] = inc_or_set * "
"((%(type)s*)mit->dataptr)[0] + ((%(type)s*)it->dataptr)[0];")
complexadd = """ complexadd = """
((%(type)s*)mit->dataptr)[0].real = inc_or_set * ((%(type)s*)mit->dataptr)[0].real + ((%(type)s*)it->dataptr)[0].real; ((%(type)s*)mit->dataptr)[0].real = inc_or_set *
((%(type)s*)mit->dataptr)[0].imag = inc_or_set * ((%(type)s*)mit->dataptr)[0].imag + ((%(type)s*)it->dataptr)[0].imag; ((%(type)s*)mit->dataptr)[0].real + ((%(type)s*)it->dataptr)[0].real;
((%(type)s*)mit->dataptr)[0].imag = inc_or_set *
((%(type)s*)mit->dataptr)[0].imag + ((%(type)s*)it->dataptr)[0].imag;
""" """
fns = ''.join([inplace_map_template % {'type': t, 'typen': t.upper(), fns = ''.join([inplace_map_template % {'type': t, 'typen': t.upper(),
...@@ -51,33 +57,36 @@ def compile_cutils_code(): ...@@ -51,33 +57,36 @@ def compile_cutils_code():
'op': complexadd % {'type': t}} 'op': complexadd % {'type': t}}
for t in complex_types]) for t in complex_types])
def gen_binop(type, typen):
return """
#if defined(%(typen)s)
%(type)s_inplace_add,
#endif
""" % dict(type=type, typen=typen)
fn_array = ("static inplace_map_binop addition_funcs[] = {" + fn_array = ("static inplace_map_binop addition_funcs[] = {" +
''.join([""" ''.join([gen_binop(type=t, typen=t.upper())
#if defined(%(typen)s) for t in types + complex_types]) + "NULL};\n")
%(type)s_inplace_add,
#endif def gen_num(typen):
""" % {'type': t, 'typen': t.upper()} return """
for t in types + complex_types]) + #if defined(%(typen)s)
"""NULL}; %(typen)s,
""") #endif
""" % dict(type=type, typen=typen)
type_number_array = ("static int type_numbers[] = {" + type_number_array = ("static int type_numbers[] = {" +
''.join([""" ''.join([gen_num(typen=t.upper())
#if defined(%(typen)s) for t in types + complex_types]) + "-1000};")
%(typen)s,
#endif
""" % {'type': t, 'typen': t.upper()}
for t in types + complex_types]) +
"-1000};")
code = (""" code = ("""
#if NPY_API_VERSION >= 0x00000008 #if NPY_API_VERSION >= 0x00000008
typedef void (*inplace_map_binop)(PyArrayMapIterObject *, PyArrayIterObject *, int inc_or_set); typedef void (*inplace_map_binop)(PyArrayMapIterObject *,
""" + fns + fn_array + type_number_array + PyArrayIterObject *, int inc_or_set);
""" + fns + fn_array + type_number_array + """
"""
static int static int
map_increment(PyArrayMapIterObject *mit, PyObject *op, inplace_map_binop add_inplace, int inc_or_set) map_increment(PyArrayMapIterObject *mit, PyObject *op,
inplace_map_binop add_inplace, int inc_or_set)
{ {
PyArrayObject *arr = NULL; PyArrayObject *arr = NULL;
PyArrayIterObject *it; PyArrayIterObject *it;
...@@ -129,7 +138,8 @@ inplace_increment(PyObject *dummy, PyObject *args) ...@@ -129,7 +138,8 @@ inplace_increment(PyObject *dummy, PyObject *args)
return NULL; return NULL;
} }
if (!PyArray_Check(arg_a)) { if (!PyArray_Check(arg_a)) {
PyErr_SetString(PyExc_ValueError, "needs an ndarray as first argument"); PyErr_SetString(PyExc_ValueError,
"needs an ndarray as first argument");
return NULL; return NULL;
} }
...@@ -285,7 +295,7 @@ try: ...@@ -285,7 +295,7 @@ try:
open(os.path.join(location, '__init__.py'), 'w').close() open(os.path.join(location, '__init__.py'), 'w').close()
try: try:
from cutils_ext.cutils_ext import * from cutils_ext.cutils_ext import * # noqa
except ImportError: except ImportError:
get_lock() get_lock()
# Ensure no-one else is currently modifying the content of the compilation # Ensure no-one else is currently modifying the content of the compilation
...@@ -296,11 +306,11 @@ try: ...@@ -296,11 +306,11 @@ try:
# We must retry to import it as some other process could # We must retry to import it as some other process could
# have been compiling it between the first failed import # have been compiling it between the first failed import
# and when we receive the lock # and when we receive the lock
from cutils_ext.cutils_ext import * from cutils_ext.cutils_ext import * # noqa
except ImportError: except ImportError:
compile_cutils() compile_cutils()
from cutils_ext.cutils_ext import * from cutils_ext.cutils_ext import * # noqa
finally: finally:
# Release lock on compilation directory. # Release lock on compilation directory.
......
...@@ -15,12 +15,13 @@ _logger = logging.getLogger('theano.gof.lazylinker_c') ...@@ -15,12 +15,13 @@ _logger = logging.getLogger('theano.gof.lazylinker_c')
force_compile = False force_compile = False
version = 0.21 # must match constant returned in function get_version() version = 0.21 # must match constant returned in function get_version()
lazylinker_ext = None
def try_import(): def try_import():
global lazylinker_ext global lazylinker_ext
sys.path[0:0] = [config.compiledir] sys.path[0:0] = [config.compiledir]
import lazylinker_ext import lazylinker_ext # noqa
del sys.path[0] del sys.path[0]
...@@ -43,11 +44,11 @@ try: ...@@ -43,11 +44,11 @@ try:
# Try to make the location # Try to make the location
os.mkdir(location) os.mkdir(location)
except OSError as e: except OSError as e:
# If we get an error, verify that the error was # 17, the path already exists, # If we get an error, verify that the error was # 17, the
# and that it is a directory # path already exists, and that it is a directory Note: we
# Note: we can't check if it exists before making it, because we are not holding # can't check if it exists before making it, because we
# the lock right now, so we could race another process and get error 17 if we lose # are not holding the lock right now, so we could race
# the race # another process and get error 17 if we lose the race
assert e.errno == errno.EEXIST assert e.errno == errno.EEXIST
assert os.path.isdir(location) assert os.path.isdir(location)
...@@ -142,5 +143,5 @@ except ImportError: ...@@ -142,5 +143,5 @@ except ImportError:
# Release lock on compilation directory. # Release lock on compilation directory.
release_lock() release_lock()
from lazylinker_ext.lazylinker_ext import * from lazylinker_ext.lazylinker_ext import * # noqa
assert force_compile or (version == get_version()) assert force_compile or (version == get_version())
...@@ -32,7 +32,7 @@ class DB(object): ...@@ -32,7 +32,7 @@ class DB(object):
self.__db__ = DefaultOrderedDict(OrderedSet) self.__db__ = DefaultOrderedDict(OrderedSet)
self._names = set() self._names = set()
self.name = None # will be reset by register self.name = None # will be reset by register
#(via obj.name by the thing doing the registering) # (via obj.name by the thing doing the registering)
def register(self, name, obj, *tags, **kwargs): def register(self, name, obj, *tags, **kwargs):
""" """
...@@ -175,8 +175,10 @@ class Query(object): ...@@ -175,8 +175,10 @@ class Query(object):
self.exclude = OrderedSet(self.exclude) self.exclude = OrderedSet(self.exclude)
def __str__(self): def __str__(self):
return "Query{inc=%s,ex=%s,require=%s,subquery=%s,position_cutoff=%d}" % ( return ("Query{inc=%s,ex=%s,require=%s,subquery=%s,"
self.include, self.exclude, self.require, self.subquery, self.position_cutoff) "position_cutoff=%d}" %
(self.include, self.exclude, self.require, self.subquery,
self.position_cutoff))
# add all opt with this tag # add all opt with this tag
def including(self, *tags): def including(self, *tags):
...@@ -268,7 +270,7 @@ class SequenceDB(DB): ...@@ -268,7 +270,7 @@ class SequenceDB(DB):
position_cutoff = kwtags.pop('position_cutoff', position_cutoff = kwtags.pop('position_cutoff',
config.optdb.position_cutoff) config.optdb.position_cutoff)
if len(tags) >= 1 and isinstance(tags[0], Query): if len(tags) >= 1 and isinstance(tags[0], Query):
# the call to super should have raise an error with a good message # the call to super should have raise an error with a good message
assert len(tags) == 1 assert len(tags) == 1
if getattr(tags[0], 'position_cutoff', None): if getattr(tags[0], 'position_cutoff', None):
position_cutoff = tags[0].position_cutoff position_cutoff = tags[0].position_cutoff
......
...@@ -39,8 +39,8 @@ def make_depends(): ...@@ -39,8 +39,8 @@ def make_depends():
def depends(pair): def depends(pair):
""" Returns True if a depends on b """ """ Returns True if a depends on b """
a, b = pair a, b = pair
return (any(bout in a.inputs for bout in b.outputs) return (any(bout in a.inputs for bout in b.outputs) or
or any(depends((ainp.owner, b)) for ainp in a.inputs any(depends((ainp.owner, b)) for ainp in a.inputs
if ainp.owner)) if ainp.owner))
return depends return depends
...@@ -160,12 +160,12 @@ def posort(l, *cmps): ...@@ -160,12 +160,12 @@ def posort(l, *cmps):
for b in l: for b in l:
assert not(b in comes_after[a] and a in comes_after[b]) assert not(b in comes_after[a] and a in comes_after[b])
for cmp in cmps: for cmp_fn in cmps:
for a in l: for a in l:
for b in l: for b in l:
if cmp(a, b) < 0: # a wants to come before b if cmp_fn(a, b) < 0: # a wants to come before b
# if this wouldn't cause a cycle and isn't already known # if this wouldn't cause a cycle and isn't already known
if not b in comes_before[a] and not b in comes_after[a]: if b not in comes_before[a] and b not in comes_after[a]:
add_links(a, b) add_links(a, b)
# check() # debug code # check() # debug code
......
...@@ -36,8 +36,11 @@ def test_give_variables_names_small(): ...@@ -36,8 +36,11 @@ def test_give_variables_names_small():
def test_remove(): def test_remove():
even = lambda x: x % 2 == 0 def even(x):
odd = lambda x: x % 2 == 1 return x % 2 == 0
# The list are neede as with python 3, remove and filter return generators
def odd(x):
return x % 2 == 1
# The list are needed as with python 3, remove and filter return generators
# and we can't compare generators. # and we can't compare generators.
assert list(remove(even, range(5))) == list(filter(odd, range(5))) assert list(remove(even, range(5))) == list(filter(odd, range(5)))
...@@ -214,9 +214,9 @@ class Validator(Feature): ...@@ -214,9 +214,9 @@ class Validator(Feature):
class ReplaceValidate(History, Validator): class ReplaceValidate(History, Validator):
pickle_rm_attr = ["replace_validate", "replace_all_validate", pickle_rm_attr = (["replace_validate", "replace_all_validate",
"replace_all_validate_remove"] + \ "replace_all_validate_remove"] +
History.pickle_rm_attr + Validator.pickle_rm_attr History.pickle_rm_attr + Validator.pickle_rm_attr)
def on_attach(self, fgraph): def on_attach(self, fgraph):
for attr in ('replace_validate', 'replace_all_validate', for attr in ('replace_validate', 'replace_all_validate',
...@@ -256,11 +256,13 @@ class ReplaceValidate(History, Validator): ...@@ -256,11 +256,13 @@ class ReplaceValidate(History, Validator):
try: try:
fgraph.replace(r, new_r, reason=reason, verbose=False) fgraph.replace(r, new_r, reason=reason, verbose=False)
except Exception as e: except Exception as e:
if ('The type of the replacement must be the same' not in msg = str(e)
str(e) and 'does not belong to this FunctionGraph' not in str(e)): s1 = 'The type of the replacement must be the same'
s2 = 'does not belong to this FunctionGraph'
if (s1 not in msg and s2 not in msg):
out = sys.stderr out = sys.stderr
print("<<!! BUG IN FGRAPH.REPLACE OR A LISTENER !!>>", end=' ', file=out) print("<<!! BUG IN FGRAPH.REPLACE OR A LISTENER !!>>",
print(type(e), e, reason, file=out) type(e), e, reason, file=out)
# this might fail if the error is in a listener: # this might fail if the error is in a listener:
# (fgraph.replace kinda needs better internal error handling) # (fgraph.replace kinda needs better internal error handling)
fgraph.revert(chk) fgraph.revert(chk)
...@@ -286,13 +288,14 @@ class ReplaceValidate(History, Validator): ...@@ -286,13 +288,14 @@ class ReplaceValidate(History, Validator):
fgraph.revert(chk) fgraph.revert(chk)
if warn: if warn:
out = sys.stderr out = sys.stderr
print(( print(
"WARNING: An optimization wanted to replace a Variable" "WARNING: An optimization wanted to replace a Variable"
" in the graph, but the replacement for it doesn't" " in the graph, but the replacement for it doesn't"
" remove it. We disabled the optimization." " remove it. We disabled the optimization."
" Your function runs correctly, but it would be" " Your function runs correctly, but it would be"
" appreciated if you submit this problem to the" " appreciated if you submit this problem to the"
" mailing list theano-users so that we can fix it."), file=out) " mailing list theano-users so that we can fix it.",
file=out)
print(reason, replacements, file=out) print(reason, replacements, file=out)
raise ReplacementDidntRemovedError() raise ReplacementDidntRemovedError()
...@@ -311,7 +314,8 @@ class NodeFinder(Bookkeeper): ...@@ -311,7 +314,8 @@ class NodeFinder(Bookkeeper):
def on_attach(self, fgraph): def on_attach(self, fgraph):
if self.fgraph is not None: if self.fgraph is not None:
raise Exception("A NodeFinder instance can only serve one FunctionGraph.") raise Exception("A NodeFinder instance can only serve one "
"FunctionGraph.")
if hasattr(fgraph, 'get_nodes'): if hasattr(fgraph, 'get_nodes'):
raise AlreadyThere("NodeFinder is already present or in conflict" raise AlreadyThere("NodeFinder is already present or in conflict"
" with another plugin.") " with another plugin.")
......
"""WRITEME Defines the `Type` class.""" """WRITEME Defines the `Type` class."""
__docformat__ = "restructuredtext en"
from theano.compat import PY3 from theano.compat import PY3
from theano.gof import utils from theano.gof import utils
...@@ -13,6 +10,8 @@ from theano.gof import graph ...@@ -13,6 +10,8 @@ from theano.gof import graph
######## ########
from theano.gof.op import CLinkerObject from theano.gof.op import CLinkerObject
__docformat__ = "restructuredtext en"
class CLinkerType(CLinkerObject): class CLinkerType(CLinkerObject):
"""Interface specification for Types that can be arguments to a `CLinkerOp`. """Interface specification for Types that can be arguments to a `CLinkerOp`.
...@@ -45,7 +44,8 @@ class CLinkerType(CLinkerObject): ...@@ -45,7 +44,8 @@ class CLinkerType(CLinkerObject):
- `MethodNotDefined`: Subclass does not implement this method - `MethodNotDefined`: Subclass does not implement this method
""" """
raise MethodNotDefined("c_literal", type(self), self.__class__.__name__) raise MethodNotDefined("c_literal", type(self),
self.__class__.__name__)
def c_declare(self, name, sub, check_input=True): def c_declare(self, name, sub, check_input=True):
"""Required: Return c code to declare variables that will be """Required: Return c code to declare variables that will be
...@@ -56,7 +56,8 @@ class CLinkerType(CLinkerObject): ...@@ -56,7 +56,8 @@ class CLinkerType(CLinkerObject):
return "PyObject ** addr_of_%(name)s;" return "PyObject ** addr_of_%(name)s;"
:param name: the name of the ``PyObject *`` pointer that will the value for this Type :param name: the name of the ``PyObject *`` pointer that will
the value for this Type
:type name: string :type name: string
...@@ -138,7 +139,8 @@ class CLinkerType(CLinkerObject): ...@@ -138,7 +139,8 @@ class CLinkerType(CLinkerObject):
- `MethodNotDefined`: Subclass does not implement this method - `MethodNotDefined`: Subclass does not implement this method
""" """
raise MethodNotDefined("c_extract", type(self), self.__class__.__name__) raise MethodNotDefined("c_extract", type(self),
self.__class__.__name__)
def c_extract_out(self, name, sub, check_input=True): def c_extract_out(self, name, sub, check_input=True):
"""Optional: C code to extract a PyObject * instance. """Optional: C code to extract a PyObject * instance.
...@@ -184,11 +186,12 @@ class CLinkerType(CLinkerObject): ...@@ -184,11 +186,12 @@ class CLinkerType(CLinkerObject):
def c_sync(self, name, sub): def c_sync(self, name, sub):
"""Required: Return c code to pack C types back into a PyObject. """Required: Return c code to pack C types back into a PyObject.
The code returned from this function must be templated using "%(name)s", The code returned from this function must be templated using
representing the name that the caller wants to call this Variable. The "%(name)s", representing the name that the caller wants to
returned code may set "py_%(name)s" to a PyObject* and that PyObject* call this Variable. The returned code may set "py_%(name)s"
will be accessible from Python via variable.data. Do not forget to adjust to a PyObject* and that PyObject* will be accessible from
reference counts if "py_%(name)s" is changed from its original value. Python via variable.data. Do not forget to adjust reference
counts if "py_%(name)s" is changed from its original value.
:Parameters: :Parameters:
- `name`: WRITEME - `name`: WRITEME
...@@ -205,10 +208,11 @@ class CLinkerType(CLinkerObject): ...@@ -205,10 +208,11 @@ class CLinkerType(CLinkerObject):
def c_code_cache_version(self): def c_code_cache_version(self):
"""Return a tuple of integers indicating the version of this Type. """Return a tuple of integers indicating the version of this Type.
An empty tuple indicates an 'unversioned' Type that will not be cached between processes. An empty tuple indicates an 'unversioned' Type that will not
be cached between processes.
The cache mechanism may erase cached modules that have been superceded by newer The cache mechanism may erase cached modules that have been
versions. See `ModuleCache` for details. superceded by newer versions. See `ModuleCache` for details.
""" """
return () return ()
...@@ -221,19 +225,21 @@ class PureType(object): ...@@ -221,19 +225,21 @@ class PureType(object):
- creating `Variable` instances (conventionally, `__call__` does this), and - creating `Variable` instances (conventionally, `__call__` does this), and
- filtering a value assigned to a `Variable` so that the value conforms to restrictions - filtering a value assigned to a `Variable` so that the value
imposed by the type (also known as casting, this is done by `filter`), conforms to restrictions imposed by the type (also known as
casting, this is done by `filter`),
""" """
# the type that will be created by call to make_variable.
Variable = graph.Variable
Variable = graph.Variable # the type that will be created by call to make_variable. # the type that will be created by call to make_constant
Constant = graph.Constant # the type that will be created by call to make_constant Constant = graph.Constant
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None):
"""Required: Return data or an appropriately wrapped/converted data. """Required: Return data or an appropriately wrapped/converted data.
Subclass implementation should raise a TypeError exception if the data is not of an Subclass implementation should raise a TypeError exception if
acceptable type. the data is not of an acceptable type.
If strict is True, the data returned must be the same as the If strict is True, the data returned must be the same as the
data passed as an argument. If it is False, and allow_downcast data passed as an argument. If it is False, and allow_downcast
...@@ -283,7 +289,8 @@ class PureType(object): ...@@ -283,7 +289,8 @@ class PureType(object):
return other return other
def is_valid_value(self, a): def is_valid_value(self, a):
"""Required: Return True for any python object `a` that would be a legal value for a Variable of this Type""" """Required: Return True for any python object `a` that would be a
legal value for a Variable of this Type"""
try: try:
self.filter(a, strict=True) self.filter(a, strict=True)
return True return True
...@@ -291,7 +298,8 @@ class PureType(object): ...@@ -291,7 +298,8 @@ class PureType(object):
return False return False
def value_validity_msg(self, a): def value_validity_msg(self, a):
"""Optional: return a message explaining the output of is_valid_value""" """Optional: return a message explaining the output of
is_valid_value"""
return "none" return "none"
def make_variable(self, name=None): def make_variable(self, name=None):
...@@ -371,7 +379,8 @@ class Type(object2, PureType, CLinkerType): ...@@ -371,7 +379,8 @@ class Type(object2, PureType, CLinkerType):
But you are encouraged to write your own, as described in WRITEME. But you are encouraged to write your own, as described in WRITEME.
The following following code illustrates the use of a Type instance, here tensor.fvector: The following following code illustrates the use of a Type
instance, here tensor.fvector:
.. code-block:: python .. code-block:: python
...@@ -381,17 +390,21 @@ class Type(object2, PureType, CLinkerType): ...@@ -381,17 +390,21 @@ class Type(object2, PureType, CLinkerType):
# Create a second Variable with the same Type instance # Create a second Variable with the same Type instance
c = tensor.fvector() c = tensor.fvector()
Whenever you create a symbolic variable in theano (technically, `Variable`) it will contain a Whenever you create a symbolic variable in theano (technically,
reference to a Type instance. That reference is typically constant during the lifetime of `Variable`) it will contain a reference to a Type instance. That
the Variable. Many variables can refer to a single Type instance, as do b and c above. The reference is typically constant during the lifetime of the
Type instance defines the kind of value which might end up in that variable when executing Variable. Many variables can refer to a single Type instance, as
a `Function`. In this sense, theano is like a strongly-typed language because the types do b and c above. The Type instance defines the kind of value
are included in the graph before the values. In our example above, b is a Variable which is which might end up in that variable when executing a `Function`.
guaranteed to correspond to a numpy.ndarray of rank 1 when we try to do some computations In this sense, theano is like a strongly-typed language because
the types are included in the graph before the values. In our
example above, b is a Variable which is guaranteed to correspond
to a numpy.ndarray of rank 1 when we try to do some computations
with it. with it.
Many `Op` instances will raise an exception if they are applied to inputs with incorrect Many `Op` instances will raise an exception if they are applied to
types. Type references are also useful to do type-checking in pattern-based optimizations. inputs with incorrect types. Type references are also useful to
do type-checking in pattern-based optimizations.
""" """
def convert_variable(self, var): def convert_variable(self, var):
...@@ -451,8 +464,8 @@ class Generic(SingletonType): ...@@ -451,8 +464,8 @@ class Generic(SingletonType):
""" """
Represents a generic Python object. Represents a generic Python object.
This class implements the `PureType` and `CLinkerType` interfaces for generic PyObject This class implements the `PureType` and `CLinkerType` interfaces
instances. for generic PyObject instances.
EXAMPLE of what this means, or when you would use this type. EXAMPLE of what this means, or when you would use this type.
......
from __future__ import print_function from __future__ import print_function
import linecache import linecache
import traceback import traceback
import re
import sys import sys
from theano import config from theano import config
...@@ -15,7 +14,6 @@ def simple_extract_stack(f=None, limit=None): ...@@ -15,7 +14,6 @@ def simple_extract_stack(f=None, limit=None):
This is because this update cause an call to os.stat to get the This is because this update cause an call to os.stat to get the
line content. This cause too much long on cluster. line content. This cause too much long on cluster.
""" """
if f is None: if f is None:
try: try:
...@@ -48,7 +46,7 @@ if sys.version_info[:2] > (3, 4): ...@@ -48,7 +46,7 @@ if sys.version_info[:2] > (3, 4):
# I enable my implementation only for some python version just to # I enable my implementation only for some python version just to
# be sure the Python internal do not change. If this work with # be sure the Python internal do not change. If this work with
# other python version, you can enable it. # other python version, you can enable it.
simple_extract_stack = traceback.extract_stack simple_extract_stack = traceback.extract_stack # noqa
def add_tag_trace(thing, user_line=1): def add_tag_trace(thing, user_line=1):
...@@ -190,8 +188,8 @@ def deprecated(filename, msg=''): ...@@ -190,8 +188,8 @@ def deprecated(filename, msg=''):
def g(*args, **kwargs): def g(*args, **kwargs):
if printme[0]: if printme[0]:
print('WARNING: %s.%s deprecated. %s'\ print('WARNING: %s.%s deprecated. %s' %
% (filename, f.__name__, msg)) (filename, f.__name__, msg))
printme[0] = False printme[0] = False
return f(*args, **kwargs) return f(*args, **kwargs)
return g return g
...@@ -220,7 +218,7 @@ def difference(seq1, seq2): ...@@ -220,7 +218,7 @@ def difference(seq1, seq2):
raise Exception('not worth it') raise Exception('not worth it')
set2 = set(seq2) set2 = set(seq2)
return [x for x in seq1 if x not in set2] return [x for x in seq1 if x not in set2]
except Exception as e: except Exception:
# maybe a seq2 element is not hashable # maybe a seq2 element is not hashable
# maybe seq2 is too short # maybe seq2 is too short
# -> use O(len(seq1) * len(seq2)) algo # -> use O(len(seq1) * len(seq2)) algo
...@@ -311,11 +309,11 @@ def comm_guard(type1, type2): ...@@ -311,11 +309,11 @@ def comm_guard(type1, type2):
old_f = f.func_globals[f.__name__] old_f = f.func_globals[f.__name__]
def new_f(arg1, arg2, *rest): def new_f(arg1, arg2, *rest):
if (type1 is ANY_TYPE or isinstance(arg1, type1)) \ if ((type1 is ANY_TYPE or isinstance(arg1, type1)) and
and (type2 is ANY_TYPE or isinstance(arg2, type2)): (type2 is ANY_TYPE or isinstance(arg2, type2))):
pass pass
elif (type1 is ANY_TYPE or isinstance(arg2, type1)) \ elif ((type1 is ANY_TYPE or isinstance(arg2, type1)) and
and (type2 is ANY_TYPE or isinstance(arg1, type2)): (type2 is ANY_TYPE or isinstance(arg1, type2))):
arg1, arg2 = arg2, arg1 arg1, arg2 = arg2, arg1
else: else:
return old_f(arg1, arg2, *rest) return old_f(arg1, arg2, *rest)
...@@ -337,7 +335,8 @@ def comm_guard(type1, type2): ...@@ -337,7 +335,8 @@ def comm_guard(type1, type2):
return type.__name__ return type.__name__
new_f.__doc__ = (str(old_f.__doc__) + "\n" + new_f.__doc__ = (str(old_f.__doc__) + "\n" +
", ".join([typename(type) for type in (type1, type2)]) + ", ".join([typename(type)
for type in (type1, type2)]) +
"\n" + str(f.__doc__ or "")) "\n" + str(f.__doc__ or ""))
return new_f return new_f
...@@ -406,15 +405,16 @@ def give_variables_names(variables): ...@@ -406,15 +405,16 @@ def give_variables_names(variables):
This function is idempotent.""" This function is idempotent."""
names = map(lambda var: var.name, variables) names = map(lambda var: var.name, variables)
h = hist(names) h = hist(names)
bad_var = lambda var: not var.name or h[var.name] > 1
def bad_var(var):
return not var.name or h[var.name] > 1
for i, var in enumerate(filter(bad_var, variables)): for i, var in enumerate(filter(bad_var, variables)):
var.name = (var.name or "") + "_%d" % i var.name = (var.name or "") + "_%d" % i
if not unique(map(str, variables)): if not unique(map(str, variables)):
raise ValueError("Not all variables have unique names." raise ValueError("Not all variables have unique names. Maybe you've "
"Maybe you've named some of the variables identically") "named some of the variables identically")
return variables return variables
......
...@@ -53,7 +53,8 @@ AddConfigVar('vm.lazy', ...@@ -53,7 +53,8 @@ AddConfigVar('vm.lazy',
in_c_key=False) in_c_key=False)
def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, dependencies): def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re,
dependencies):
reallocated_info = {} reallocated_info = {}
viewed_by = {} viewed_by = {}
for var in fgraph.variables: for var in fgraph.variables:
...@@ -74,14 +75,14 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend ...@@ -74,14 +75,14 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
ins = None ins = None
if dmap and idx_o in dmap: if dmap and idx_o in dmap:
idx_v = dmap[idx_o] idx_v = dmap[idx_o]
assert len( assert len(idx_v) == 1, ("Here we only support the possibility"
idx_v) == 1, "Here we only support the possibility to destroy one input" " to destroy one input")
ins = node.inputs[idx_v[0]] ins = node.inputs[idx_v[0]]
if vmap and idx_o in vmap: if vmap and idx_o in vmap:
assert ins is None assert ins is None
idx_v = vmap[idx_o] idx_v = vmap[idx_o]
assert len( assert len(idx_v) == 1, ("Here we only support the possibility"
idx_v) == 1, "Here we only support the possibility to view one input" " to view one input")
ins = node.inputs[idx_v[0]] ins = node.inputs[idx_v[0]]
if ins is not None: if ins is not None:
assert isinstance(ins, theano.Variable) assert isinstance(ins, theano.Variable)
...@@ -92,10 +93,11 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend ...@@ -92,10 +93,11 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
for ins in node.inputs: for ins in node.inputs:
assert not (ins in view_of and viewed_by[ins]) assert not (ins in view_of and viewed_by[ins])
if (getattr(ins, 'ndim', None) == 0 and not storage_map[ins][0] if (getattr(ins, 'ndim', None) == 0 and not storage_map[ins][0] and
and ins not in fgraph.outputs and ins.owner ins not in fgraph.outputs and ins.owner and
and all([compute_map_re[v][0] for v in dependencies.get(ins, [])]) all([compute_map_re[v][0]
and ins not in allocated): for v in dependencies.get(ins, [])]) and
ins not in allocated):
# Constant Memory cannot be changed # Constant Memory cannot be changed
# Constant and shared variables' storage_map value is not empty # Constant and shared variables' storage_map value is not empty
reuse_out = None reuse_out = None
...@@ -105,8 +107,9 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend ...@@ -105,8 +107,9 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
if reuse_out: if reuse_out:
break break
for out in order[i].outputs: for out in order[i].outputs:
if (getattr(out, 'ndim', None) == 0 and out not in pre_allocated if (getattr(out, 'ndim', None) == 0 and
and ins.type == out.type): out not in pre_allocated and
ins.type == out.type):
reuse_out = out reuse_out = out
pre_allocated.add(out) pre_allocated.add(out)
allocated.add(ins) allocated.add(ins)
...@@ -122,8 +125,9 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend ...@@ -122,8 +125,9 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
if reuse_out: if reuse_out:
break break
for out in order[i].outputs: for out in order[i].outputs:
if (getattr(out, 'ndim', None) == 0 and out not in pre_allocated if (getattr(out, 'ndim', None) == 0 and
and ins.type == out.type): out not in pre_allocated and
ins.type == out.type):
reuse_out = out reuse_out = out
pre_allocated.add(out) pre_allocated.add(out)
allocated.add(ins) allocated.add(ins)
...@@ -508,7 +512,8 @@ class Stack(VM): ...@@ -508,7 +512,8 @@ class Stack(VM):
st = "c" st = "c"
self.variable_strides[var] = st self.variable_strides[var] = st
except Exception: except Exception:
link.raise_with_op(current_apply, link.raise_with_op(
current_apply,
self.thunks[self.node_idx[current_apply]], self.thunks[self.node_idx[current_apply]],
storage_map=storage_map) storage_map=storage_map)
for o in current_apply.outputs: for o in current_apply.outputs:
...@@ -521,9 +526,9 @@ class Stack(VM): ...@@ -521,9 +526,9 @@ class Stack(VM):
for i in current_apply.inputs: for i in current_apply.inputs:
# Garbage Collection -> check if anybody else uses # Garbage Collection -> check if anybody else uses
# this input # this input
if (dependencies[i] if (dependencies[i] and
and i.owner i.owner and
and i not in self.outputs): i not in self.outputs):
if all(compute_map[v][0] if all(compute_map[v][0]
for v in dependencies[i]): for v in dependencies[i]):
storage_map[i][0] = None storage_map[i][0] = None
...@@ -544,10 +549,13 @@ class Stack(VM): ...@@ -544,10 +549,13 @@ class Stack(VM):
'destroy_map', 'destroy_map',
False)): False)):
warnings.warn( warnings.warn(
"There was a bug that existed in the default Theano configuration," "There was a bug that existed in "
" only in the development version between July 5th 2012" "the default Theano configuration,"
" and July 30th 2012. This was not in a released version." " only in the development version "
" The bug was affecting this script.", "between July 5th 2012 and "
"July 30th 2012. This was not in "
"a released version. The bug was "
"affecting this script.",
# The stack level is not good when # The stack level is not good when
# inside a Scan. # inside a Scan.
stacklevel=3 stacklevel=3
...@@ -578,7 +586,8 @@ class Stack(VM): ...@@ -578,7 +586,8 @@ class Stack(VM):
self.call_times[current_idx] += dt self.call_times[current_idx] += dt
except Exception: except Exception:
link.raise_with_op(current_apply, link.raise_with_op(
current_apply,
self.thunks[self.node_idx[current_apply]], self.thunks[self.node_idx[current_apply]],
storage_map=storage_map) storage_map=storage_map)
...@@ -639,7 +648,7 @@ class Stack(VM): ...@@ -639,7 +648,7 @@ class Stack(VM):
if self.allow_gc: if self.allow_gc:
for v in storage_map: for v in storage_map:
if v.owner and not v in self.outputs: if v.owner and v not in self.outputs:
if compute_map[v][0] == 2: if compute_map[v][0] == 2:
continue continue
else: else:
...@@ -840,7 +849,6 @@ class VM_Linker(link.LocalLinker): ...@@ -840,7 +849,6 @@ class VM_Linker(link.LocalLinker):
vars_idx_inv[i] = var vars_idx_inv[i] = var
# put storage_map and compute_map into a int-based scheme # put storage_map and compute_map into a int-based scheme
n_applies = len(nodes)
storage_map_list = [storage_map[vars_idx_inv[i]] storage_map_list = [storage_map[vars_idx_inv[i]]
for i in xrange(len(vars_idx_inv))] for i in xrange(len(vars_idx_inv))]
compute_map_list = [compute_map[vars_idx_inv[i]] compute_map_list = [compute_map[vars_idx_inv[i]]
...@@ -988,7 +996,8 @@ class VM_Linker(link.LocalLinker): ...@@ -988,7 +996,8 @@ class VM_Linker(link.LocalLinker):
else: else:
dependencies = self.compute_gc_dependencies(storage_map) dependencies = self.compute_gc_dependencies(storage_map)
reallocated_info = calculate_reallocate_info(order, fgraph, storage_map, compute_map_re,dependencies) reallocated_info = calculate_reallocate_info(
order, fgraph, storage_map, compute_map_re, dependencies)
for node in order: for node in order:
try: try:
...@@ -1014,7 +1023,8 @@ class VM_Linker(link.LocalLinker): ...@@ -1014,7 +1023,8 @@ class VM_Linker(link.LocalLinker):
lazy = config.vm.lazy lazy = config.vm.lazy
if lazy is None: if lazy is None:
lazy = not all([(not th.lazy) for th in thunks]) lazy = not all([(not th.lazy) for th in thunks])
if not (lazy or (config.profile and config.profile_memory) or self.use_cloop or self.callback): if not (lazy or (config.profile and config.profile_memory) or
self.use_cloop or self.callback):
for pair in reallocated_info.values(): for pair in reallocated_info.values():
storage_map[pair[1]] = storage_map[pair[0]] storage_map[pair[1]] = storage_map[pair[0]]
...@@ -1024,10 +1034,10 @@ class VM_Linker(link.LocalLinker): ...@@ -1024,10 +1034,10 @@ class VM_Linker(link.LocalLinker):
for node in order: for node in order:
clear_after_this_thunk = [] clear_after_this_thunk = []
for input in node.inputs: for input in node.inputs:
if ((input in computed) if (input in computed and
and (input not in fgraph.outputs) input not in fgraph.outputs and
and (node == last_user[input]) node == last_user[input] and
and input not in reallocated_info.keys()): input not in reallocated_info.keys()):
clear_after_this_thunk.append(storage_map[input]) clear_after_this_thunk.append(storage_map[input])
post_thunk_clear.append(clear_after_this_thunk) post_thunk_clear.append(clear_after_this_thunk)
else: else:
......
...@@ -2,7 +2,6 @@ from functools import wraps ...@@ -2,7 +2,6 @@ from functools import wraps
import numpy import numpy
import theano
from theano import scalar as scal, Constant from theano import scalar as scal, Constant
from theano.gof import local_optimizer from theano.gof import local_optimizer
from theano.tensor import (DimShuffle, get_scalar_constant_value, from theano.tensor import (DimShuffle, get_scalar_constant_value,
......
...@@ -7,13 +7,13 @@ from theano.tests import unittest_tools as utt ...@@ -7,13 +7,13 @@ from theano.tests import unittest_tools as utt
# Skip tests if cuda_ndarray is not available. # Skip tests if cuda_ndarray is not available.
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
import theano.sandbox.cuda as cuda_ndarray import theano.sandbox.cuda as cuda_ndarray
if not cuda_ndarray.cuda_available: if not cuda_ndarray.cuda_available: # noqa
raise SkipTest('Optional package cuda not available') raise SkipTest('Optional package cuda not available')
from theano.misc.pycuda_init import pycuda_available from theano.misc.pycuda_init import pycuda_available
if not pycuda_available: if not pycuda_available: # noqa
raise SkipTest('Optional package pycuda not available') raise SkipTest('Optional package pycuda not available')
from theano.sandbox.cuda.fftconv import scikits_cuda_available from theano.sandbox.cuda.fftconv import scikits_cuda_available
if not scikits_cuda_available: if not scikits_cuda_available: # noqa
raise SkipTest('Optional package scikits.cuda not available') raise SkipTest('Optional package scikits.cuda not available')
from theano.sandbox.cuda import float32_shared_constructor as shared from theano.sandbox.cuda import float32_shared_constructor as shared
......
...@@ -2,13 +2,14 @@ ...@@ -2,13 +2,14 @@
# mpiexec -np 2 python _test_mpi_roundtrip.py # mpiexec -np 2 python _test_mpi_roundtrip.py
from mpi4py import MPI from mpi4py import MPI
comm = MPI.COMM_WORLD
import theano import theano
from theano.tensor.io import send, recv, mpi_cmps from theano.tensor.io import send, recv, mpi_cmps
from theano.gof.sched import sort_schedule_fn from theano.gof.sched import sort_schedule_fn
import numpy as np import numpy as np
from sys import stdout, stderr, exit from sys import stdout, stderr, exit
comm = MPI.COMM_WORLD
rank = comm.Get_rank() rank = comm.Get_rank()
size = comm.Get_size() size = comm.Get_size()
......
from datetime import datetime
__authors__ = "Ian Goodfellow" __authors__ = "Ian Goodfellow"
__credits__ = ["Ian Goodfellow"] __credits__ = ["Ian Goodfellow"]
__license__ = "3-clause BSD" __license__ = "3-clause BSD"
__maintainer__ = "Ian Goodfellow" __maintainer__ = "Ian Goodfellow"
__email__ = "goodfeli@iro" __email__ = "goodfeli@iro"
from datetime import datetime
def disturb_mem(): def disturb_mem():
# Allocate a time-dependent amount of objects to increase # Allocate a time-dependent amount of objects to increase
......
from __future__ import print_function from __future__ import print_function
import os, unittest, sys import os
import nose.plugins.builtin import unittest
import sys
from nose.config import Config from nose.config import Config
from nose.plugins.manager import PluginManager from nose.plugins.manager import PluginManager
from numpy.testing.nosetester import import_nose, NoseTester import nose.plugins.builtin
from numpy.testing.nosetester import NoseTester
from numpy.testing.noseclasses import KnownFailure, NumpyTestProgram from numpy.testing.noseclasses import KnownFailure, NumpyTestProgram
...@@ -31,7 +34,7 @@ class TheanoNoseTester(NoseTester): ...@@ -31,7 +34,7 @@ class TheanoNoseTester(NoseTester):
:type extra_argv: list :type extra_argv: list
:param extra_argv: List with any extra arguments to pass to nosetests. :param extra_argv: List with any extra arguments to pass to nosetests.
""" """
#self.package_path = os.path.abspath(self.package_path) # self.package_path = os.path.abspath(self.package_path)
argv = [__file__, self.package_path] argv = [__file__, self.package_path]
argv += ['--verbosity', str(verbose)] argv += ['--verbosity', str(verbose)]
if extra_argv: if extra_argv:
...@@ -39,8 +42,6 @@ class TheanoNoseTester(NoseTester): ...@@ -39,8 +42,6 @@ class TheanoNoseTester(NoseTester):
return argv return argv
def _show_system_info(self): def _show_system_info(self):
nose = import_nose()
import theano import theano
print("Theano version %s" % theano.__version__) print("Theano version %s" % theano.__version__)
theano_dir = os.path.dirname(theano.__file__) theano_dir = os.path.dirname(theano.__file__)
...@@ -55,16 +56,14 @@ class TheanoNoseTester(NoseTester): ...@@ -55,16 +56,14 @@ class TheanoNoseTester(NoseTester):
Takes the same arguments as `test`. Takes the same arguments as `test`.
""" """
# fail with nice error message if nose is not present
nose = import_nose()
# compile argv # compile argv
argv = self._test_argv(verbose, extra_argv) argv = self._test_argv(verbose, extra_argv)
# numpy way of doing coverage # numpy way of doing coverage
if coverage: if coverage:
argv += ['--cover-package=%s' % self.package_name, '--with-coverage', argv += ['--cover-package=%s' % self.package_name,
'--cover-tests', '--cover-inclusive', '--cover-erase'] '--with-coverage', '--cover-tests',
'--cover-inclusive', '--cover-erase']
# Capture output only if needed # Capture output only if needed
if not capture: if not capture:
...@@ -91,7 +90,8 @@ class TheanoNoseTester(NoseTester): ...@@ -91,7 +90,8 @@ class TheanoNoseTester(NoseTester):
:param extra_argv: List with any extra arguments to pass to nosetests. :param extra_argv: List with any extra arguments to pass to nosetests.
:type coverage: bool :type coverage: bool
:param coverage: If True, report coverage of Theano code. Default is False. :param coverage: If True, report coverage of Theano
code. Default is False.
:type capture: bool :type capture: bool
:param capture: If True, capture the standard output of the tests, like :param capture: If True, capture the standard output of the tests, like
...@@ -134,8 +134,6 @@ class TheanoNoseTester(NoseTester): ...@@ -134,8 +134,6 @@ class TheanoNoseTester(NoseTester):
def main(modulename): def main(modulename):
debug = False
if 0: if 0:
unittest.main() unittest.main()
elif len(sys.argv) == 2 and sys.argv[1] == "--debug": elif len(sys.argv) == 2 and sys.argv[1] == "--debug":
......
...@@ -20,7 +20,6 @@ __contact__ = "Saizheng Zhang <saizhenglisa..at..gmail.com>" ...@@ -20,7 +20,6 @@ __contact__ = "Saizheng Zhang <saizhenglisa..at..gmail.com>"
whitelist_flake8 = [ whitelist_flake8 = [
"__init__.py", "__init__.py",
"version.py",
"tests/test_gradient.py", "tests/test_gradient.py",
"tests/test_config.py", "tests/test_config.py",
"tests/diverse_tests.py", "tests/diverse_tests.py",
...@@ -31,37 +30,20 @@ whitelist_flake8 = [ ...@@ -31,37 +30,20 @@ whitelist_flake8 = [
"tests/test_record.py", "tests/test_record.py",
"tests/__init__.py", "tests/__init__.py",
"tests/test_updates.py", "tests/test_updates.py",
"tests/main.py",
"tests/test_pickle_unpickle_theano_fn.py", "tests/test_pickle_unpickle_theano_fn.py",
"tests/test_determinism.py", "tests/test_determinism.py",
"tests/record.py", "tests/record.py",
"tests/test_printing.py",
"tests/test_tutorial.py", "tests/test_tutorial.py",
"tests/disturb_mem.py",
"tests/unittest_tools.py", "tests/unittest_tools.py",
"compile/ops.py",
"compile/debugmode.py",
"compile/function.py",
"compile/pfunc.py",
"compile/mode.py",
"compile/profilemode.py",
"compile/builders.py",
"compile/__init__.py", "compile/__init__.py",
"compile/profiling.py", "compile/profiling.py",
"compile/function_module.py",
"compile/sharedvalue.py",
"compile/monitormode.py",
"compile/io.py",
"compile/module.py",
"compile/tests/test_builders.py", "compile/tests/test_builders.py",
"compile/tests/test_misc.py", "compile/tests/test_misc.py",
"compile/tests/test_monitormode.py", "compile/tests/test_monitormode.py",
"compile/tests/test_function_module.py", "compile/tests/test_function_module.py",
"compile/tests/test_inplace_opt_for_value.py",
"compile/tests/test_shared.py", "compile/tests/test_shared.py",
"compile/tests/test_ops.py", "compile/tests/test_ops.py",
"compile/tests/test_pfunc.py", "compile/tests/test_pfunc.py",
"compile/tests/test_module.py",
"compile/tests/test_debugmode.py", "compile/tests/test_debugmode.py",
"compile/tests/test_profiling.py", "compile/tests/test_profiling.py",
"typed_list/type.py", "typed_list/type.py",
...@@ -94,16 +76,13 @@ whitelist_flake8 = [ ...@@ -94,16 +76,13 @@ whitelist_flake8 = [
"tensor/io.py", "tensor/io.py",
"tensor/elemwise_cgen.py", "tensor/elemwise_cgen.py",
"tensor/raw_random.py", "tensor/raw_random.py",
"tensor/randomstreams.py",
"tensor/blas_scipy.py", "tensor/blas_scipy.py",
"tensor/basic.py", "tensor/basic.py",
"tensor/tests/test_subtensor.py", "tensor/tests/test_subtensor.py",
"tensor/tests/test_utils.py", "tensor/tests/test_utils.py",
"tensor/tests/test_nlinalg.py", "tensor/tests/test_nlinalg.py",
"tensor/tests/test_randomstreams.py",
"tensor/tests/test_shared_randomstreams.py", "tensor/tests/test_shared_randomstreams.py",
"tensor/tests/test_misc.py", "tensor/tests/test_misc.py",
"tensor/tests/test_naacl09.py",
"tensor/tests/mlp_test.py", "tensor/tests/mlp_test.py",
"tensor/tests/test_opt_uncanonicalize.py", "tensor/tests/test_opt_uncanonicalize.py",
"tensor/tests/test_opt.py", "tensor/tests/test_opt.py",
...@@ -155,7 +134,6 @@ whitelist_flake8 = [ ...@@ -155,7 +134,6 @@ whitelist_flake8 = [
"sandbox/test_theano_object.py", "sandbox/test_theano_object.py",
"sandbox/test_scan.py", "sandbox/test_scan.py",
"sandbox/rng_mrg.py", "sandbox/rng_mrg.py",
"sandbox/downsample.py",
"sandbox/solve.py", "sandbox/solve.py",
"sandbox/theano_object.py", "sandbox/theano_object.py",
"sandbox/scan.py", "sandbox/scan.py",
...@@ -190,7 +168,6 @@ whitelist_flake8 = [ ...@@ -190,7 +168,6 @@ whitelist_flake8 = [
"sandbox/cuda/nvcc_compiler.py", "sandbox/cuda/nvcc_compiler.py",
"sandbox/cuda/neighbours.py", "sandbox/cuda/neighbours.py",
"sandbox/cuda/tests/walltime.py", "sandbox/cuda/tests/walltime.py",
"sandbox/cuda/tests/test_fftconv.py",
"sandbox/cuda/tests/test_gradient.py", "sandbox/cuda/tests/test_gradient.py",
"sandbox/cuda/tests/test_neighbours.py", "sandbox/cuda/tests/test_neighbours.py",
"sandbox/cuda/tests/test_conv_cuda_ndarray.py", "sandbox/cuda/tests/test_conv_cuda_ndarray.py",
...@@ -218,7 +195,6 @@ whitelist_flake8 = [ ...@@ -218,7 +195,6 @@ whitelist_flake8 = [
"sandbox/scan_module/tests/test_utils.py", "sandbox/scan_module/tests/test_utils.py",
"sandbox/scan_module/tests/test_scan.py", "sandbox/scan_module/tests/test_scan.py",
"sandbox/linalg/ops.py", "sandbox/linalg/ops.py",
"sandbox/linalg/kron.py",
"sandbox/linalg/__init__.py", "sandbox/linalg/__init__.py",
"sandbox/linalg/tests/test_linalg.py", "sandbox/linalg/tests/test_linalg.py",
"sandbox/gpuarray/comp.py", "sandbox/gpuarray/comp.py",
...@@ -288,24 +264,12 @@ whitelist_flake8 = [ ...@@ -288,24 +264,12 @@ whitelist_flake8 = [
"sparse/sandbox/truedot.py", "sparse/sandbox/truedot.py",
"sparse/sandbox/sp.py", "sparse/sandbox/sp.py",
"gof/destroyhandler.py", "gof/destroyhandler.py",
"gof/vm.py",
"gof/cutils.py",
"gof/compiledir.py",
"gof/unify.py", "gof/unify.py",
"gof/lazylinker_c.py",
"gof/optdb.py",
"gof/utils.py",
"gof/graph.py", "gof/graph.py",
"gof/callcache.py",
"gof/python25.py",
"gof/type.py",
"gof/__init__.py", "gof/__init__.py",
"gof/cc.py", "gof/cc.py",
"gof/opt.py", "gof/opt.py",
"gof/compilelock.py",
"gof/link.py", "gof/link.py",
"gof/sched.py",
"gof/toolbox.py",
"gof/fg.py", "gof/fg.py",
"gof/op.py", "gof/op.py",
"gof/cmodule.py", "gof/cmodule.py",
...@@ -322,9 +286,6 @@ whitelist_flake8 = [ ...@@ -322,9 +286,6 @@ whitelist_flake8 = [
"gof/tests/test_cc.py", "gof/tests/test_cc.py",
"gof/tests/test_compute_test_value.py", "gof/tests/test_compute_test_value.py",
"gof/sandbox/equilibrium.py", "gof/sandbox/equilibrium.py",
"sandbox/cuda/opt_util.py",
"gof/tests/test_utils.py",
"tensor/tests/_test_mpi_roundtrip.py",
] ]
......
try: try:
from theano.generated_version import * from theano.generated_version import * # noqa
except ImportError: except ImportError:
short_version = 'unknown' short_version = 'unknown'
version = 'unknown' version = 'unknown'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论