提交 85f83402 authored 作者: lamblin's avatar lamblin

Merge pull request #1409 from nouiz/mixed

Fix opt error/skip in scan, add Shape.c_code()
...@@ -1703,6 +1703,7 @@ class GCC_compiler(object): ...@@ -1703,6 +1703,7 @@ class GCC_compiler(object):
flags = list(flags) flags = list(flags)
compilation_ok = True compilation_ok = True
run_ok = False
try: try:
fd, path = tempfile.mkstemp(suffix='.c', prefix=tmp_prefix) fd, path = tempfile.mkstemp(suffix='.c', prefix=tmp_prefix)
exe_path = path[:-2] exe_path = path[:-2]
...@@ -1719,7 +1720,6 @@ class GCC_compiler(object): ...@@ -1719,7 +1720,6 @@ class GCC_compiler(object):
compilation_ok = False compilation_ok = False
elif try_run: elif try_run:
# Try to execute the program # Try to execute the program
run_ok = False
try: try:
proc = call_subprocess_Popen([exe_path], proc = call_subprocess_Popen([exe_path],
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
......
...@@ -123,10 +123,15 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -123,10 +123,15 @@ def remove_constants_and_unused_inputs_scan(node):
if isinstance(nw_out, tensor.Constant): if isinstance(nw_out, tensor.Constant):
givens[nw_in] = nw_out.clone() givens[nw_in] = nw_out.clone()
elif nw_in in all_ins: elif nw_in in all_ins:
identical_non_seqs = [x for x in outer_non_seqs[:idx] identical_non_seqs = [x for x in nw_outer
if scan_utils.equal_computations( if scan_utils.equal_computations(
[x], [nw_out])] [x], [nw_out])]
if identical_non_seqs: if identical_non_seqs:
identical_idx = outer_non_seqs.index(identical_non_seqs[0])
# If we have identical non sequences, the previous one
# must be in nw_inner or be a constant.
assert (non_seqs[identical_idx] in nw_inner or
isinstance(identical_non_seqs[0], tensor.Constant))
index = outer_non_seqs.index(identical_non_seqs[0]) index = outer_non_seqs.index(identical_non_seqs[0])
givens[nw_in] = non_seqs[index] givens[nw_in] = non_seqs[index]
else: else:
......
...@@ -3486,6 +3486,51 @@ class T_Scan(unittest.TestCase): ...@@ -3486,6 +3486,51 @@ class T_Scan(unittest.TestCase):
assert not opt_obj.belongs_to_set(scan_node1, [scan_node2]) assert not opt_obj.belongs_to_set(scan_node1, [scan_node2])
assert not opt_obj.belongs_to_set(scan_node2, [scan_node1]) assert not opt_obj.belongs_to_set(scan_node2, [scan_node1])
def test_remove_constants_and_unused_inputs_scan(self):
"""
Test the opt remove_constants_and_unused_inputs_scan
TODO: currently we only test non_seqs, should test
"""
W = theano.tensor.matrix(name='W')
v = theano.tensor.ivector(name='v')
y1, _ = theano.scan(lambda i, W: W[i], sequences=v,
outputs_info=None, non_sequences=[W])
y2, _ = theano.scan(lambda i, _, W: W[i], sequences=v,
outputs_info=None, non_sequences=[W[0], W])
y3, _ = theano.scan(lambda i, W, _: W[i], sequences=v,
outputs_info=None, non_sequences=[W, W[0]])
y4, _ = theano.scan(lambda i, _, _2, W: W[i], sequences=v,
outputs_info=None, non_sequences=[W[0], W[0], W])
y5, _ = theano.scan(lambda i, _, W, _2: W[i], sequences=v,
outputs_info=None, non_sequences=[W[0], W, W[0]])
y6, _ = theano.scan(lambda i, W, _, _2: W[i], sequences=v,
outputs_info=None, non_sequences=[W, W[0], W[0]])
# TODO: y7 have problem during run time. I think it should
# raise an error during the scan construction.
#y7, _ = theano.scan(lambda i, W, _, _2: W[i], sequences=v,
# outputs_info=None, non_sequences=[v, W[0], W])
for out in [y1, y2, y3, y4, y5, y6]:
print
print "Begin test"
print
#This used to raise an exception
f = theano.function([W, v], out)
f(numpy.zeros((3, 3)), [1, 2])
scan_node = f.maker.fgraph.toposort()[-1]
# TODO: Why this assert always fail?
# assert (len(scan_node.inputs) ==
# len(set(scan_node.inputs)))
inp = scan_node.op.inner_non_seqs(scan_node.op.inputs)
assert len(inp) == 1
assert (len(inp) == len(set(inp)))
inp = scan_node.op.outer_non_seqs(scan_node)
assert len(inp) == 1
assert (len(inp) == len(set(inp)))
#import pdb;pdb.set_trace()
#assert numpy.allclose(f([1, 2]), [[0, 0, 0], [1, 1, 1], [1, 1, 1]])
def test_speed(): def test_speed():
# #
......
...@@ -2459,6 +2459,30 @@ class Shape(Op): ...@@ -2459,6 +2459,30 @@ class Shape(Op):
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
return [None] return [None]
def c_code(self, node, nodename, inp, out, sub):
x, = inp
z, = out
if isinstance(node.inputs[0].type, TensorType):
return """
npy_intp shape[] = {PyArray_NDIM(%(x)s)};
if(%(z)s == NULL || (PyArray_DIMS(%(z)s)[0] != shape[0]))
{
Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject*) PyArray_SimpleNew(1, shape, NPY_INT64);
}
for(int i=0;i<shape[0];i++)
{
((npy_int64*)PyArray_GETPTR1(%(z)s, i))[0] = PyArray_DIMS(%(x)s)[i];
}
""" % locals()
else:
#TODO: if your type is not listed here, make a damn registry of
# shape_i ops for various types of variables.
# Do not continue this madness.
return super(Shape_i, self).c_code(node, name, (x,), (out,), sub)
def c_code_cache_version(self):
return (1,)
@constructor @constructor
def old_shape(a): def old_shape(a):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论