提交 5ee86171 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1281 from abalkin/py3k-fixes

Fixes for python 3 compatibility.
......@@ -30,7 +30,12 @@ if PY3:
from itertools import combinations, product
from sys import maxsize
def decode(x):
return x.decode()
def decode_iter(itr):
for x in itr:
yield x.decode()
else:
from operator import div as operator_div
......@@ -44,3 +49,9 @@ else:
from theano.compat.python2x import all, any, partial, defaultdict, deque
from theano.compat.python2x import combinations, product, maxsize
from theano.compat.python2x import DictMixin, OrderedDict
def decode(x):
return x
def decode_iter(x):
return x
......@@ -13,13 +13,14 @@ import subprocess
import sys
import tempfile
import time
import itertools
import distutils.sysconfig
import numpy.distutils # TODO: TensorType should handle this
import theano
from theano.compat import PY3, b, next
from theano.compat import PY3, next, decode, decode_iter
from theano.gof.utils import flatten
from theano.configparser import config
from theano.gof.cc import hash_from_code
......@@ -1470,6 +1471,34 @@ def gcc_version():
return gcc_version_str
def gcc_llvm():
""" Detect if the g++ version used is the llvm one or not.
It don't support all g++ parameters even if it support many of them.
"""
if gcc_llvm.is_llvm is None:
pass
p = None
try:
p = call_subprocess_Popen(['g++', '--version'],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
p.wait()
output = p.stdout.read() + p.stderr.read()
except OSError:
# Typically means g++ cannot be found.
# So it is not an llvm compiler.
# Normally this should not happen as we should not try to
# compile when g++ is not available. If this happen, it
# will crash later so supposing it is not llvm is "safe".
output = b('')
del p
gcc_llvm.is_llvm = b("llvm") in output
return gcc_llvm.is_llvm
gcc_llvm.is_llvm = None
class GCC_compiler(object):
# The equivalent flags of --march=native used by g++.
march_flags = None
......@@ -1515,11 +1544,11 @@ class GCC_compiler(object):
if p.returncode != 0:
return None
stdout = p.stdout.readlines()
stderr = p.stderr.readlines()
stdout = decode_iter(p.stdout.readlines())
stderr = decode_iter(p.stderr.readlines())
lines = []
if parse:
for line in stdout + stderr:
for line in itertools.chain(stdout, stderr):
if "COLLECT_GCC_OPTIONS=" in line:
continue
elif "-march=" in line and "-march=native" not in line:
......@@ -1528,8 +1557,8 @@ class GCC_compiler(object):
lines.append(line.strip())
lines = list(set(lines)) # to remove duplicate
else:
lines = stdout + stderr
return lines
lines = itertools.chain(stdout, stderr)
return list(lines)
# The '-' at the end is needed. Otherwise, g++ do not output
# enough information.
......@@ -1804,7 +1833,7 @@ class GCC_compiler(object):
try:
p = call_subprocess_Popen(cmd, stderr=subprocess.PIPE)
compile_stderr = p.communicate()[1]
compile_stderr = decode(p.communicate()[1])
except Exception:
# An exception can occur e.g. if `g++` is not found.
print_command_line_error()
......@@ -1825,7 +1854,7 @@ class GCC_compiler(object):
# prints the exception, having '\n' in the text makes it more
# difficult to read.
raise Exception('Compilation failed (return status=%s): %s' %
(status, compile_stderr.replace(b('\n'), b('. '))))
(status, compile_stderr.replace('\n', '. ')))
elif config.cmodule.compilation_warning and compile_stderr:
# Print errors just below the command line.
print compile_stderr
......
......@@ -785,8 +785,7 @@ def scan(fn,
not isinstance(x, SharedVariable) and
not isinstance(x, gof.Constant)),
gof.graph.inputs(fake_outputs))
extra_inputs = filter(lambda x: x not in args + fake_nonseqs,
all_inputs)
extra_inputs = [x for x in all_inputs if x not in args + fake_nonseqs]
non_seqs += extra_inputs
## Note we do not use all_inputs directly since the order of variables
## in args is quite important
......
......@@ -17,6 +17,7 @@ from theano.gof.python25 import any
from theano.tests import unittest_tools as utt
import theano.scalar.sharedvar
from theano.gof.python25 import OrderedDict
from theano.compat import PY3
from numpy.testing.noseclasses import KnownFailureTest
......@@ -2366,8 +2367,8 @@ class T_Scan(unittest.TestCase):
f = theano.function([x, y], [sx, sy],
mode=mode_with_opt.excluding('scanOp_pushout_seqs_ops'))
topo = f.maker.fgraph.toposort()
scans = filter(lambda n: isinstance(
n.op, theano.scan_module.scan_op.Scan), topo)
scans = [n for n in topo if isinstance(
n.op, theano.scan_module.scan_op.Scan)]
self.assertTrue(len(scans) == 2)
sx, upx = theano.scan(sum, sequences=[x], n_steps=2)
......@@ -2376,8 +2377,8 @@ class T_Scan(unittest.TestCase):
f = theano.function([x, y], [sx, sy],
mode=mode_with_opt.excluding('scanOp_pushout_seqs_ops'))
topo = f.maker.fgraph.toposort()
scans = filter(lambda n: isinstance(
n.op, theano.scan_module.scan_op.Scan), topo)
scans = [n for n in topo if isinstance(
n.op, theano.scan_module.scan_op.Scan)]
self.assertTrue(len(scans) == 2)
sx, upx = theano.scan(sum, sequences=[x], n_steps=4)
......@@ -2386,8 +2387,8 @@ class T_Scan(unittest.TestCase):
f = theano.function([x, y], [sx, sy],
mode=mode_with_opt.excluding('scanOp_pushout_seqs_ops'))
topo = f.maker.fgraph.toposort()
scans = filter(lambda n: isinstance(
n.op, theano.scan_module.scan_op.Scan), topo)
scans = [n for n in topo if isinstance(
n.op, theano.scan_module.scan_op.Scan)]
self.assertTrue(len(scans) == 1)
sx, upx = theano.scan(sum, sequences=[x])
......@@ -2396,8 +2397,8 @@ class T_Scan(unittest.TestCase):
f = theano.function([x], [sx, sy],
mode=mode_with_opt.excluding('scanOp_pushout_seqs_ops'))
topo = f.maker.fgraph.toposort()
scans = filter(lambda n:
isinstance(n.op, theano.scan_module.scan_op.Scan), topo)
scans = [n for n in topo if isinstance(
n.op, theano.scan_module.scan_op.Scan)]
self.assertTrue(len(scans) == 1)
sx, upx = theano.scan(sum, sequences=[x])
......@@ -2406,8 +2407,8 @@ class T_Scan(unittest.TestCase):
f = theano.function([x], [sx, sy],
mode=mode_with_opt.excluding('scanOp_pushout_seqs_ops'))
topo = f.maker.fgraph.toposort()
scans = filter(lambda n:
isinstance(n.op, theano.scan_module.scan_op.Scan), topo)
scans = [n for n in topo if isinstance(
n.op, theano.scan_module.scan_op.Scan)]
self.assertTrue(len(scans) == 1)
sx, upx = theano.scan(sum, sequences=[x])
......@@ -2416,8 +2417,8 @@ class T_Scan(unittest.TestCase):
f = theano.function([x], [sx, sy],
mode=mode_with_opt.excluding('scanOp_pushout_seqs_ops'))
topo = f.maker.fgraph.toposort()
scans = filter(lambda n:
isinstance(n.op, theano.scan_module.scan_op.Scan), topo)
scans = [n for n in topo if isinstance(
n.op, theano.scan_module.scan_op.Scan)]
self.assertTrue(len(scans) == 2)
def test_hash(self):
......@@ -3515,12 +3516,20 @@ def test_speed():
t0 = time.time()
r_i = iter(r[1:])
r_ii = iter(r[:-1])
while True:
try:
tmp = r_i.next()
tmp += r_ii.next()
except StopIteration:
break
if PY3:
while True:
try:
tmp = next(r_i)
tmp += next(r_ii)
except StopIteration:
break
else:
while True:
try:
tmp = r_i.next()
tmp += r_ii.next()
except StopIteration:
break
t1 = time.time()
print 'python with builtin iterator', t1 - t0
......
......@@ -4609,12 +4609,13 @@ class Subtensor(Op):
# There is a bug in numpy that results in isinstance(x, int) returning
# False for numpy integers.
# See <http://projects.scipy.org/numpy/ticket/2235>.
elif isinstance(entry, (numpy.integer, int)):
elif isinstance(entry, numpy.integer):
return entry
# On Windows 64-bit, shapes are returned as Python long, as they can
# be bigger than what a Python int can hold.
# Shapes should always fit in a numpy.int64, and we support them better
elif isinstance(entry, long):
# 2) In Python3, long replaced int. So we must assert it fit in int64.
elif isinstance(entry, (int, long)):
entry64 = numpy.int64(entry)
return entry64
else:
......
......@@ -1625,7 +1625,7 @@ def local_useless_subtensor(node):
except NotScalarConstantError:
pass
if isinstance(idx.stop, int):
if isinstance(idx.stop, (int, numpy.integer)):
if idx.stop < length_pos_data:
return False
elif isinstance(idx.stop, theano.scalar.Scalar):
......
......@@ -6438,7 +6438,7 @@ class T_long_tensor(unittest.TestCase):
for exp in xrange(64):
val = 2L ** exp - 1
scalar_ct = constant(val)
assert scalar_ct.dtype == 'int64'
assert scalar_ct.dtype.startswith('int')
assert scalar_ct.value == val
vector_ct = constant([val, val])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论