提交 de5e06e7 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #1579 from nouiz/faster_opt

Faster opt
...@@ -36,3 +36,7 @@ Reference ...@@ -36,3 +36,7 @@ Reference
***TODO*** ***TODO***
.. note:: FunctionGraph(inputs, outputs) clone the inputs by
default. To don't have this behavior, add the parameter
clone=False. This is needed as we don't want cached constant
in fgraph.
...@@ -631,9 +631,8 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False): ...@@ -631,9 +631,8 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
updates = [spec.update for spec in input_specs if spec.update] updates = [spec.update for spec in input_specs if spec.update]
orig_outputs = [spec.variable for spec in output_specs] + updates orig_outputs = [spec.variable for spec in output_specs] + updates
inputs, outputs = gof.graph.clone(orig_inputs, orig_outputs)
equivalence_tracker = _VariableEquivalenceTracker() equivalence_tracker = _VariableEquivalenceTracker()
fgraph = gof.fg.FunctionGraph(inputs, outputs, fgraph = gof.fg.FunctionGraph(orig_inputs, orig_outputs,
# DestroyHandler may not be needed yet, as there is usually no # DestroyHandler may not be needed yet, as there is usually no
# inplace operation in the graph at this stage. DestroyHandler # inplace operation in the graph at this stage. DestroyHandler
# will be installed by an optimization after canonicalization, # will be installed by an optimization after canonicalization,
...@@ -658,7 +657,7 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False): ...@@ -658,7 +657,7 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
break break
# We need to protect all immutable inputs from inplace operations. # We need to protect all immutable inputs from inplace operations.
fgraph.attach_feature(Supervisor(input for spec, input in zip(input_specs, inputs) fgraph.attach_feature(Supervisor(input for spec, input in zip(input_specs, fgraph.inputs)
if not (spec.mutable or (hasattr(fgraph, 'destroyers') if not (spec.mutable or (hasattr(fgraph, 'destroyers')
and fgraph.destroyers(input))))) and fgraph.destroyers(input)))))
...@@ -1595,7 +1594,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -1595,7 +1594,7 @@ class _Linker(gof.link.LocalLinker):
# directly from PureOp) # directly from PureOp)
if not isinstance(node.op, gof.op.Op): if not isinstance(node.op, gof.op.Op):
raise utils.MethodNotDefined() raise utils.MethodNotDefined()
e = FunctionGraph(*graph.clone(node.inputs, node.outputs)) e = FunctionGraph(node.inputs, node.outputs)
# The toposort isn't a stochastic order as it contain only one node. # The toposort isn't a stochastic order as it contain only one node.
e.toposort = lambda: list(e.apply_nodes) e.toposort = lambda: list(e.apply_nodes)
# Specifically... e.nodes is a set, but of only 1 element # Specifically... e.nodes is a set, but of only 1 element
......
...@@ -129,8 +129,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False): ...@@ -129,8 +129,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False):
updates = [spec.update for spec in input_specs if spec.update] updates = [spec.update for spec in input_specs if spec.update]
orig_outputs = [spec.variable for spec in output_specs] + updates orig_outputs = [spec.variable for spec in output_specs] + updates
inputs, outputs = gof.graph.clone(orig_inputs, orig_outputs) fgraph = gof.fg.FunctionGraph(orig_inputs, orig_outputs)
fgraph = gof.fg.FunctionGraph(inputs, outputs)
for node in fgraph.apply_nodes: for node in fgraph.apply_nodes:
if getattr(node.op, 'destroy_map', None): if getattr(node.op, 'destroy_map', None):
...@@ -143,7 +142,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False): ...@@ -143,7 +142,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False):
# We need to protect all immutable inputs from inplace operations. # We need to protect all immutable inputs from inplace operations.
fgraph.attach_feature( fgraph.attach_feature(
Supervisor(input Supervisor(input
for spec, input in zip(input_specs, inputs) for spec, input in zip(input_specs, fgraph.inputs)
if not (spec.mutable or if not (spec.mutable or
(hasattr(fgraph, 'destroyers') and (hasattr(fgraph, 'destroyers') and
fgraph.destroyers(input))))) fgraph.destroyers(input)))))
...@@ -1306,12 +1305,12 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False, ...@@ -1306,12 +1305,12 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
else: else:
Maker = getattr(mode, 'function_maker', FunctionMaker) Maker = getattr(mode, 'function_maker', FunctionMaker)
fn = Maker(inputs, fn = Maker(inputs,
outputs, outputs,
mode, mode,
accept_inplace=accept_inplace, accept_inplace=accept_inplace,
profile=profile, profile=profile,
on_unused_input=on_unused_input).create( on_unused_input=on_unused_input).create(
defaults) defaults)
t2 = time.time() t2 = time.time()
if profile: if profile:
......
...@@ -43,7 +43,7 @@ from theano.gof.compiledir import \ ...@@ -43,7 +43,7 @@ from theano.gof.compiledir import \
local_bitwidth, python_int_bitwidth local_bitwidth, python_int_bitwidth
from theano.gof.fg import \ from theano.gof.fg import \
InconsistencyError, MissingInputError, FunctionGraph CachedConstantError, InconsistencyError, MissingInputError, FunctionGraph
from theano.gof.destroyhandler import \ from theano.gof.destroyhandler import \
DestroyHandler DestroyHandler
......
...@@ -5,6 +5,7 @@ Contains the FunctionGraph class and exception ...@@ -5,6 +5,7 @@ Contains the FunctionGraph class and exception
types that it can raise types that it can raise
""" """
import sys import sys
import time
import theano import theano
from theano.gof import graph from theano.gof import graph
...@@ -19,6 +20,14 @@ from theano.gof.python25 import OrderedDict ...@@ -19,6 +20,14 @@ from theano.gof.python25 import OrderedDict
from theano.misc.ordered_set import OrderedSet from theano.misc.ordered_set import OrderedSet
class CachedConstantError(Exception):
"""An exception thrown when we put in a FunctionGraph a Constant
that is cached. This should not happen as the user can reuse this
cached constant in other FunctionGraph.
"""
pass
class InconsistencyError(Exception): class InconsistencyError(Exception):
""" """
This exception should be thrown by listeners to FunctionGraph when the This exception should be thrown by listeners to FunctionGraph when the
...@@ -65,7 +74,7 @@ class FunctionGraph(utils.object2): ...@@ -65,7 +74,7 @@ class FunctionGraph(utils.object2):
""" """
def __init__(self, inputs, outputs, features=None): def __init__(self, inputs, outputs, features=None, clone=True):
""" """
Create an FunctionGraph which operates on the subgraph bound by the inputs and Create an FunctionGraph which operates on the subgraph bound by the inputs and
outputs sets. outputs sets.
...@@ -76,7 +85,14 @@ class FunctionGraph(utils.object2): ...@@ -76,7 +85,14 @@ class FunctionGraph(utils.object2):
#TODO: document what variables are[not] set in the FunctionGraph when a feature #TODO: document what variables are[not] set in the FunctionGraph when a feature
is added via the constructor. How constructed is the FunctionGraph? is added via the constructor. How constructed is the FunctionGraph?
:param clone: If true, we will clone the graph. This is
usefull to remove the constant cache problem.
""" """
if clone:
inputs, outputs = graph.clone(inputs, outputs)
self.execute_callbacks_time = 0
if features is None: if features is None:
features = [] features = []
...@@ -119,6 +135,11 @@ class FunctionGraph(utils.object2): ...@@ -119,6 +135,11 @@ class FunctionGraph(utils.object2):
### Setup a Variable ### ### Setup a Variable ###
def __setup_r__(self, r): def __setup_r__(self, r):
# sets up r so it belongs to this fgraph # sets up r so it belongs to this fgraph
if getattr(r, 'cached', False):
raise CachedConstantError(
"You manually constructed a FunctionGraph, but you passed it a"
" graph that have cached constant. This should happen."
" Clone the graph before building the FunctionGraph")
if (hasattr(r, 'fgraph') and if (hasattr(r, 'fgraph') and
r.fgraph is not None and r.fgraph is not None and
r.fgraph is not self): r.fgraph is not self):
...@@ -223,10 +244,8 @@ class FunctionGraph(utils.object2): ...@@ -223,10 +244,8 @@ class FunctionGraph(utils.object2):
if NullType is None: if NullType is None:
from null_type import NullType from null_type import NullType
# Imports the owners of the variables # Imports the owners of the variables
r_owner_done = set(self.apply_nodes)
for apply_node in [r.owner for r in variables if r.owner is not None]: for apply_node in [r.owner for r in variables if r.owner is not None]:
if apply_node not in r_owner_done: if apply_node not in self.apply_nodes:
r_owner_done.add(apply_node)
self.__import__(apply_node, reason=reason) self.__import__(apply_node, reason=reason)
for r in variables: for r in variables:
if r.owner is None and not isinstance(r, graph.Constant) and r not in self.inputs: if r.owner is None and not isinstance(r, graph.Constant) and r not in self.inputs:
...@@ -521,6 +540,7 @@ class FunctionGraph(utils.object2): ...@@ -521,6 +540,7 @@ class FunctionGraph(utils.object2):
getattr(feature, name)(*args) getattr(feature, name)(*args)
for each feature which has a method called after name. for each feature which has a method called after name.
""" """
t0 = time.time()
for feature in self._features: for feature in self._features:
try: try:
fn = getattr(feature, name) fn = getattr(feature, name)
...@@ -531,6 +551,7 @@ class FunctionGraph(utils.object2): ...@@ -531,6 +551,7 @@ class FunctionGraph(utils.object2):
continue continue
fn(self, *args, **kwargs) fn(self, *args, **kwargs)
self.execute_callbacks_time += time.time() - t0
def collect_callbacks(self, name, *args): def collect_callbacks(self, name, *args):
"""WRITEME """WRITEME
......
...@@ -593,7 +593,7 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -593,7 +593,7 @@ class Op(utils.object2, PureOp, CLinkerOp):
#logger.debug('Compiling node %i of graph' % node_idx) #logger.debug('Compiling node %i of graph' % node_idx)
if self._op_use_c_code: if self._op_use_c_code:
try: try:
e = FunctionGraph(*graph.clone(node.inputs, node.outputs)) e = FunctionGraph(node.inputs, node.outputs)
e_no_recycling = [new_o e_no_recycling = [new_o
for (new_o, old_o) in zip(e.outputs, node.outputs) for (new_o, old_o) in zip(e.outputs, node.outputs)
......
...@@ -76,7 +76,13 @@ class Optimizer(object): ...@@ -76,7 +76,13 @@ class Optimizer(object):
opt.apply(fgraph) opt.apply(fgraph)
""" """
self.add_requirements(fgraph) self.add_requirements(fgraph)
return self.apply(fgraph, *args, **kwargs) try:
orig = theano.tensor.basic.constant.enable
theano.tensor.basic.constant.enable = False
ret = self.apply(fgraph, *args, **kwargs)
finally:
theano.tensor.basic.constant.enable = orig
return ret
def __call__(self, fgraph): def __call__(self, fgraph):
"""WRITEME """WRITEME
...@@ -165,6 +171,10 @@ class SeqOptimizer(Optimizer, list): ...@@ -165,6 +171,10 @@ class SeqOptimizer(Optimizer, list):
l = [] l = []
if fgraph.profile: if fgraph.profile:
validate_before = fgraph.profile.validate_time validate_before = fgraph.profile.validate_time
sub_validate_time = [validate_before]
else:
sub_validate_time = []
callback_before = fgraph.execute_callbacks_time
nb_node_before = len(fgraph.apply_nodes) nb_node_before = len(fgraph.apply_nodes)
sub_profs = [] sub_profs = []
for optimizer in self: for optimizer in self:
...@@ -173,6 +183,8 @@ class SeqOptimizer(Optimizer, list): ...@@ -173,6 +183,8 @@ class SeqOptimizer(Optimizer, list):
sub_prof = optimizer.optimize(fgraph) sub_prof = optimizer.optimize(fgraph)
l.append(float(time.time() - t0)) l.append(float(time.time() - t0))
sub_profs.append(sub_prof) sub_profs.append(sub_prof)
if fgraph.profile:
sub_validate_time.append(fgraph.profile.validate_time)
except AssertionError: except AssertionError:
# do not catch Assertion failures # do not catch Assertion failures
raise raise
...@@ -187,8 +199,9 @@ class SeqOptimizer(Optimizer, list): ...@@ -187,8 +199,9 @@ class SeqOptimizer(Optimizer, list):
validate_time = fgraph.profile.validate_time - validate_before validate_time = fgraph.profile.validate_time - validate_before
else: else:
validate_time = None validate_time = None
return (self, l, validate_time, nb_node_before, callback_time = fgraph.execute_callbacks_time - callback_before
len(fgraph.apply_nodes), sub_profs) return (self, l, validate_time, callback_time, nb_node_before,
len(fgraph.apply_nodes), sub_profs, sub_validate_time)
def __str__(self): def __str__(self):
return "SeqOpt(%s)" % list.__str__(self) return "SeqOpt(%s)" % list.__str__(self)
...@@ -208,8 +221,8 @@ class SeqOptimizer(Optimizer, list): ...@@ -208,8 +221,8 @@ class SeqOptimizer(Optimizer, list):
@staticmethod @staticmethod
def print_profile(stream, prof, level=0): def print_profile(stream, prof, level=0):
(opts, prof, validate_time, nb_node_before, (opts, prof, validate_time, callback_time, nb_node_before,
nb_node_after, sub_profs) = prof nb_node_after, sub_profs, sub_validate_time) = prof
blanc = (' ' * level) blanc = (' ' * level)
print >> stream, blanc, "SeqOptimizer", print >> stream, blanc, "SeqOptimizer",
...@@ -222,8 +235,10 @@ class SeqOptimizer(Optimizer, list): ...@@ -222,8 +235,10 @@ class SeqOptimizer(Optimizer, list):
sum(prof), nb_node_before, nb_node_after)) sum(prof), nb_node_before, nb_node_after))
print >> stream, \ print >> stream, \
blanc, " %.3fs for fgraph.validate()" % (validate_time) blanc, " %.3fs for fgraph.validate()" % (validate_time)
print >> stream, \
blanc, " %.3fs for callback" % (callback_time)
if level == 0: if level == 0:
print >> stream, blanc, " time - (name, class, index)" print >> stream, blanc, " time - (name, class, index) - validate time"
ll = [] ll = []
for opt in opts: for opt in opts:
if hasattr(opt, "__name__"): if hasattr(opt, "__name__"):
...@@ -245,7 +260,14 @@ class SeqOptimizer(Optimizer, list): ...@@ -245,7 +260,14 @@ class SeqOptimizer(Optimizer, list):
for (t, opt) in lll[::-1]: for (t, opt) in lll[::-1]:
#if t < 1: #if t < 1:
# continue # continue
print >> stream, blanc, ' %.6fs - %s' % (t, opt) if sub_validate_time:
i = opt[-1]
val_time = sub_validate_time[i + 1] - sub_validate_time[i]
print >> stream, blanc, ' %.6fs - %s - %.3fs' % (
t, opt, val_time)
else:
print >> stream, blanc, ' %.6fs - %s' % (t, opt)
if sub_profs[opt[-1]]: if sub_profs[opt[-1]]:
opts[opt[-1]].print_profile(stream, sub_profs[opt[-1]], opts[opt[-1]].print_profile(stream, sub_profs[opt[-1]],
level=level + 1) level=level + 1)
...@@ -539,6 +561,13 @@ class MergeOptimizer(Optimizer): ...@@ -539,6 +561,13 @@ class MergeOptimizer(Optimizer):
# Constant and non-constant are now applied in the same phase. # Constant and non-constant are now applied in the same phase.
# I am not sure why, but it seems to be faster this way. # I am not sure why, but it seems to be faster this way.
sched = fgraph.merge_feature.scheduled sched = fgraph.merge_feature.scheduled
nb_fail = 0
t0 = time.time()
if fgraph.profile:
validate_before = fgraph.profile.validate_time
callback_before = fgraph.execute_callbacks_time
nb_merged = 0
nb_constant = 0
while sched: while sched:
pairs_list = sched.pop() pairs_list = sched.pop()
success = True success = True
...@@ -547,17 +576,44 @@ class MergeOptimizer(Optimizer): ...@@ -547,17 +576,44 @@ class MergeOptimizer(Optimizer):
fgraph.replace_all_validate(pairs, 'Merge') fgraph.replace_all_validate(pairs, 'Merge')
except InconsistencyError: except InconsistencyError:
success = False success = False
nb_fail += 1
fgraph.merge_feature.blacklist.append( fgraph.merge_feature.blacklist.append(
(pairs[0][0].owner, pairs[0][1].owner)) (pairs[0][0].owner, pairs[0][1].owner))
if success: if success:
nb_merged += len(pairs)
if isinstance(pairs[0][0], graph.Constant):
nb_constant += 1
#print pairs, pairs[0][0].type
break break
if fgraph.profile:
validate_time = fgraph.profile.validate_time - validate_before
callback_time = fgraph.execute_callbacks_time - callback_before
else:
validate_time = None
callback_time = None
# clear blacklist # clear blacklist
fgraph.merge_feature.blacklist = [] fgraph.merge_feature.blacklist = []
return (nb_fail, time.time() - t0, validate_time,
callback_time, nb_merged, nb_constant)
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
@staticmethod
def print_profile(stream, prof, level=0):
nb_fail, replace_time, validate_time, callback_time, nb_merged, nb_constant = prof
blanc = (' ' * level)
print >> stream, blanc, "MergeOptimizer"
print >> stream, blanc, " nb_fail", nb_fail
print >> stream, blanc, " replace_time", replace_time
print >> stream, blanc, " validate_time", validate_time
print >> stream, blanc, " callback_time", callback_time
print >> stream, blanc, " nb_merged", nb_merged
print >> stream, blanc, " nb_constant", nb_constant
merge_optimizer = MergeOptimizer() merge_optimizer = MergeOptimizer()
...@@ -575,7 +631,9 @@ def is_same_graph_with_merge(var1, var2, givens=None): ...@@ -575,7 +631,9 @@ def is_same_graph_with_merge(var1, var2, givens=None):
givens = copied[2] givens = copied[2]
# Create FunctionGraph. # Create FunctionGraph.
inputs = theano.gof.graph.inputs(vars) inputs = theano.gof.graph.inputs(vars)
fgraph = theano.gof.fg.FunctionGraph(inputs, vars) # The clone isn't needed as we did a deepcopy and we cloning will
# break the mapping in givens.
fgraph = theano.gof.fg.FunctionGraph(inputs, vars, clone=False)
# Perform Variable substitution. # Perform Variable substitution.
for to_replace, replace_by in givens.iteritems(): for to_replace, replace_by in givens.iteritems():
fgraph.replace(to_replace, replace_by) fgraph.replace(to_replace, replace_by)
...@@ -1114,7 +1172,7 @@ class NavigatorOptimizer(Optimizer): ...@@ -1114,7 +1172,7 @@ class NavigatorOptimizer(Optimizer):
pass pass
def __init__(self, local_opt, ignore_newtrees='auto', def __init__(self, local_opt, ignore_newtrees='auto',
failure_callback=None): failure_callback=None):
""" """
:param local_opt: a LocalOptimizer to apply over a FunctionGraph :param local_opt: a LocalOptimizer to apply over a FunctionGraph
(or None is Ok too). (or None is Ok too).
...@@ -1278,7 +1336,11 @@ class TopoOptimizer(NavigatorOptimizer): ...@@ -1278,7 +1336,11 @@ class TopoOptimizer(NavigatorOptimizer):
def apply(self, fgraph, start_from=None): def apply(self, fgraph, start_from=None):
if start_from is None: if start_from is None:
start_from = fgraph.outputs start_from = fgraph.outputs
callback_before = fgraph.execute_callbacks_time
nb_nodes_start = len(fgraph.apply_nodes)
t0 = time.time()
q = deque(graph.io_toposort(fgraph.inputs, start_from)) q = deque(graph.io_toposort(fgraph.inputs, start_from))
io_t = time.time() - t0
def importer(node): def importer(node):
if node is not current_node: if node is not current_node:
...@@ -1292,19 +1354,40 @@ class TopoOptimizer(NavigatorOptimizer): ...@@ -1292,19 +1354,40 @@ class TopoOptimizer(NavigatorOptimizer):
pass pass
u = self.attach_updater(fgraph, importer, pruner) u = self.attach_updater(fgraph, importer, pruner)
nb = 0
try: try:
t0 = time.time()
while q: while q:
if self.order == 'out_to_in': if self.order == 'out_to_in':
node = q.pop() node = q.pop()
else: else:
node = q.popleft() node = q.popleft()
current_node = node current_node = node
self.process_node(fgraph, node) nb += self.process_node(fgraph, node)
loop_t = time.time() - t0
except Exception: except Exception:
self.detach_updater(fgraph, u) self.detach_updater(fgraph, u)
raise raise
self.detach_updater(fgraph, u) self.detach_updater(fgraph, u)
callback_time = fgraph.execute_callbacks_time - callback_before
nb_nodes_end = len(fgraph.apply_nodes)
return (nb, nb_nodes_start, nb_nodes_end,
io_t, loop_t, callback_time)
@staticmethod
def print_profile(stream, prof, level=0):
(nb, nb_nodes_start, nb_nodes_end,
io_t, loop_t, callback_time) = prof
blanc = (' ' * level)
print >> stream, blanc, "TopoOptimizer"
print >> stream, blanc, " nb_node (start, end, changed)", (
nb_nodes_start, nb_nodes_end, nb)
print >> stream, blanc, " init io_toposort", io_t
print >> stream, blanc, " loop time", loop_t
print >> stream, blanc, " callback_time", callback_time
def __str__(self): def __str__(self):
return getattr(self, '__name__', return getattr(self, '__name__',
'<TopoOptimizer instance>') '<TopoOptimizer instance>')
......
import unittest
import theano
from theano.gof import CachedConstantError, FunctionGraph
class TFunctionGraph(unittest.TestCase):
def test_constant_cache_error(self):
v = theano.tensor.constant(1)
assert v.cached
self.assertRaises(CachedConstantError, FunctionGraph, [], [v + 1],
clone=False)
def test_clone(self):
v = theano.tensor.constant(1)
assert v.cached
FunctionGraph([], [v + 1])
...@@ -3,7 +3,7 @@ from theano.gof.graph import Variable, Apply ...@@ -3,7 +3,7 @@ from theano.gof.graph import Variable, Apply
from theano.gof.type import Type from theano.gof.type import Type
from theano.gof.op import Op from theano.gof.op import Op
from theano.gof.fg import FunctionGraph as Env, InconsistencyError from theano.gof.fg import FunctionGraph, InconsistencyError
from theano.gof.toolbox import * from theano.gof.toolbox import *
...@@ -61,14 +61,13 @@ def inputs(): ...@@ -61,14 +61,13 @@ def inputs():
return x, y, z return x, y, z
class TestNodeFinder: class TestNodeFinder:
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
e0 = dot(y, z) e0 = dot(y, z)
e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0)) e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e], clone=False)
g.attach_feature(NodeFinder()) g.attach_feature(NodeFinder())
assert hasattr(g, 'get_nodes') assert hasattr(g, 'get_nodes')
......
...@@ -220,7 +220,7 @@ class GpuArrayVariable(_operators, Variable): ...@@ -220,7 +220,7 @@ class GpuArrayVariable(_operators, Variable):
GpuArrayType.Variable = GpuArrayVariable GpuArrayType.Variable = GpuArrayVariable
class GpuArraySignature(tensor.basic.TensorConstantSignature): class GpuArraySignature(tensor.TensorConstantSignature):
pass # might do something better if we can run the sum on the pass # might do something better if we can run the sum on the
# GPU, but for now this will suffice. # GPU, but for now this will suffice.
......
...@@ -2928,7 +2928,10 @@ class Composite(ScalarOp): ...@@ -2928,7 +2928,10 @@ class Composite(ScalarOp):
self.name = rval self.name = rval
def init_fgraph(self): def init_fgraph(self):
fgraph = FunctionGraph(*gof.graph.clone(self.inputs, self.outputs)) #The clone done by FunctionGraph is needed as we don't want
#the fgraph to be set to the variable as we need to pickle
#them for the cache of c module to work.
fgraph = FunctionGraph(self.inputs, self.outputs)
gof.MergeOptimizer().optimize(fgraph) gof.MergeOptimizer().optimize(fgraph)
for node in fgraph.apply_nodes: for node in fgraph.apply_nodes:
if not isinstance(node.op, ScalarOp): if not isinstance(node.op, ScalarOp):
......
...@@ -182,8 +182,8 @@ class Scan(PureOp): ...@@ -182,8 +182,8 @@ class Scan(PureOp):
self.n_tap_outs = self.n_mit_mot + self.n_mit_sot self.n_tap_outs = self.n_mit_mot + self.n_mit_sot
if not self.info['gpu']: if not self.info['gpu']:
tmp_in, tmp_out = scan_utils.reconstruct_graph(self.inputs, tmp_in, tmp_out = scan_utils.reconstruct_graph(self.inputs,
self.outputs) self.outputs)
local_fgraph = gof.FunctionGraph(tmp_in, tmp_out) local_fgraph = gof.FunctionGraph(tmp_in, tmp_out, clone=False)
self._cmodule_key = gof.CLinker().cmodule_key_(local_fgraph, []) self._cmodule_key = gof.CLinker().cmodule_key_(local_fgraph, [])
self._hash_inner_graph = hash(self._cmodule_key) self._hash_inner_graph = hash(self._cmodule_key)
else: else:
......
...@@ -173,7 +173,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -173,7 +173,7 @@ class PushOutNonSeqScan(gof.Optimizer):
clean_inputs, clean_outputs = scan_utils.reconstruct_graph( clean_inputs, clean_outputs = scan_utils.reconstruct_graph(
node.op.inputs, node.op.outputs) node.op.inputs, node.op.outputs)
local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs) local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs, clone=False)
max_iterations = 2 * len(local_fgraph.toposort()) + 3 max_iterations = 2 * len(local_fgraph.toposort()) + 3
counts = 0 counts = 0
to_remove = [] to_remove = []
...@@ -347,7 +347,7 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -347,7 +347,7 @@ class PushOutSeqScan(gof.Optimizer):
clean_inputs, clean_outputs = scan_utils.reconstruct_graph( clean_inputs, clean_outputs = scan_utils.reconstruct_graph(
node.op.inputs, node.op.outputs) node.op.inputs, node.op.outputs)
local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs) local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs, clone=False)
max_iterations = 2 * len(local_fgraph.toposort()) + 3 max_iterations = 2 * len(local_fgraph.toposort()) + 3
counts = 0 counts = 0
to_remove = [] to_remove = []
......
...@@ -6,6 +6,9 @@ import warnings ...@@ -6,6 +6,9 @@ import warnings
from theano.tensor.basic import * from theano.tensor.basic import *
from theano.tensor.subtensor import * from theano.tensor.subtensor import *
from theano.tensor.type_other import * from theano.tensor.type_other import *
from theano.tensor.var import (
AsTensorError, _tensor_py_operators, TensorVariable,
TensorConstantSignature, TensorConstant)
from theano.tensor import opt from theano.tensor import opt
from theano.tensor import opt_uncanonicalize from theano.tensor import opt_uncanonicalize
......
...@@ -5,7 +5,6 @@ __docformat__ = "restructuredtext en" ...@@ -5,7 +5,6 @@ __docformat__ = "restructuredtext en"
import sys import sys
import warnings import warnings
from itertools import izip from itertools import izip
from textwrap import dedent
import numpy import numpy
from copy import copy as python_copy from copy import copy as python_copy
...@@ -17,13 +16,12 @@ from theano.gof import Apply, Constant, Op, Variable ...@@ -17,13 +16,12 @@ from theano.gof import Apply, Constant, Op, Variable
from theano.tensor import elemwise from theano.tensor import elemwise
from theano.tensor.var import (AsTensorError, TensorVariable, from theano.tensor.var import (AsTensorError, TensorVariable,
TensorConstantSignature,
TensorConstant, TensorConstant,
_tensor_py_operators) _tensor_py_operators)
from theano.tensor.type import TensorType from theano.tensor.type import TensorType
from theano import scalar as scal from theano import scalar as scal
from theano.gof.python25 import partial, any, all, maxsize from theano.gof.python25 import partial, any, all
from theano.gof.utils import hashtype, MethodNotDefined from theano.gof.utils import hashtype
from theano import compile, printing from theano import compile, printing
from theano.printing import pprint, min_informative_str from theano.printing import pprint, min_informative_str
...@@ -400,10 +398,31 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None): ...@@ -400,10 +398,31 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None):
raise TypeError("Could not convert %s to TensorType" % x, type(x)) raise TypeError("Could not convert %s to TensorType" % x, type(x))
constant_cache = {}
def constant(x, name=None, ndim=None, dtype=None): def constant(x, name=None, ndim=None, dtype=None):
return constant_or_value(x, rtype=TensorConstant, name=name, ndim=ndim, ret = constant_or_value(x, rtype=TensorConstant, name=name, ndim=ndim,
dtype=dtype) dtype=dtype)
#We create a small cache of frequently used constant.
#This speed up the Merge optimization for big graph.
#We want to cache all scalar to don't merge as frequently constants.
#But we don't want to cache too much stuff
#So we cache integer with dtype [u]int and float where the value is between -10 and 10
#We want to cache all broadcast pattern for scalar.
if not constant.enable:
return ret
sig = ret.signature()
if (sig not in constant_cache and ret.data.size == 1 and
ret.data <= 10 and ret.data >= -10 and
(ret.dtype in int_dtypes or ret.dtype in uint_dtypes or
(ret.dtype in float_dtypes and int(ret.data) == ret.data))):
constant_cache[sig] = ret
# This is needed to raise a good error to the user.
ret.cached = True
return constant_cache.get(sig, ret)
constant.enable = True
def _obj_is_wrappable_as_tensor(x): def _obj_is_wrappable_as_tensor(x):
try: try:
......
...@@ -147,7 +147,7 @@ import theano.scalar ...@@ -147,7 +147,7 @@ import theano.scalar
from theano.tensor import basic as T from theano.tensor import basic as T
from theano.tensor.blas_headers import blas_header_text from theano.tensor.blas_headers import blas_header_text
from theano.tensor.blas_headers import blas_header_version from theano.tensor.blas_headers import blas_header_version
from theano.tensor.opt import local_dimshuffle_lift, in2out from theano.tensor.opt import in2out, local_dimshuffle_lift
_logger = logging.getLogger('theano.tensor.blas') _logger = logging.getLogger('theano.tensor.blas')
...@@ -896,9 +896,22 @@ class Gemm(GemmRelated): ...@@ -896,9 +896,22 @@ class Gemm(GemmRelated):
"Wrong number of inputs for %s (expected 5, got %s)" % "Wrong number of inputs for %s (expected 5, got %s)" %
(self, len(inputs))) (self, len(inputs)))
z, a, x, y, b = inputs z, a, x, y, b = inputs
# For the consistency check we don't want z to be a cached constant.
if getattr(z, 'cached', False):
z = copy.copy(z)
zr, xr, yr = [set(view_roots(i)) for i in z, x, y] zr, xr, yr = [set(view_roots(i)) for i in z, x, y]
# TODO: justify / delete # We want the gemm to be inplace. When this op is inplace, it
# declare to be inplace only on z. So to make it safe, we
# raise an error if z can be a view on x or y.
# I don't know if Theano currently can support that case. As
# this case don't happen in our code, I won't spent time
# investigating this. So the assert is for safety. I also
# think there is another mechanism that would prevent this,
# but I don't what to modify old code and have chance to break
# something.
if zr.intersection(xr): if zr.intersection(xr):
raise InconsistencyError(Gemm.E_z_uniq, (z, x)) raise InconsistencyError(Gemm.E_z_uniq, (z, x))
if zr.intersection(yr): if zr.intersection(yr):
...@@ -1758,8 +1771,8 @@ optdb.register('BlasOpt', blas_optdb, 1.7, 'fast_run') ...@@ -1758,8 +1771,8 @@ optdb.register('BlasOpt', blas_optdb, 1.7, 'fast_run')
# free-for-all that makes the graph crazy. # free-for-all that makes the graph crazy.
blas_optdb.register('local_dot_to_dot22', blas_optdb.register('local_dot_to_dot22',
EquilibriumOptimizer([local_dot_to_dot22], max_use_ratio=5), in2out(local_dot_to_dot22),
0, 'fast_run') 0, 'fast_run')
blas_optdb.register('gemm_optimizer', blas_optdb.register('gemm_optimizer',
GemmOptimizer(), GemmOptimizer(),
10, 'fast_run') 10, 'fast_run')
...@@ -1983,8 +1996,8 @@ def local_dot22_to_dot22scalar(node): ...@@ -1983,8 +1996,8 @@ def local_dot22_to_dot22scalar(node):
#must happen after gemm as the gemm optimizer don't understant #must happen after gemm as the gemm optimizer don't understant
#dot22scalar and gemm give more speed up then dot22scalar #dot22scalar and gemm give more speed up then dot22scalar
blas_optdb.register('local_dot22_to_dot22scalar', blas_optdb.register('local_dot22_to_dot22scalar',
EquilibriumOptimizer([local_dot22_to_dot22scalar], max_use_ratio=5), in2out(local_dot22_to_dot22scalar),
11, 'fast_run') 11, 'fast_run')
#from opt import register_specialize, register_canonicalize #from opt import register_specialize, register_canonicalize
......
from theano import config from theano import config
from theano.tensor.opt import in2out
from theano.tensor.blas import ldflags, blas_header_text, blas_header_version from theano.tensor.blas import ldflags, blas_header_text, blas_header_version
from theano.tensor.blas import blas_optdb, optdb, local_optimizer, EquilibriumOptimizer from theano.tensor.blas import blas_optdb, optdb, local_optimizer, EquilibriumOptimizer
from theano.tensor.blas import Ger, ger, ger_destructive from theano.tensor.blas import Ger, ger, ger_destructive
...@@ -609,21 +609,14 @@ def make_c_gemv_destructive(node): ...@@ -609,21 +609,14 @@ def make_c_gemv_destructive(node):
####### ####### ####### ####### ####### #######
blas_optdb.register('use_c_blas', blas_optdb.register('use_c_blas',
EquilibriumOptimizer([ in2out(use_c_ger, use_c_gemv),
use_c_ger, 20, 'fast_run', 'c_blas')
use_c_gemv,
],
max_use_ratio=5),
20, 'fast_run', 'c_blas')
#print 'BLAS_OPTDB' #print 'BLAS_OPTDB'
#print blas_optdb #print blas_optdb
# this matches the InplaceBlasOpt defined in blas.py # this matches the InplaceBlasOpt defined in blas.py
optdb.register('c_blas_destructive', optdb.register('c_blas_destructive',
EquilibriumOptimizer([ in2out(make_c_ger_destructive,
make_c_ger_destructive, make_c_gemv_destructive,
make_c_gemv_destructive, name="c_blas_destructive"),
], 70.0, 'fast_run', 'inplace', 'c_blas')
failure_callback=EquilibriumOptimizer.warn_inplace,
max_use_ratio=5),
70.0, 'fast_run', 'inplace', 'c_blas')
...@@ -749,10 +749,8 @@ class Elemwise(Op): ...@@ -749,10 +749,8 @@ class Elemwise(Op):
# the gradient contains a constant, translate it as # the gradient contains a constant, translate it as
# an equivalent TensorType of size 1 and proper number of # an equivalent TensorType of size 1 and proper number of
# dimensions # dimensions
res = TensorConstant(TensorType(dtype=r.type.dtype, res = theano.tensor.constant(numpy.asarray(r.data), dtype=r.type.dtype)
broadcastable=()), return DimShuffle((), ['x'] * nd, inplace=False)(res)
numpy.asarray(r.data)) # .reshape(b)
return DimShuffle((), ['x'] * nd, inplace=True)(res)
new_r = Elemwise(node.op, {})( new_r = Elemwise(node.op, {})(
*[transform(ipt) for ipt in node.inputs]) *[transform(ipt) for ipt in node.inputs])
return new_r return new_r
......
...@@ -2119,7 +2119,7 @@ def local_IncSubtensor_serialize(node): ...@@ -2119,7 +2119,7 @@ def local_IncSubtensor_serialize(node):
# We register it in a TopoOptimizer inside the canonizer EQ optimizer. # We register it in a TopoOptimizer inside the canonizer EQ optimizer.
# Otherwise in some cases it was making the EQ optimizer use 45. In # Otherwise in some cases it was making the EQ optimizer use 45. In
# the TopoOptimizer, the EQ only use 6 passes. # the TopoOptimizer, the EQ only use 5 passes.
compile.optdb.register('pre_local_IncSubtensor_serialize', compile.optdb.register('pre_local_IncSubtensor_serialize',
in2out(local_IncSubtensor_serialize), in2out(local_IncSubtensor_serialize),
#Just before canonizer #Just before canonizer
...@@ -2136,13 +2136,13 @@ def local_inplace_setsubtensor(node): ...@@ -2136,13 +2136,13 @@ def local_inplace_setsubtensor(node):
""" """
if isinstance(node.op, IncSubtensor) and not node.op.inplace: if isinstance(node.op, IncSubtensor) and not node.op.inplace:
new_op = node.op.__class__( new_op = node.op.__class__(
node.op.idx_list, inplace=True, node.op.idx_list, inplace=True,
set_instead_of_inc=node.op.set_instead_of_inc, set_instead_of_inc=node.op.set_instead_of_inc,
destroyhandler_tolerate_aliased=node.op.destroyhandler_tolerate_aliased) destroyhandler_tolerate_aliased=node.op.destroyhandler_tolerate_aliased)
new_node = new_op(*node.inputs) new_node = new_op(*node.inputs)
return [new_node] return [new_node]
return False return False
compile.optdb.register('inplace_setsubtensor', compile.optdb.register('local_inplace_setsubtensor',
TopoOptimizer(local_inplace_setsubtensor, TopoOptimizer(local_inplace_setsubtensor,
failure_callback=TopoOptimizer.warn_inplace), 60, failure_callback=TopoOptimizer.warn_inplace), 60,
'fast_run', 'inplace') # DEBUG 'fast_run', 'inplace') # DEBUG
...@@ -3711,17 +3711,16 @@ def local_add_specialize(node): ...@@ -3711,17 +3711,16 @@ def local_add_specialize(node):
continue continue
new_inputs.append(input) new_inputs.append(input)
if len(new_inputs) < len(node.inputs): if len(new_inputs) < len(node.inputs):
dtype = node.outputs[0].type.dtype dtype = node.outputs[0].type.dtype
if len(new_inputs) == 0: if len(new_inputs) == 0:
#we got rid of the entire expression! #we got rid of the entire expression!
ndim = node.outputs[0].type.ndim ndim = node.outputs[0].type.ndim
return fill_chain( #Reuse call to constant for cache()
T.TensorConstant( cst = T.constant(numpy.zeros((1,) * ndim, dtype=dtype))
T.TensorType( assert cst.type.broadcastable == (True,) * ndim
dtype=dtype, return fill_chain(cst)
broadcastable=[True] * ndim),
numpy.zeros((1,) * ndim, dtype=dtype)))
if len(new_inputs) == 1: if len(new_inputs) == 1:
ret = fill_chain(new_inputs[0]) ret = fill_chain(new_inputs[0])
......
...@@ -22,9 +22,6 @@ Also, we should make the fgraph refuse optimization that break the canonization ...@@ -22,9 +22,6 @@ Also, we should make the fgraph refuse optimization that break the canonization
# TODO: intelligent merge for mul/add # TODO: intelligent merge for mul/add
# TODO: 0*x -> 0 # TODO: 0*x -> 0
import logging import logging
_logger = logging.getLogger('theano.tensor.opt') _logger = logging.getLogger('theano.tensor.opt')
...@@ -35,10 +32,12 @@ from theano.tensor import basic as T ...@@ -35,10 +32,12 @@ from theano.tensor import basic as T
from theano.gof.opt import Optimizer from theano.gof.opt import Optimizer
from theano.gof import InconsistencyError, toolbox from theano.gof import InconsistencyError, toolbox
from theano.tensor.basic import get_scalar_constant_value, NotScalarConstantError from theano.tensor.basic import (get_scalar_constant_value,
NotScalarConstantError)
from theano.tensor.opt import register_uncanonicalize from theano.tensor.opt import register_uncanonicalize
from theano import scalar as scal from theano import scalar as scal
class MaxAndArgmaxOptimizer(Optimizer): class MaxAndArgmaxOptimizer(Optimizer):
"""Replace MaxAndArgmax by CAReduce when the argmax is not used """Replace MaxAndArgmax by CAReduce when the argmax is not used
...@@ -56,23 +55,25 @@ class MaxAndArgmaxOptimizer(Optimizer): ...@@ -56,23 +55,25 @@ class MaxAndArgmaxOptimizer(Optimizer):
did_something = False did_something = False
for node in nodelist: for node in nodelist:
if node.op == T._max_and_argmax: if node.op == T._max_and_argmax:
if len(node.outputs[1].clients)==0: if len(node.outputs[1].clients) == 0:
try: try:
axis=get_scalar_constant_value(node.inputs[1]) axis = get_scalar_constant_value(node.inputs[1])
except NotScalarConstantError: except NotScalarConstantError:
return False return False
new = CAReduce(scal.maximum,axis)(node.inputs[0]) new = CAReduce(scal.maximum, axis)(node.inputs[0])
try: try:
fgraph.replace_all_validate( fgraph.replace_all_validate(
((node.outputs[0],new),), ((node.outputs[0], new),),
reason = self.__class__.__name__) reason=self.__class__.__name__)
did_something = True did_something = True
break break
except InconsistencyError, e: except InconsistencyError, e:
pass pass
register_uncanonicalize(MaxAndArgmaxOptimizer(),name='MaxAndArgmaxOptimizer') register_uncanonicalize(MaxAndArgmaxOptimizer(),
name='MaxAndArgmaxOptimizer')
@register_uncanonicalize @register_uncanonicalize
@gof.local_optimizer([T._shape]) @gof.local_optimizer([T._shape])
...@@ -87,9 +88,12 @@ def local_max_to_min(node): ...@@ -87,9 +88,12 @@ def local_max_to_min(node):
""" """
if node.op == T.neg and node.inputs[0].owner: if node.op == T.neg and node.inputs[0].owner:
max = node.inputs[0] max = node.inputs[0]
if max.owner and isinstance(max.owner.op, CAReduce) and max.owner.op.scalar_op==scal.maximum: if (max.owner and
isinstance(max.owner.op, CAReduce)
and max.owner.op.scalar_op == scal.maximum):
neg = max.owner.inputs[0] neg = max.owner.inputs[0]
if neg.owner and neg.owner.op == T.neg: if neg.owner and neg.owner.op == T.neg:
return [CAReduce(scal.minimum,max.owner.op.axis)(neg.owner.inputs[0])] return [CAReduce(scal.minimum,
max.owner.op.axis)(neg.owner.inputs[0])]
return False return False
import unittest
import numpy import numpy
import theano import theano
from theano.tensor.utils import (hash_from_ndarray, hash_from_dict, from theano.tensor.utils import (hash_from_ndarray, hash_from_dict,
shape_of_variables) shape_of_variables)
def test_hash_from_ndarray(): def test_hash_from_ndarray():
...@@ -10,18 +12,18 @@ def test_hash_from_ndarray(): ...@@ -10,18 +12,18 @@ def test_hash_from_ndarray():
rng = numpy.random.rand(5, 5) rng = numpy.random.rand(5, 5)
for data in [-2, -1, 0, 1, 2, numpy.zeros((1, 5)), numpy.zeros((1, 6)), for data in [-2, -1, 0, 1, 2, numpy.zeros((1, 5)), numpy.zeros((1, 6)),
# Data buffer empty but different shapes # Data buffer empty but different shapes
numpy.zeros((1, 0)), numpy.zeros((2, 0)), numpy.zeros((1, 0)), numpy.zeros((2, 0)),
# Same data buffer and shapes but different strides # Same data buffer and shapes but different strides
numpy.arange(25).reshape(5, 5), numpy.arange(25).reshape(5, 5),
numpy.arange(25).reshape(5, 5).T, numpy.arange(25).reshape(5, 5).T,
# Same data buffer, shapes and strides but different dtypes # Same data buffer, shapes and strides but different dtypes
numpy.zeros((5, 5), dtype="uint32"), numpy.zeros((5, 5), dtype="uint32"),
numpy.zeros((5, 5), dtype="int32"), numpy.zeros((5, 5), dtype="int32"),
# Test slice # Test slice
rng, rng[1:], rng[:4], rng[1:3], rng[::2], rng[::-1] rng, rng[1:], rng[:4], rng[1:3], rng[::2], rng[::-1]
]: ]:
data = numpy.asarray(data) data = numpy.asarray(data)
hashs.append(hash_from_ndarray(data)) hashs.append(hash_from_ndarray(data))
...@@ -49,22 +51,31 @@ def test_hash_from_dict(): ...@@ -49,22 +51,31 @@ def test_hash_from_dict():
# List are not hashable. So they are transformed into tuple. # List are not hashable. So they are transformed into tuple.
assert hash_from_dict({0: (0,)}) == hash_from_dict({0: [0]}) assert hash_from_dict({0: (0,)}) == hash_from_dict({0: [0]})
def test_shape_of_variables_simple():
x = theano.tensor.matrix('x') class Tshape_of_variables(unittest.TestCase):
y = x+x def test_simple(self):
fgraph = theano.FunctionGraph([x], [y]) x = theano.tensor.matrix('x')
assert shape_of_variables(fgraph, {x: (5, 5)}) == {x: (5, 5), y: (5, 5)} y = x+x
fgraph = theano.FunctionGraph([x], [y], clone=False)
x = theano.tensor.matrix('x') shapes = shape_of_variables(fgraph, {x: (5, 5)})
y = theano.tensor.dot(x, x.T) assert shapes == {x: (5, 5), y: (5, 5)}
fgraph = theano.FunctionGraph([x], [y])
shapes = shape_of_variables(fgraph, {x: (5, 1)}) x = theano.tensor.matrix('x')
assert shapes[x] == (5, 1) y = theano.tensor.dot(x, x.T)
assert shapes[y] == (5, 5) fgraph = theano.FunctionGraph([x], [y], clone=False)
shapes = shape_of_variables(fgraph, {x: (5, 1)})
def test_shape_of_variables_subtensor(): assert shapes[x] == (5, 1)
x = theano.tensor.matrix('x') assert shapes[y] == (5, 5)
subx = x[1:]
fgraph = theano.FunctionGraph([x], [subx]) def test_subtensor(self):
shapes = shape_of_variables(fgraph, {x: (10, 10)}) x = theano.tensor.matrix('x')
assert shapes[subx] == (9, 10) subx = x[1:]
fgraph = theano.FunctionGraph([x], [subx], clone=False)
shapes = shape_of_variables(fgraph, {x: (10, 10)})
assert shapes[subx] == (9, 10)
def test_err(self):
x = theano.tensor.matrix('x')
subx = x[1:]
fgraph = theano.FunctionGraph([x], [subx])
self.assertRaises(ValueError, shape_of_variables, fgraph, {x: (10, 10)})
import numpy import numpy
import theano import theano
from theano.compat.python2x import any
from theano.gof.cc import hash_from_code from theano.gof.cc import hash_from_code
...@@ -69,7 +70,7 @@ def shape_of_variables(fgraph, input_shapes): ...@@ -69,7 +70,7 @@ def shape_of_variables(fgraph, input_shapes):
>>> import theano >>> import theano
>>> x = theano.tensor.matrix('x') >>> x = theano.tensor.matrix('x')
>>> y = x[512:]; y.name = 'y' >>> y = x[512:]; y.name = 'y'
>>> fgraph = theano.FunctionGraph([x], [y]) >>> fgraph = theano.FunctionGraph([x], [y], clone=False)
>>> shape_of_variables(fgraph, {x: (1024, 1024)}) >>> shape_of_variables(fgraph, {x: (1024, 1024)})
{y: (512, 1024), x: (1024, 1024)} {y: (512, 1024), x: (1024, 1024)}
""" """
...@@ -85,6 +86,12 @@ def shape_of_variables(fgraph, input_shapes): ...@@ -85,6 +86,12 @@ def shape_of_variables(fgraph, input_shapes):
compute_shapes = theano.function(input_dims, output_dims) compute_shapes = theano.function(input_dims, output_dims)
if any([i not in fgraph.inputs for i in input_shapes.keys()]):
raise ValueError(
"input_shapes keys aren't in the fgraph.inputs. FunctionGraph()"
" interface changed. Now by default, it clone the graph it receive."
" To have the old behavior, give him this new parameter `clone=False`.")
numeric_input_dims = [dim for inp in fgraph.inputs numeric_input_dims = [dim for inp in fgraph.inputs
for dim in input_shapes[inp]] for dim in input_shapes[inp]]
numeric_output_dims = compute_shapes(*numeric_input_dims) numeric_output_dims = compute_shapes(*numeric_input_dims)
......
import copy
import numpy import numpy
import theano import theano
from theano.compat import PY3 from theano.compat import PY3
from theano.compat.python2x import all
from theano.scalar import ComplexError, IntegerDivisionError from theano.scalar import ComplexError, IntegerDivisionError
from theano.gof import Constant, Variable from theano.gof import Constant, Variable
from theano.gof.utils import hashtype from theano.gof.utils import hashtype
...@@ -674,5 +675,14 @@ class TensorConstant(_tensor_py_operators, Constant): ...@@ -674,5 +675,14 @@ class TensorConstant(_tensor_py_operators, Constant):
other = theano.tensor.basic.constant(other) other = theano.tensor.basic.constant(other)
return (isinstance(other, TensorConstant) and return (isinstance(other, TensorConstant) and
self.signature() == other.signature()) self.signature() == other.signature())
def __copy__(self):
# We need to do this to remove the cached attribute
return type(self)(self.type, self.data, self.name)
def __deepcopy__(self, memo):
# We need to do this to remove the cached attribute
return type(self)(copy.deepcopy(self.type, memo),
copy.deepcopy(self.data, memo),
copy.deepcopy(self.name, memo))
TensorType.Constant = TensorConstant TensorType.Constant = TensorConstant
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论