提交 188b2809 authored 作者: James Bergstra's avatar James Bergstra

I think these are fixes to Composite.eq and hash.

But they cause a mysterious bug in test_gpu_fusion on my mac in which the second of cases 2 and 3 (any order) fails to update the result buffer. I tried but failed to fix it so far.
上级 6c7a8fc2
...@@ -1957,18 +1957,6 @@ class Composite(ScalarOp): ...@@ -1957,18 +1957,6 @@ class Composite(ScalarOp):
if not isinstance(node.op, ScalarOp): if not isinstance(node.op, ScalarOp):
raise ValueError("The env to Composite must be exclusively composed of ScalarOp instances.") raise ValueError("The env to Composite must be exclusively composed of ScalarOp instances.")
subd = dict(zip(inputs,
["%%(i%i)s"%i for i in xrange(len(inputs))]) +
zip(outputs,
["%%(o%i)s"%i for i in xrange(len(outputs))]))
for orphan in env.variables: #env.orphans:
if orphan.owner is None and orphan not in env.inputs:
if isinstance(orphan, Constant):
subd[orphan] = orphan.type.c_literal(orphan.data)
else:
raise ValueError("All orphans in the env to Composite must be Constant instances.")
if not hasattr(self,"name"): if not hasattr(self,"name"):
l=[] l=[]
for n in env.toposort(): for n in env.toposort():
...@@ -1981,29 +1969,6 @@ class Composite(ScalarOp): ...@@ -1981,29 +1969,6 @@ class Composite(ScalarOp):
l.append(v) l.append(v)
self.name="Composite{"+",".join(l)+"}" self.name="Composite{"+",".join(l)+"}"
_c_code = "{\n"
i = 0
j = 0
for node in env.toposort():
j += 1
for output in node.outputs:
if output not in subd:
i += 1
name = "V%%(id)s_tmp%i" % i
subd[output] = name
_c_code += "%s %s;\n" % (output.type.dtype_specs()[1], name)
s = node.op.c_code(node,
"%(name)s",
[subd[input] for input in node.inputs],
[subd[output] for output in node.outputs],
dict(fail = "%(fail)s",
id = "%%(id)s_%i" % j))
_c_code += s
_c_code += "\n"
_c_code += "}\n"
def compose_impl(r): def compose_impl(r):
# this is not optimal at all eg in add(*1 -> mul(x, y), *1) # this is not optimal at all eg in add(*1 -> mul(x, y), *1)
# it will calculate *1 twice # it will calculate *1 twice
...@@ -2020,7 +1985,6 @@ class Composite(ScalarOp): ...@@ -2020,7 +1985,6 @@ class Composite(ScalarOp):
_impls = [compose_impl(r) for r in env.outputs] _impls = [compose_impl(r) for r in env.outputs]
self._c_code = _c_code
self._impls = _impls self._impls = _impls
self.nin = len(inputs) self.nin = len(inputs)
self.nout = len(outputs) self.nout = len(outputs)
...@@ -2059,10 +2023,56 @@ class Composite(ScalarOp): ...@@ -2059,10 +2023,56 @@ class Composite(ScalarOp):
#It won't generate conflicting variable name. #It won't generate conflicting variable name.
d['id']='_DUMMY_ID_' d['id']='_DUMMY_ID_'
return self._c_code % d subd = dict(
zip(self.env.inputs,
["%%(i%i)s"%i for i in xrange(len(self.env.inputs))])
+ zip(self.env.outputs,
["%%(o%i)s"%i for i in xrange(len(self.env.outputs))]))
for orphan in self.env.variables: #env.orphans:
if orphan.owner is None and orphan not in self.env.inputs:
if isinstance(orphan, Constant):
subd[orphan] = orphan.type.c_literal(orphan.data)
else:
raise ValueError(
"All orphans in the env to Composite must"
" be Constant instances.")
_c_code = "{\n"
i = 0
j = 0
for node in self.env.toposort():
j += 1
for output in node.outputs:
if output not in subd:
i += 1
name = "V%%(id)s_tmp%i" % i
subd[output] = name
_c_code += "%s %s;\n" % (
output.type.dtype_specs()[1], name)
s = node.op.c_code(node,
"%(name)s",
[subd[input] for input in node.inputs],
[subd[output] for output in node.outputs],
dict(fail = "%(fail)s",
id = "%%(id)s_%i" % j))
_c_code += s
_c_code += "\n"
_c_code += "}\n"
return _c_code % d
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,)+tuple([x.op.c_code_cache_version() for x in self.env.toposort()]) rval = [1]
for x in self.env.toposort():
xv = x.op.c_code_cache_version()
if xv:
rval.append(xv)
else:
return ()
return tuple(rval)
def c_support_code(self): def c_support_code(self):
str = "" str = ""
...@@ -2099,13 +2109,12 @@ class Composite(ScalarOp): ...@@ -2099,13 +2109,12 @@ class Composite(ScalarOp):
return self._hashval return self._hashval
def __getstate__(self): def __getstate__(self):
d = copy(self.__dict__) return dict(
d.pop('env') inputs=self.inputs,
d.pop('_impls') outputs=self.outputs)
return d
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) ## self.__dict__.update(d)
# We must call init to set env and _impls again, as otherwise # We must call init to set env and _impls again, as otherwise
# self.perform will not work. # self.perform will not work.
self.__init__(self.inputs, self.outputs) self.__init__(d['inputs'], d['outputs'])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论