More backporting

上级 3126f874
import numpy import numpy
import os import os
import scipy.sparse as sp import scipy.sparse as sp
from .. import gof from theano import gof
def check_equal(x, y): def check_equal(x, y):
""" """
......
...@@ -11,10 +11,12 @@ from theano.printing import pprint ...@@ -11,10 +11,12 @@ from theano.printing import pprint
import io, sys import io, sys
if sys.version_info[:2] >= (2,5): if sys.version_info[:2] >= (2,5):
from functools import partial
from collections import defaultdict from collections import defaultdict
else:
from theano.gof.python25 import any, all, defaultdict, partial
from itertools import chain from itertools import chain
if sys.version_info[:2] >= (2,5):
from functools import partial
import function_module as F import function_module as F
import mode as get_mode import mode as get_mode
......
...@@ -10,6 +10,8 @@ import md5 ...@@ -10,6 +10,8 @@ import md5
if sys.version_info[:2] >= (2,5): if sys.version_info[:2] >= (2,5):
import hashlib import hashlib
from theano.gof.python25 import any, all
# weave import # weave import
from scipy import weave from scipy import weave
......
...@@ -2,6 +2,8 @@ import sys ...@@ -2,6 +2,8 @@ import sys
if sys.version_info[:2] >= (2,5): if sys.version_info[:2] >= (2,5):
from functools import partial from functools import partial
else:
from theano.gof.python25 import partial
import graph import graph
......
...@@ -6,7 +6,7 @@ import numpy ...@@ -6,7 +6,7 @@ import numpy
from theano import gof from theano import gof
from theano.gof import Op, utils, Variable, Constant, Type, Apply, Env from theano.gof import Op, utils, Variable, Constant, Type, Apply, Env
from theano.gof.python25 import partial, all from theano.gof.python25 import partial, all, any
def upcast(dtype, *dtypes): def upcast(dtype, *dtypes):
z = numpy.zeros((), dtype = dtype) z = numpy.zeros((), dtype = dtype)
......
...@@ -18,7 +18,7 @@ from theano import gradient ...@@ -18,7 +18,7 @@ from theano import gradient
import elemwise import elemwise
from theano import scalar as scal from theano import scalar as scal
from theano.gof.python25 import partial, any from theano.gof.python25 import partial, any, all
from theano import compile, printing from theano import compile, printing
from theano.printing import pprint, Print from theano.printing import pprint, Print
......
"""Define RModule, a Module providing random number streams in Theano graphs.""" """Define RModule, a Module providing random number streams in Theano graphs."""
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
import sys import sys
import functools if sys.version_info[:2] >= (2,5):
from functools import partial from functools import partial
else:
from theano.gof.python25 import partial
from collections import deque from collections import deque
import numpy import numpy
from copy import copy from copy import copy
from ...compile import (SymbolicInputKit, SymbolicInput, from theano.compile import (SymbolicInputKit, SymbolicInput,
Module, module, Method, Member, In, Component) Module, module, Method, Member, In, Component)
from ...gof import Container from theano.gof import Container
from ...tensor import raw_random from theano.tensor import raw_random
class KitComponent(Component): class KitComponent(Component):
""" """
......
...@@ -8,7 +8,7 @@ from theano import scalar ...@@ -8,7 +8,7 @@ from theano import scalar
from theano.scalar import Scalar from theano.scalar import Scalar
from theano import printing from theano import printing
from theano.printing import pprint from theano.printing import pprint
from theano.gof.python25 import all from theano.gof.python25 import all, any
from copy import copy, deepcopy from copy import copy, deepcopy
......
...@@ -2,15 +2,15 @@ ...@@ -2,15 +2,15 @@
## Not all of those ops have been thoroughly tested. ## Not all of those ops have been thoroughly tested.
#from theano import tensor, scalar #from theano import tensor, scalar
from .. import gof from theano import gof
from .. import scalar from theano import scalar
from .. import printing from theano import printing
from ..printing import pprint from theano.printing import pprint
import basic as tensor import basic as tensor
import elemwise import elemwise
import numpy import numpy
import opt import opt
from ..compile import optdb from theano.compile import optdb
############ ############
# #
...@@ -404,8 +404,19 @@ def local_softmax_with_bias(node): ...@@ -404,8 +404,19 @@ def local_softmax_with_bias(node):
assert non_vectors #not empty assert non_vectors #not empty
if vectors: if vectors:
#we're in business... #we're in business...
vector_sum = tensor.add(*vectors) if len(vectors)>1 else vectors[0] if len(vectors)>1:
non_vector_sum = tensor.add(*non_vectors) if len(non_vectors)>1 else non_vectors[0] vector_sum = tensor.add(*vectors)
else:
vector_sum = vectors[0]
#backport
#vector_sum = tensor.add(*vectors) if len(vectors)>1 else vectors[0]
if len(non_vectors)>1:
non_vector_sum = tensor.add(*non_vectors)
else:
non_vector_sum = non_vectors[0]
#non_vector_sum = tensor.add(*non_vectors) if len(non_vectors)>1 else non_vectors[0]
try: try:
sm_bias = softmax_with_bias(non_vector_sum, vector_sum) sm_bias = softmax_with_bias(non_vector_sum, vector_sum)
except: except:
...@@ -909,8 +920,14 @@ def categorical_crossentropy(coding_dist, true_dist, axis=1): ...@@ -909,8 +920,14 @@ def categorical_crossentropy(coding_dist, true_dist, axis=1):
if true_dist.ndim == 2: if true_dist.ndim == 2:
return -theano.sum(true_dist * log(coding_dist), axis=axis) return -theano.sum(true_dist * log(coding_dist), axis=axis)
else: else:
if axis == 0:
retval = coding_dist.T
else:
retval = coding_dist,
return categorical_crossentropy_1hot( return categorical_crossentropy_1hot(
coding_dist.T if axis == 0 else coding_dist, #backport
#coding_dist.T if axis == 0 else coding_dist,
retval,
true_dist) true_dist)
......
...@@ -17,7 +17,7 @@ import sys ...@@ -17,7 +17,7 @@ import sys
from theano import compile #to register the optimizer built by this file from theano import compile #to register the optimizer built by this file
from theano.compile.debugmode import _debugprint from theano.compile.debugmode import _debugprint
from theano.gof.python25 import any from theano.gof.python25 import any, all
# Utilities # Utilities
......
...@@ -18,10 +18,18 @@ def fetch_seed(pseed=None): ...@@ -18,10 +18,18 @@ def fetch_seed(pseed=None):
""" """
seed = pseed or os.getenv("THEANO_UNITTEST_SEED", 666) seed = pseed or os.getenv("THEANO_UNITTEST_SEED", 666)
seed = None if seed=='random' else seed if seed=='random':
seed = None
#backport
#seed = None if seed=='random' else seed
try: try:
seed = int(seed) if seed else None if seed:
seed = int(seed)
else:
seed = None
#backport
#seed = int(seed) if seed else None
except ValueError: except ValueError:
print >> sys.stderr, 'Error: THEANO_UNITTEST_SEED contains '\ print >> sys.stderr, 'Error: THEANO_UNITTEST_SEED contains '\
'invalid seed, using None instead' 'invalid seed, using None instead'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论