提交 a812ee64 authored 作者: Frederic Bastien's avatar Frederic Bastien

merge

......@@ -16,6 +16,9 @@ from mode import *
from io import *
# used by function and module as the default compilation mode
mode_default = 'FAST_COMPILE'
def infer_reuse_pattern(env, outputs_to_disown):
"""
Given an env and a list of results, returns the list of all
......@@ -451,7 +454,7 @@ class FunctionMaker(object):
raise TypeError("Unknown output type: %s (%s)", type(output), output)
def __init__(self, inputs, outputs,
mode = 'FAST_COMPILE', accept_inplace = False, function_builder = Function):
mode = mode_default, accept_inplace = False, function_builder = Function):
"""
:type inputs: a list of SymbolicInput instances
......@@ -641,7 +644,7 @@ def register_checker(checker):
def function(inputs, outputs, mode='FAST_RUN', accept_inplace = False):
def function(inputs, outputs, mode=mode_default, accept_inplace = False):
"""
Return a function calculating the outputs from the inputs.
......
......@@ -125,7 +125,7 @@ class Component(object):
"""
raise NotImplementedError
def make_no_init(self, mode='FAST_COMPILE'):
def make_no_init(self, mode=F.mode_default):
"""
Allocates the necessary containers using allocate() and uses
build() with the provided mode to make an instance which will
......@@ -145,7 +145,7 @@ class Component(object):
arguments and the keyword arguments. If 'mode' is in the
keyword arguments it will be passed to build().
"""
mode = kwargs.pop('mode', 'FAST_COMPILE')
mode = kwargs.pop('mode', F.mode_default)
rval = self.make_no_init(mode)
if hasattr(rval, 'initialize'):
rval.initialize(*args, **kwargs)
......@@ -956,7 +956,7 @@ class Module(ComponentDict):
"""
self.make_mi(args,kwargs)
mode = kwargs.pop('mode', 'FAST_COMPILE')
mode = kwargs.pop('mode', F.mode_default)
rval = self.make_no_init(mode)
if hasattr(rval, 'initialize'):
rval.initialize(*args, **kwargs)
......
......@@ -236,6 +236,7 @@ class PPrinter:
return "\n".join(s[1] for s in strings)
def __call__(self, *args):
print sys.stderr, "WARNING: PPrinter bug. Is theano ticket #249 fixed yet?"
if len(args) == 1:
return self.process(*args)
elif len(args) == 2 and isinstance(args[1], (PrinterState, dict)):
......
......@@ -11,6 +11,7 @@ broken.
"""
import time, copy, sys
from StringIO import StringIO
from .. import gof
......@@ -21,6 +22,7 @@ from ..gof.cc import OpWiseCLinker, CLinker
from ..compile.mode import Mode
import numpy
from ..compile.function_module import (convert_function_input,
FunctionMaker,
predefined_modes,
......@@ -31,6 +33,24 @@ from ..compile.function_module import (convert_function_input,
SymbolicOutput,
Supervisor)
def debugprint(a, prefix='', depth=-1, done=None, file=sys.stdout):
if depth==0:
return
done = set() if done is None else done
if hasattr(a, 'op'):
print >> file, prefix, a.op, id(a)
if id(a) not in done:
done.add(id(a))
for i in a.inputs:
if i.owner:
debugprint(i.owner, prefix+' ', depth=depth-1, done=done, file=file)
else:
print >> file, prefix+' ', i, id(i)
else:
print >> file, prefix+' ', a, id(a)
return file
class ResultEquivalenceTracker(object):
def __init__(self):
self.env = None
......@@ -43,6 +63,7 @@ class ResultEquivalenceTracker(object):
self.env = env
self.all_results_ever = []
self.reasons = {}
self.snapshots = {}
def on_detach(self, env):
assert env is self.env
......@@ -70,15 +91,22 @@ class ResultEquivalenceTracker(object):
self.equiv[r] = set([r])
self.all_results_ever.append(r)
self.reasons.setdefault(r, [])
self.snapshots.setdefault(r, [])
for r in node.inputs:
self.reasons.setdefault(r, [])
self.snapshots.setdefault(r, [])
def on_change_input(self, env, node, i, r, new_r, reason=None):
#print 'CHANGE by', reason, 'to use', new_r, type(new_r)
self.reasons.setdefault(new_r, [])
self.snapshots.setdefault(new_r, [])
if (reason, r) not in self.reasons[new_r]:
self.reasons[new_r].append((reason, r))
self.snapshots[new_r].append((
reason,
debugprint(r.owner, prefix=' ', depth=6, file=StringIO()).getvalue(),
debugprint(new_r.owner,prefix=' ', depth=6, file=StringIO()).getvalue()))
self.reasons[r].append(('replaced by', new_r))
if r in self.equiv:
......@@ -252,6 +280,12 @@ class OptCheckLinker(OpWiseCLinker):
print " Value Type:", type(r_vals[r])
print " Value: ", r_vals[r]
print " Reason: ", [(str(reason), id(old_r)) for reason, old_r in env.equivalence_tracker.reasons[r]]
print " Snapshots:"
for s in env.equivalence_tracker.snapshots[r]:
print " BEFORE"
print s[1]
print " AFTER"
print s[2]
print ""
raise Exception("OptCheckFailure")
......
......@@ -801,6 +801,7 @@ def local_greedy_distributor(node):
increase numerical stability, e.g. when x and/or y tend to 0 in
example 1.
"""
out = node.outputs[0]
num, denum = local_mul_canonizer.get_num_denum(out)
if len(num) == 1 and not denum:
......@@ -816,7 +817,6 @@ def local_greedy_distributor(node):
num.remove(candidate)
_change, candidate, num, denum = attempt_distribution(candidate, num, denum)
change |= _change
if change:
new_num.append(candidate)
for candidate in list(denum):
......@@ -825,7 +825,6 @@ def local_greedy_distributor(node):
denum.remove(candidate)
_change, candidate, denum, num = attempt_distribution(candidate, denum, num)
change |= _change
if change:
new_denum.append(candidate)
if not change:
......
......@@ -13,6 +13,9 @@ from theano import pprint
import numpy
#import scalar_opt
from theano.sandbox.debugmode import OptCheck
from theano import function
def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)):
x = Tensor(broadcastable = xbc, dtype = 'float64')('x')
......@@ -120,6 +123,25 @@ class test_greedy_distribute(unittest.TestCase):
gof.TopoOptimizer(gof.LocalOptGroup(local_greedy_distributor), order = 'out_to_in').optimize(g)
##print pprint(g.outputs[0])
def test_kording_bug(self):
x, y = vectors('xy')
eps = scalar('eps')
s = scalar('s')
#r = theano.tensor.mul(theano.tensor.fill(x, 2.*a), x/a , (y+z) , a)
#r = theano.tensor.mul((x/a+y) , a, z)
r = mul(
s - 1
, eps + x/s
, eps + y/s
, s)
f = function([s, eps, x,y], r**2, mode=OptCheck())
r0 = f(4,1.e-6, [1.5,2], [2.3,3.1])
r1 = f(4,1.e-6, [1.5,2], [2.3,3.1])
r2 = f(4,1.e-6, [1.5,2], [2.3,3.1])
class test_canonize(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论