提交 5429c30a authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #3694 from nouiz/crash_cache_c_scalar

Crash cache c scalar
...@@ -510,6 +510,7 @@ def struct_variable_codeblocks(variable, policies, id, symbol_table, sub): ...@@ -510,6 +510,7 @@ def struct_variable_codeblocks(variable, policies, id, symbol_table, sub):
""" """
name = "V%i" % id name = "V%i" % id
if variable not in symbol_table:
symbol_table[variable] = name symbol_table[variable] = name
sub = dict(sub) sub = dict(sub)
# sub['name'] = name # sub['name'] = name
...@@ -608,9 +609,20 @@ class CLinker(link.Linker): ...@@ -608,9 +609,20 @@ class CLinker(link.Linker):
self.orphans = list(r for r in self.variables self.orphans = list(r for r in self.variables
if isinstance(r, graph.Constant) and if isinstance(r, graph.Constant) and
r not in self.inputs) r not in self.inputs)
# C type constants (theano.scalar.Scalar). They don't request an object
self.consts = []
# Move c type from orphans (theano.scalar.Scalar) to self.consts
for variable in self.orphans:
if isinstance(variable, graph.Constant):
try:
variable.type.c_literal(variable.data)
self.consts.append(variable)
self.orphans.remove(variable)
except (utils.MethodNotDefined, NotImplementedError):
pass
self.temps = list(set(self.variables).difference( self.temps = list(set(self.variables).difference(
self.inputs).difference(self.outputs).difference(self.orphans)) self.inputs).difference(self.outputs).difference(self.orphans))
self.consts = []
def code_gen(self): def code_gen(self):
""" """
...@@ -634,8 +646,6 @@ class CLinker(link.Linker): ...@@ -634,8 +646,6 @@ class CLinker(link.Linker):
no_recycling = self.no_recycling no_recycling = self.no_recycling
self.consts = []
c_support_code_apply = [] c_support_code_apply = []
c_init_code_apply = [] c_init_code_apply = []
...@@ -664,7 +674,11 @@ class CLinker(link.Linker): ...@@ -664,7 +674,11 @@ class CLinker(link.Linker):
# [what to declare in each run, # [what to declare in each run,
# what to do at the beginning of each run, # what to do at the beginning of each run,
# what to do at the end of each run]] # what to do at the end of each run]]
if variable in self.inputs: if variable in self.consts:
symbol[variable] = ("(" + variable.type.c_literal(
variable.data) + ")")
continue
elif variable in self.inputs:
# We need to extract the new inputs at each run # We need to extract the new inputs at each run
# they do not need to be relayed to Python, so we don't sync. # they do not need to be relayed to Python, so we don't sync.
# If the variable is both an input and an output, there is # If the variable is both an input and an output, there is
...@@ -675,15 +689,6 @@ class CLinker(link.Linker): ...@@ -675,15 +689,6 @@ class CLinker(link.Linker):
if not isinstance(variable, graph.Constant): if not isinstance(variable, graph.Constant):
raise TypeError("All orphans to CLinker must be Constant" raise TypeError("All orphans to CLinker must be Constant"
" instances.", variable) " instances.", variable)
if isinstance(variable, graph.Constant):
try:
symbol[variable] = ("(" + variable.type.c_literal(
variable.data) + ")")
self.consts.append(variable)
self.orphans.remove(variable)
continue
except (utils.MethodNotDefined, NotImplementedError):
pass
# orphans are not inputs so we'll just get fetch them # orphans are not inputs so we'll just get fetch them
# when we initialize the struct and assume they stay # when we initialize the struct and assume they stay
# the same # the same
...@@ -1159,13 +1164,6 @@ class CLinker(link.Linker): ...@@ -1159,13 +1164,6 @@ class CLinker(link.Linker):
for v in self.variables: for v in self.variables:
if v in self.consts: if v in self.consts:
continue continue
if v in self.orphans and isinstance(v, graph.Constant):
try:
# constant will be inlined, no need to get
v.type.c_literal(v.data)
continue
except (utils.MethodNotDefined, NotImplementedError):
pass
init_tasks.append((v, 'init', id)) init_tasks.append((v, 'init', id))
tasks.append((v, 'get', id + 1)) tasks.append((v, 'get', id + 1))
id += 2 id += 2
......
...@@ -19,7 +19,7 @@ def as_variable(x): ...@@ -19,7 +19,7 @@ def as_variable(x):
class TDouble(Type): class TDouble(Type):
def filter(self, data): def filter(self, data, strict=False, allow_downcast=False):
return float(data) return float(data)
def c_declare(self, name, sub, check_input=True): def c_declare(self, name, sub, check_input=True):
...@@ -103,7 +103,7 @@ class MyOp(Op): ...@@ -103,7 +103,7 @@ class MyOp(Op):
out[0] = self.impl(*inputs) out[0] = self.impl(*inputs)
def c_code_cache_version(self): def c_code_cache_version(self):
return () return (1,)
# class Unary(MyOp): # class Unary(MyOp):
...@@ -201,6 +201,35 @@ def test_clinker_literal_inlining(): ...@@ -201,6 +201,35 @@ def test_clinker_literal_inlining():
assert "4.12345678" in code # we expect the number to be inlined assert "4.12345678" in code # we expect the number to be inlined
def test_clinker_literal_cache():
# This caused bugs in the past related to the cache.
if not theano.config.cxx:
raise SkipTest("G++ not available, so we need to skip this test.")
mode = theano.Mode(linker='c')
A = theano.tensor.matrix()
input1 = theano.tensor.vector()
normal_svd = numpy.array([[5.936276e+01, -4.664007e-07, -2.56265e-06],
[-4.664007e-07, 9.468691e-01, -3.18862e-02],
[-2.562651e-06, -3.188625e-02, 1.05226e+00]],
dtype=theano.config.floatX)
orientationi = numpy.array([59.36276866, 1.06116353, 0.93797339],
dtype=theano.config.floatX)
for out1 in [A - input1[0] * numpy.identity(3),
input1[0] * numpy.identity(3)]:
benchmark = theano.function(
inputs=[A, input1],
outputs=[out1],
on_unused_input='ignore',
mode=mode)
out1 = benchmark(normal_svd, orientationi)
def test_clinker_single_node(): def test_clinker_single_node():
if not theano.config.cxx: if not theano.config.cxx:
raise SkipTest("G++ not available, so we need to skip this test.") raise SkipTest("G++ not available, so we need to skip this test.")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论