提交 f56254ef authored 作者: lamblin's avatar lamblin

Merge pull request #1036 from nouiz/gpu_elem_opt

Fix an optimization crash of local_gpu_elemwise_0.
...@@ -2725,11 +2725,17 @@ class Composite(ScalarOp): ...@@ -2725,11 +2725,17 @@ class Composite(ScalarOp):
return super(Composite, self).make_node(*inputs) return super(Composite, self).make_node(*inputs)
else: else:
# Make a new op with the right input type. # Make a new op with the right input type.
assert len(inputs) == self.nin
res = theano.compile.rebuild_collect_shared( res = theano.compile.rebuild_collect_shared(
self.outputs, self.outputs,
replace=dict(zip(self.inputs, inputs)), replace=dict(zip(self.inputs, inputs)),
rebuild_strict=False) rebuild_strict=False)
node = Composite(inputs, res[1]).make_node(*inputs) # After rebuild_collect_shared, the Variable in inputs
# are not necessarily in the graph represented by res.
# res[2][0] is a dict that map from the original variable to the
# cloned variable.
cloned_inputs = [res[2][0][i] for i in inputs]
node = Composite(cloned_inputs, res[1]).make_node(*inputs)
return node return node
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
......
...@@ -101,6 +101,26 @@ class test_composite(unittest.TestCase): ...@@ -101,6 +101,26 @@ class test_composite(unittest.TestCase):
fn = gof.DualLinker().accept(g).make_function() fn = gof.DualLinker().accept(g).make_function()
assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5] assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5]
def test_make_node_continue_graph(self):
# This is a test for a bug (now fixed) that disabled the
# local_gpu_elemwise_0 optimization and printed an
# optimization warning on the terminal.
# We test that Composite.make_node accept as inputs Variable
# some that represent existing computation.
si0 = theano.scalar.int8()
si1 = theano.scalar.int8()
si2 = theano.scalar.float32()
sout = (si0 * si1) / si2
sop = theano.scalar.Composite([si0, si1, si2],
[sout])
si0 = theano.scalar.int8()
si1 = theano.scalar.int8()
si2 = theano.scalar.float32()
si3 = theano.scalar.float32()
sop.make_node(si0 * si3, si1, si2)
class test_logical(unittest.TestCase): class test_logical(unittest.TestCase):
def test_gt(self): def test_gt(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论