提交 87c4e01f authored 作者: carriepl's avatar carriepl

Merge pull request #3811 from nouiz/typed_list

Add TypedListConstant, this fix an optimization warning.
......@@ -568,7 +568,7 @@ class Constant(Variable):
else:
name = str(self.data)
if len(name) > 20:
name = name[:10] + '...' + name[-10]
name = name[:10] + '...' + name[-10:]
return 'Constant{%s}' % name
def clone(self):
......
......@@ -1169,7 +1169,9 @@ class ShapeFeature(object):
# Merge other_shape with r_shape, giving the priority to other_shape
merged_shape = []
for i, ps in enumerate(other_shape):
if (ps.owner and
if r_shape is None and other_shape:
merged_shape.append(other_shape[i])
elif (ps.owner and
isinstance(getattr(ps.owner, 'op', None), Shape_i) and
ps.owner.op.i == i and
ps.owner.inputs[0] in (r, other_r)):
......
......@@ -53,6 +53,15 @@ class TypedListVariable(_typed_list_py_operators, Variable):
TypedListType.Variable = TypedListVariable
class TypedListConstant(_typed_list_py_operators, Constant):
"""
Subclass to add the typed list operators to the basic `Variable` class.
"""
TypedListType.Constant = TypedListConstant
class GetItem(Op):
# See doc in instance of this Op or function after this class definition.
view_map = {0: [0]}
......
......@@ -110,3 +110,13 @@ class test_inplace(unittest.TestCase):
y = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(numpy.array_equal(f([x, y], y), [x]))
def test_constant_folding():
m = theano.tensor.ones((1,), dtype='int8')
l = theano.typed_list.make_list([m, m])
f = theano.function([], l)
topo = f.maker.fgraph.toposort()
assert len(topo)
assert isinstance(topo[0].op, theano.compile.ops.DeepCopyOp)
assert f() == [1, 1]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论