提交 62144bf0 authored 作者: Frederic Bastien's avatar Frederic Bastien

make Specifiying shape work when specifying only some shape. test it.

上级 5a378b80
...@@ -1524,9 +1524,10 @@ class SpecifyShape(Op): ...@@ -1524,9 +1524,10 @@ class SpecifyShape(Op):
for dim in range(node.inputs[0].ndim): for dim in range(node.inputs[0].ndim):
try: try:
s=get_constant_value(node.inputs[1][dim]) s=get_constant_value(node.inputs[1][dim])
s=as_tensor_variable(s)
new_shape.append(s) new_shape.append(s)
except TypeError, e: except TypeError, e:
new_shape.append(xshape[dim]) new_shape.append(node.inputs[1][dim])
assert len(new_shape)==len(xshape) assert len(new_shape)==len(xshape)
return [new_shape] return [new_shape]
......
...@@ -286,6 +286,56 @@ def makeSharedTester(shared_constructor_, ...@@ -286,6 +286,56 @@ def makeSharedTester(shared_constructor_,
else: else:
self.assertRaises(AssertionError, shape_constant_fct) self.assertRaises(AssertionError, shape_constant_fct)
def test_specify_shape_partial(self):
dtype = self.dtype
if dtype is None:
dtype = theano.config.floatX
rng = numpy.random.RandomState([2,4,16])
x1_1 = numpy.asarray(rng.uniform(1,2,[4,2]),dtype=dtype)
x1_1 = self.cast_value(x1_1)
x1_2 = numpy.asarray(rng.uniform(1,2,[4,2]),dtype=dtype)
x1_2 = self.cast_value(x1_2)
x2 = numpy.asarray(rng.uniform(1,2,[5,2]),dtype=dtype)
x2 = self.cast_value(x2)
#Test that we can replace with values of the same shape
x1_shared = self.shared_constructor(x1_1)
x1_specify_shape = tensor.specify_shape(x1_shared,
(tensor.as_tensor_variable(x1_1.shape[0]),
x1_shared.shape[1]))
x1_shared.set_value(x1_2)
assert numpy.allclose(self.ref_fct(x1_shared.value), self.ref_fct( x1_2))
shape_op_fct = theano.function([],x1_shared.shape)
topo = shape_op_fct.maker.env.toposort()
if theano.config.mode!='FAST_COMPILE':
assert len(topo)==3
assert isinstance(topo[0].op,tensor.opt.Shape_i)
assert isinstance(topo[1].op,tensor.opt.Shape_i)
assert isinstance(topo[2].op,tensor.opt.MakeVector)
#Test that we forward the input
specify_shape_fct = theano.function([],x1_specify_shape)
theano.printing.debugprint(specify_shape_fct)
assert numpy.all(specify_shape_fct()==x1_2)
topo_specify = specify_shape_fct.maker.env.toposort()
assert len(topo_specify)==6
#Test that we put the shape info into the graph
shape_constant_fct = theano.function([],x1_specify_shape.shape)
theano.printing.debugprint(shape_constant_fct)
assert numpy.all(shape_constant_fct()==shape_op_fct())
topo_cst = shape_constant_fct.maker.env.toposort()
assert len(topo_cst)==6
#Test that we can replace with values of the different shape
# but that will raise an error in some case, but not all
x1_shared.set_value(x2)
self.assertRaises(AssertionError, specify_shape_fct)
#No assertion will be raised as the Op is removed from the graph
shape_constant_fct()
def test_specify_shape_inplace(self): def test_specify_shape_inplace(self):
#test that specify_shape don't break inserting inplace op #test that specify_shape don't break inserting inplace op
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论