提交 999cbee0 authored 作者: abergeron's avatar abergeron

Merge pull request #2115 from nouiz/py3

Fix optimization not always registered and python 3 crash
...@@ -4,7 +4,8 @@ from theano import config ...@@ -4,7 +4,8 @@ from theano import config
from theano.tensor.opt import in2out from theano.tensor.opt import in2out
from theano.tensor.blas import ldflags, blas_header_text, blas_header_version from theano.tensor.blas import ldflags, blas_header_text, blas_header_version
from theano.tensor.blas import blas_optdb, optdb, local_optimizer, EquilibriumOptimizer from theano.tensor.blas import (
blas_optdb, optdb, local_optimizer, EquilibriumOptimizer)
from theano.tensor.blas import Ger, ger, ger_destructive from theano.tensor.blas import Ger, ger, ger_destructive
from theano.tensor.blas import Gemv, gemv_inplace, gemv_no_inplace from theano.tensor.blas import Gemv, gemv_inplace, gemv_no_inplace
from theano.tensor import basic as T from theano.tensor import basic as T
...@@ -28,9 +29,9 @@ class BaseBLAS(object): ...@@ -28,9 +29,9 @@ class BaseBLAS(object):
return blas_header_text() return blas_header_text()
####### ####### ####### # ##### ####### #######
# GER # GER
####### ####### ####### # ##### ####### #######
def ger_c_code(A, a, x, y, Z, destructive, fail): def ger_c_code(A, a, x, y, Z, destructive, fail):
return """ return """
...@@ -250,8 +251,8 @@ class CGer(BaseBLAS, Ger): ...@@ -250,8 +251,8 @@ class CGer(BaseBLAS, Ger):
A, a, x, y = inp A, a, x, y = inp
Z, = out Z, = out
code = ger_c_code(A, a, x, y, Z, code = ger_c_code(A, a, x, y, Z,
destructive=int(self.destructive), destructive=int(self.destructive),
fail=sub['fail']) fail=sub['fail'])
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
...@@ -279,12 +280,13 @@ def make_c_ger_destructive(node): ...@@ -279,12 +280,13 @@ def make_c_ger_destructive(node):
return [cger_inplace(*node.inputs)] return [cger_inplace(*node.inputs)]
####### ####### ####### # ##### ####### #######
# GEMV # GEMV
####### ####### ####### # ##### ####### #######
def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail, force_init_beta=False): def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail,
force_init_beta=False):
""" """
zz <- beta * aa + alpha * dot(xx, yy) zz <- beta * aa + alpha * dot(xx, yy)
...@@ -611,18 +613,17 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail, force_init_beta= ...@@ -611,18 +613,17 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail, force_init_beta=
class CGemv(BaseBLAS, Gemv): class CGemv(BaseBLAS, Gemv):
def __init__(self, inplace, force_init_beta=False): def __init__(self, inplace, force_init_beta=False):
super(CGemv, self).__init__(inplace) super(CGemv, self).__init__(inplace)
self.force_init_beta = force_init_beta
self.force_init_beta = force_init_beta
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
aa, alpha, xx, yy, beta = inp aa, alpha, xx, yy, beta = inp
zz, = out zz, = out
code = gemv_c_code( code = gemv_c_code(
aa, xx, yy, zz, alpha, beta, aa, xx, yy, zz, alpha, beta,
destructive=int(self.inplace), destructive=int(self.inplace),
fail=sub['fail'], fail=sub['fail'],
force_init_beta=self.force_init_beta force_init_beta=self.force_init_beta
) )
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
...@@ -630,6 +631,7 @@ class CGemv(BaseBLAS, Gemv): ...@@ -630,6 +631,7 @@ class CGemv(BaseBLAS, Gemv):
cgemv_inplace = CGemv(inplace=True) cgemv_inplace = CGemv(inplace=True)
cgemv_no_inplace = CGemv(inplace=False) cgemv_no_inplace = CGemv(inplace=False)
def check_force_gemv_init(): def check_force_gemv_init():
if check_force_gemv_init._force_init_beta is None: if check_force_gemv_init._force_init_beta is None:
""" """
...@@ -680,6 +682,7 @@ def check_force_gemv_init(): ...@@ -680,6 +682,7 @@ def check_force_gemv_init():
check_force_gemv_init._force_init_beta = None check_force_gemv_init._force_init_beta = None
@local_optimizer([gemv_inplace, gemv_no_inplace]) @local_optimizer([gemv_inplace, gemv_no_inplace])
def use_c_gemv(node): def use_c_gemv(node):
if not config.blas.ldflags: if not config.blas.ldflags:
...@@ -709,7 +712,8 @@ def use_c_gemv(node): ...@@ -709,7 +712,8 @@ def use_c_gemv(node):
""" """
force_init_beta = check_force_gemv_init() force_init_beta = check_force_gemv_init()
return [CGemv(inplace=False, force_init_beta=force_init_beta)(*node.inputs)] return [CGemv(inplace=False,
force_init_beta=force_init_beta)(*node.inputs)]
if (node.op == gemv_inplace and if (node.op == gemv_inplace and
node.outputs[0].dtype in ['float32', 'float64']): node.outputs[0].dtype in ['float32', 'float64']):
return [CGemv(inplace=True)(*node.inputs)] return [CGemv(inplace=True)(*node.inputs)]
...@@ -721,15 +725,13 @@ def make_c_gemv_destructive(node): ...@@ -721,15 +725,13 @@ def make_c_gemv_destructive(node):
return [cgemv_inplace(*node.inputs)] return [cgemv_inplace(*node.inputs)]
####### ####### ####### # ##### ####### #######
# Optimizers # Optimizers
####### ####### ####### # ##### ####### #######
blas_optdb.register('use_c_blas', blas_optdb.register('use_c_blas',
in2out(use_c_ger, use_c_gemv), in2out(use_c_ger, use_c_gemv),
20, 'fast_run', 'c_blas') 20, 'fast_run', 'c_blas')
#print 'BLAS_OPTDB'
#print blas_optdb
# this matches the InplaceBlasOpt defined in blas.py # this matches the InplaceBlasOpt defined in blas.py
optdb.register('c_blas_destructive', optdb.register('c_blas_destructive',
......
...@@ -312,7 +312,7 @@ compile.optdb.register('inplace_elemwise_opt', inplace_elemwise_optimizer, 75, ...@@ -312,7 +312,7 @@ compile.optdb.register('inplace_elemwise_opt', inplace_elemwise_optimizer, 75,
def register_canonicalize(lopt, *tags, **kwargs): def register_canonicalize(lopt, *tags, **kwargs):
if type(lopt) == str: if type(lopt) == str:
def register(inner_lopt): def register(inner_lopt):
return register_canonicalize(inner_lopt, *tags, **kwargs) return register_canonicalize(inner_lopt, lopt, *tags, **kwargs)
return register return register
else: else:
name = (kwargs and kwargs.pop('name')) or lopt.__name__ name = (kwargs and kwargs.pop('name')) or lopt.__name__
...@@ -323,7 +323,7 @@ def register_canonicalize(lopt, *tags, **kwargs): ...@@ -323,7 +323,7 @@ def register_canonicalize(lopt, *tags, **kwargs):
def register_stabilize(lopt, *tags, **kwargs): def register_stabilize(lopt, *tags, **kwargs):
if type(lopt) == str: if type(lopt) == str:
def register(inner_lopt): def register(inner_lopt):
return register_stabilize(inner_lopt, *tags, **kwargs) return register_stabilize(inner_lopt, lopt, *tags, **kwargs)
return register return register
else: else:
name = (kwargs and kwargs.pop('name')) or lopt.__name__ name = (kwargs and kwargs.pop('name')) or lopt.__name__
...@@ -334,7 +334,7 @@ def register_stabilize(lopt, *tags, **kwargs): ...@@ -334,7 +334,7 @@ def register_stabilize(lopt, *tags, **kwargs):
def register_specialize(lopt, *tags, **kwargs): def register_specialize(lopt, *tags, **kwargs):
if type(lopt) == str: if type(lopt) == str:
def register(inner_lopt): def register(inner_lopt):
return register_specialize(inner_lopt, *tags, **kwargs) return register_specialize(inner_lopt, lopt, *tags, **kwargs)
return register return register
else: else:
name = (kwargs and kwargs.pop('name')) or lopt.__name__ name = (kwargs and kwargs.pop('name')) or lopt.__name__
......
...@@ -3198,6 +3198,18 @@ class T_useless_elemwise(unittest.TestCase): ...@@ -3198,6 +3198,18 @@ class T_useless_elemwise(unittest.TestCase):
assert topo[0].op == deep_copy_op assert topo[0].op == deep_copy_op
def test_constant_folding():
""" Test that constant folding get registered at fast_compile
An error removed that registration during the registration.
"""
x = tensor.dvector()
mode = theano.compile.get_mode("FAST_COMPILE").excluding("fusion")
f = theano.function([x], [x * 2, x + x], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
def test_constant_get_stabilized(): def test_constant_get_stabilized():
""" """
Currently Theano enable the constant_folding optimization before stabilization optimization. Currently Theano enable the constant_folding optimization before stabilization optimization.
......
...@@ -58,6 +58,11 @@ class SliceType(Type): ...@@ -58,6 +58,11 @@ class SliceType(Type):
def __hash__(self): def __hash__(self):
return hashtype(self) return hashtype(self)
@staticmethod
def may_share_memory(a, b):
# Slices never shared memory between object
return isinstance(a, slice) and a is b
slicetype = SliceType() slicetype = SliceType()
...@@ -72,6 +77,12 @@ class NoneTypeT(Generic): ...@@ -72,6 +77,12 @@ class NoneTypeT(Generic):
else: else:
raise TypeError('Expected None!') raise TypeError('Expected None!')
@staticmethod
def may_share_memory(a, b):
# None never share memory between object, in the sence of DebugMode.
# Python None are singleton
return False
none_type_t = NoneTypeT() none_type_t = NoneTypeT()
# This is a variable instance. It can be used only once per fgraph. # This is a variable instance. It can be used only once per fgraph.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论