提交 5429c30a authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #3694 from nouiz/crash_cache_c_scalar

Crash cache c scalar
......@@ -510,7 +510,8 @@ def struct_variable_codeblocks(variable, policies, id, symbol_table, sub):
"""
name = "V%i" % id
symbol_table[variable] = name
if variable not in symbol_table:
symbol_table[variable] = name
sub = dict(sub)
# sub['name'] = name
sub['id'] = id
......@@ -608,9 +609,20 @@ class CLinker(link.Linker):
self.orphans = list(r for r in self.variables
if isinstance(r, graph.Constant) and
r not in self.inputs)
# C type constants (theano.scalar.Scalar). They don't request an object
self.consts = []
# Move c type from orphans (theano.scalar.Scalar) to self.consts
for variable in self.orphans:
if isinstance(variable, graph.Constant):
try:
variable.type.c_literal(variable.data)
self.consts.append(variable)
self.orphans.remove(variable)
except (utils.MethodNotDefined, NotImplementedError):
pass
self.temps = list(set(self.variables).difference(
self.inputs).difference(self.outputs).difference(self.orphans))
self.consts = []
def code_gen(self):
"""
......@@ -634,8 +646,6 @@ class CLinker(link.Linker):
no_recycling = self.no_recycling
self.consts = []
c_support_code_apply = []
c_init_code_apply = []
......@@ -664,7 +674,11 @@ class CLinker(link.Linker):
# [what to declare in each run,
# what to do at the beginning of each run,
# what to do at the end of each run]]
if variable in self.inputs:
if variable in self.consts:
symbol[variable] = ("(" + variable.type.c_literal(
variable.data) + ")")
continue
elif variable in self.inputs:
# We need to extract the new inputs at each run
# they do not need to be relayed to Python, so we don't sync.
# If the variable is both an input and an output, there is
......@@ -675,15 +689,6 @@ class CLinker(link.Linker):
if not isinstance(variable, graph.Constant):
raise TypeError("All orphans to CLinker must be Constant"
" instances.", variable)
if isinstance(variable, graph.Constant):
try:
symbol[variable] = ("(" + variable.type.c_literal(
variable.data) + ")")
self.consts.append(variable)
self.orphans.remove(variable)
continue
except (utils.MethodNotDefined, NotImplementedError):
pass
# orphans are not inputs so we'll just get fetch them
# when we initialize the struct and assume they stay
# the same
......@@ -1159,13 +1164,6 @@ class CLinker(link.Linker):
for v in self.variables:
if v in self.consts:
continue
if v in self.orphans and isinstance(v, graph.Constant):
try:
# constant will be inlined, no need to get
v.type.c_literal(v.data)
continue
except (utils.MethodNotDefined, NotImplementedError):
pass
init_tasks.append((v, 'init', id))
tasks.append((v, 'get', id + 1))
id += 2
......
......@@ -19,7 +19,7 @@ def as_variable(x):
class TDouble(Type):
def filter(self, data):
def filter(self, data, strict=False, allow_downcast=False):
return float(data)
def c_declare(self, name, sub, check_input=True):
......@@ -103,7 +103,7 @@ class MyOp(Op):
out[0] = self.impl(*inputs)
def c_code_cache_version(self):
return ()
return (1,)
# class Unary(MyOp):
......@@ -201,6 +201,35 @@ def test_clinker_literal_inlining():
assert "4.12345678" in code # we expect the number to be inlined
def test_clinker_literal_cache():
# This caused bugs in the past related to the cache.
if not theano.config.cxx:
raise SkipTest("G++ not available, so we need to skip this test.")
mode = theano.Mode(linker='c')
A = theano.tensor.matrix()
input1 = theano.tensor.vector()
normal_svd = numpy.array([[5.936276e+01, -4.664007e-07, -2.56265e-06],
[-4.664007e-07, 9.468691e-01, -3.18862e-02],
[-2.562651e-06, -3.188625e-02, 1.05226e+00]],
dtype=theano.config.floatX)
orientationi = numpy.array([59.36276866, 1.06116353, 0.93797339],
dtype=theano.config.floatX)
for out1 in [A - input1[0] * numpy.identity(3),
input1[0] * numpy.identity(3)]:
benchmark = theano.function(
inputs=[A, input1],
outputs=[out1],
on_unused_input='ignore',
mode=mode)
out1 = benchmark(normal_svd, orientationi)
def test_clinker_single_node():
if not theano.config.cxx:
raise SkipTest("G++ not available, so we need to skip this test.")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论