提交 85f9247a authored 作者: lamblin's avatar lamblin

Merge pull request #937 from nouiz/small_fix

Small fix
...@@ -1607,7 +1607,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -1607,7 +1607,7 @@ class _Linker(gof.link.LocalLinker):
if not isinstance(node.op, gof.op.Op): if not isinstance(node.op, gof.op.Op):
raise utils.MethodNotDefined() raise utils.MethodNotDefined()
e = FunctionGraph(*graph.clone(node.inputs, node.outputs)) e = FunctionGraph(*graph.clone(node.inputs, node.outputs))
e.toposort = lambda: e.nodes # WARNING: STOCHASTIC ORDER e.toposort = lambda: e.apply_nodes # WARNING: STOCHASTIC ORDER
# Specifically... e.nodes is a set, but of only 1 element # Specifically... e.nodes is a set, but of only 1 element
cl = CLinker().accept(e, [r for r, r2 in zip(e.outputs, cl = CLinker().accept(e, [r for r, r2 in zip(e.outputs,
......
import theano import theano
from theano.gof.utils import give_variables_names, unique from theano.gof.utils import give_variables_names, unique
from theano.gof.python25 import all
def test_give_variables_names(): def test_give_variables_names():
x = theano.tensor.matrix('x') x = theano.tensor.matrix('x')
...@@ -10,6 +12,7 @@ def test_give_variables_names(): ...@@ -10,6 +12,7 @@ def test_give_variables_names():
assert all(var.name for var in variables) assert all(var.name for var in variables)
assert unique([var.name for var in variables]) assert unique([var.name for var in variables])
def test_give_variables_names_idempotence(): def test_give_variables_names_idempotence():
x = theano.tensor.matrix('x') x = theano.tensor.matrix('x')
y = x + 1 y = x + 1
...@@ -23,6 +26,7 @@ def test_give_variables_names_idempotence(): ...@@ -23,6 +26,7 @@ def test_give_variables_names_idempotence():
assert names == names2 assert names == names2
def test_give_variables_names_small(): def test_give_variables_names_small():
x = theano.tensor.matrix('x') x = theano.tensor.matrix('x')
y = theano.tensor.dot(x, x) y = theano.tensor.dot(x, x)
...@@ -30,4 +34,3 @@ def test_give_variables_names_small(): ...@@ -30,4 +34,3 @@ def test_give_variables_names_small():
give_variables_names(fgraph.variables) give_variables_names(fgraph.variables)
assert all(var.name for var in fgraph.variables) assert all(var.name for var in fgraph.variables)
assert unique([var.name for var in fgraph.variables]) assert unique([var.name for var in fgraph.variables])
...@@ -854,7 +854,8 @@ class CSMGrad(gof.op.Op): ...@@ -854,7 +854,8 @@ class CSMGrad(gof.op.Op):
sp_dim = x_shape[0] sp_dim = x_shape[0]
g_row = numpy.zeros(sp_dim, dtype=g_data.dtype) g_row = numpy.zeros(sp_dim, dtype=g_data.dtype)
gout_data = numpy.zeros_like(x_data) gout_data = numpy.zeros_like(x_data, dtype=node.outputs[0].dtype)
for i in range(len(x_indptr) - 1): for i in range(len(x_indptr) - 1):
for j_ptr in range(g_indptr[i], g_indptr[i + 1]): for j_ptr in range(g_indptr[i], g_indptr[i + 1]):
g_row[g_indices[j_ptr]] += g_data[j_ptr] g_row[g_indices[j_ptr]] += g_data[j_ptr]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论