提交 c87e9000 authored 作者: Frederic Bastien's avatar Frederic Bastien
...@@ -16,6 +16,9 @@ from mode import * ...@@ -16,6 +16,9 @@ from mode import *
from io import * from io import *
# used by function and module as the default compilation mode
mode_default = 'FAST_COMPILE'
def infer_reuse_pattern(env, outputs_to_disown): def infer_reuse_pattern(env, outputs_to_disown):
""" """
Given an env and a list of results, returns the list of all Given an env and a list of results, returns the list of all
...@@ -451,7 +454,7 @@ class FunctionMaker(object): ...@@ -451,7 +454,7 @@ class FunctionMaker(object):
raise TypeError("Unknown output type: %s (%s)", type(output), output) raise TypeError("Unknown output type: %s (%s)", type(output), output)
def __init__(self, inputs, outputs, def __init__(self, inputs, outputs,
mode = 'FAST_COMPILE', accept_inplace = False, function_builder = Function): mode = mode_default, accept_inplace = False, function_builder = Function):
""" """
:type inputs: a list of SymbolicInput instances :type inputs: a list of SymbolicInput instances
...@@ -641,7 +644,7 @@ def register_checker(checker): ...@@ -641,7 +644,7 @@ def register_checker(checker):
def function(inputs, outputs, mode='FAST_RUN', accept_inplace = False): def function(inputs, outputs, mode=mode_default, accept_inplace = False):
""" """
Return a function calculating the outputs from the inputs. Return a function calculating the outputs from the inputs.
......
...@@ -125,7 +125,7 @@ class Component(object): ...@@ -125,7 +125,7 @@ class Component(object):
""" """
raise NotImplementedError raise NotImplementedError
def make_no_init(self, mode='FAST_COMPILE'): def make_no_init(self, mode=F.mode_default):
""" """
Allocates the necessary containers using allocate() and uses Allocates the necessary containers using allocate() and uses
build() with the provided mode to make an instance which will build() with the provided mode to make an instance which will
...@@ -145,7 +145,7 @@ class Component(object): ...@@ -145,7 +145,7 @@ class Component(object):
arguments and the keyword arguments. If 'mode' is in the arguments and the keyword arguments. If 'mode' is in the
keyword arguments it will be passed to build(). keyword arguments it will be passed to build().
""" """
mode = kwargs.pop('mode', 'FAST_COMPILE') mode = kwargs.pop('mode', F.mode_default)
rval = self.make_no_init(mode) rval = self.make_no_init(mode)
if hasattr(rval, 'initialize'): if hasattr(rval, 'initialize'):
rval.initialize(*args, **kwargs) rval.initialize(*args, **kwargs)
...@@ -956,7 +956,7 @@ class Module(ComponentDict): ...@@ -956,7 +956,7 @@ class Module(ComponentDict):
""" """
self.make_mi(args,kwargs) self.make_mi(args,kwargs)
mode = kwargs.pop('mode', 'FAST_COMPILE') mode = kwargs.pop('mode', F.mode_default)
rval = self.make_no_init(mode) rval = self.make_no_init(mode)
if hasattr(rval, 'initialize'): if hasattr(rval, 'initialize'):
rval.initialize(*args, **kwargs) rval.initialize(*args, **kwargs)
......
...@@ -236,6 +236,7 @@ class PPrinter: ...@@ -236,6 +236,7 @@ class PPrinter:
return "\n".join(s[1] for s in strings) return "\n".join(s[1] for s in strings)
def __call__(self, *args): def __call__(self, *args):
print sys.stderr, "WARNING: PPrinter bug. Is theano ticket #249 fixed yet?"
if len(args) == 1: if len(args) == 1:
return self.process(*args) return self.process(*args)
elif len(args) == 2 and isinstance(args[1], (PrinterState, dict)): elif len(args) == 2 and isinstance(args[1], (PrinterState, dict)):
......
...@@ -11,6 +11,7 @@ broken. ...@@ -11,6 +11,7 @@ broken.
""" """
import time, copy, sys import time, copy, sys
from StringIO import StringIO
from .. import gof from .. import gof
...@@ -21,6 +22,7 @@ from ..gof.cc import OpWiseCLinker, CLinker ...@@ -21,6 +22,7 @@ from ..gof.cc import OpWiseCLinker, CLinker
from ..compile.mode import Mode from ..compile.mode import Mode
import numpy import numpy
from ..compile.function_module import (convert_function_input, from ..compile.function_module import (convert_function_input,
FunctionMaker, FunctionMaker,
predefined_modes, predefined_modes,
...@@ -31,6 +33,24 @@ from ..compile.function_module import (convert_function_input, ...@@ -31,6 +33,24 @@ from ..compile.function_module import (convert_function_input,
SymbolicOutput, SymbolicOutput,
Supervisor) Supervisor)
def debugprint(a, prefix='', depth=-1, done=None, file=sys.stdout):
if depth==0:
return
done = set() if done is None else done
if hasattr(a, 'op'):
print >> file, prefix, a.op, id(a)
if id(a) not in done:
done.add(id(a))
for i in a.inputs:
if i.owner:
debugprint(i.owner, prefix+' ', depth=depth-1, done=done, file=file)
else:
print >> file, prefix+' ', i, id(i)
else:
print >> file, prefix+' ', a, id(a)
return file
class ResultEquivalenceTracker(object): class ResultEquivalenceTracker(object):
def __init__(self): def __init__(self):
self.env = None self.env = None
...@@ -43,6 +63,7 @@ class ResultEquivalenceTracker(object): ...@@ -43,6 +63,7 @@ class ResultEquivalenceTracker(object):
self.env = env self.env = env
self.all_results_ever = [] self.all_results_ever = []
self.reasons = {} self.reasons = {}
self.snapshots = {}
def on_detach(self, env): def on_detach(self, env):
assert env is self.env assert env is self.env
...@@ -70,15 +91,22 @@ class ResultEquivalenceTracker(object): ...@@ -70,15 +91,22 @@ class ResultEquivalenceTracker(object):
self.equiv[r] = set([r]) self.equiv[r] = set([r])
self.all_results_ever.append(r) self.all_results_ever.append(r)
self.reasons.setdefault(r, []) self.reasons.setdefault(r, [])
self.snapshots.setdefault(r, [])
for r in node.inputs: for r in node.inputs:
self.reasons.setdefault(r, []) self.reasons.setdefault(r, [])
self.snapshots.setdefault(r, [])
def on_change_input(self, env, node, i, r, new_r, reason=None): def on_change_input(self, env, node, i, r, new_r, reason=None):
#print 'CHANGE by', reason, 'to use', new_r, type(new_r) #print 'CHANGE by', reason, 'to use', new_r, type(new_r)
self.reasons.setdefault(new_r, []) self.reasons.setdefault(new_r, [])
self.snapshots.setdefault(new_r, [])
if (reason, r) not in self.reasons[new_r]: if (reason, r) not in self.reasons[new_r]:
self.reasons[new_r].append((reason, r)) self.reasons[new_r].append((reason, r))
self.snapshots[new_r].append((
reason,
debugprint(r.owner, prefix=' ', depth=6, file=StringIO()).getvalue(),
debugprint(new_r.owner,prefix=' ', depth=6, file=StringIO()).getvalue()))
self.reasons[r].append(('replaced by', new_r)) self.reasons[r].append(('replaced by', new_r))
if r in self.equiv: if r in self.equiv:
...@@ -252,6 +280,12 @@ class OptCheckLinker(OpWiseCLinker): ...@@ -252,6 +280,12 @@ class OptCheckLinker(OpWiseCLinker):
print " Value Type:", type(r_vals[r]) print " Value Type:", type(r_vals[r])
print " Value: ", r_vals[r] print " Value: ", r_vals[r]
print " Reason: ", [(str(reason), id(old_r)) for reason, old_r in env.equivalence_tracker.reasons[r]] print " Reason: ", [(str(reason), id(old_r)) for reason, old_r in env.equivalence_tracker.reasons[r]]
print " Snapshots:"
for s in env.equivalence_tracker.snapshots[r]:
print " BEFORE"
print s[1]
print " AFTER"
print s[2]
print "" print ""
raise Exception("OptCheckFailure") raise Exception("OptCheckFailure")
......
...@@ -801,6 +801,7 @@ def local_greedy_distributor(node): ...@@ -801,6 +801,7 @@ def local_greedy_distributor(node):
increase numerical stability, e.g. when x and/or y tend to 0 in increase numerical stability, e.g. when x and/or y tend to 0 in
example 1. example 1.
""" """
out = node.outputs[0] out = node.outputs[0]
num, denum = local_mul_canonizer.get_num_denum(out) num, denum = local_mul_canonizer.get_num_denum(out)
if len(num) == 1 and not denum: if len(num) == 1 and not denum:
...@@ -816,8 +817,7 @@ def local_greedy_distributor(node): ...@@ -816,8 +817,7 @@ def local_greedy_distributor(node):
num.remove(candidate) num.remove(candidate)
_change, candidate, num, denum = attempt_distribution(candidate, num, denum) _change, candidate, num, denum = attempt_distribution(candidate, num, denum)
change |= _change change |= _change
if change: new_num.append(candidate)
new_num.append(candidate)
for candidate in list(denum): for candidate in list(denum):
if candidate not in denum: if candidate not in denum:
...@@ -825,8 +825,7 @@ def local_greedy_distributor(node): ...@@ -825,8 +825,7 @@ def local_greedy_distributor(node):
denum.remove(candidate) denum.remove(candidate)
_change, candidate, denum, num = attempt_distribution(candidate, denum, num) _change, candidate, denum, num = attempt_distribution(candidate, denum, num)
change |= _change change |= _change
if change: new_denum.append(candidate)
new_denum.append(candidate)
if not change: if not change:
return False return False
......
...@@ -13,6 +13,9 @@ from theano import pprint ...@@ -13,6 +13,9 @@ from theano import pprint
import numpy import numpy
#import scalar_opt #import scalar_opt
from theano.sandbox.debugmode import OptCheck
from theano import function
def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)): def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)):
x = Tensor(broadcastable = xbc, dtype = 'float64')('x') x = Tensor(broadcastable = xbc, dtype = 'float64')('x')
...@@ -119,7 +122,26 @@ class test_greedy_distribute(unittest.TestCase): ...@@ -119,7 +122,26 @@ class test_greedy_distribute(unittest.TestCase):
gof.TopoOptimizer(gof.LocalOptGroup(local_fill_cut, local_fill_lift), order = 'out_to_in').optimize(g) gof.TopoOptimizer(gof.LocalOptGroup(local_fill_cut, local_fill_lift), order = 'out_to_in').optimize(g)
gof.TopoOptimizer(gof.LocalOptGroup(local_greedy_distributor), order = 'out_to_in').optimize(g) gof.TopoOptimizer(gof.LocalOptGroup(local_greedy_distributor), order = 'out_to_in').optimize(g)
##print pprint(g.outputs[0]) ##print pprint(g.outputs[0])
def test_kording_bug(self):
x, y = vectors('xy')
eps = scalar('eps')
s = scalar('s')
#r = theano.tensor.mul(theano.tensor.fill(x, 2.*a), x/a , (y+z) , a)
#r = theano.tensor.mul((x/a+y) , a, z)
r = mul(
s - 1
, eps + x/s
, eps + y/s
, s)
f = function([s, eps, x,y], r**2, mode=OptCheck())
r0 = f(4,1.e-6, [1.5,2], [2.3,3.1])
r1 = f(4,1.e-6, [1.5,2], [2.3,3.1])
r2 = f(4,1.e-6, [1.5,2], [2.3,3.1])
class test_canonize(unittest.TestCase): class test_canonize(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论