提交 51046423 authored 作者: Frederic's avatar Frederic

Fix test of shape_of_variables and raise a good error to explain that the…

Fix test of shape_of_variables and raise a good error to explain that the FunctionGraph interface changed.
上级 41adddb7
import unittest
import numpy import numpy
import theano import theano
...@@ -50,23 +52,30 @@ def test_hash_from_dict(): ...@@ -50,23 +52,30 @@ def test_hash_from_dict():
assert hash_from_dict({0: (0,)}) == hash_from_dict({0: [0]}) assert hash_from_dict({0: (0,)}) == hash_from_dict({0: [0]})
def test_shape_of_variables_simple(): class Tshape_of_variables(unittest.TestCase):
x = theano.tensor.matrix('x') def test_simple(self):
y = x+x x = theano.tensor.matrix('x')
fgraph = theano.FunctionGraph([x], [y]) y = x+x
assert shape_of_variables(fgraph, {x: (5, 5)}) == {x: (5, 5), y: (5, 5)} fgraph = theano.FunctionGraph([x], [y], clone=False)
shapes = shape_of_variables(fgraph, {x: (5, 5)})
assert shapes == {x: (5, 5), y: (5, 5)}
x = theano.tensor.matrix('x') x = theano.tensor.matrix('x')
y = theano.tensor.dot(x, x.T) y = theano.tensor.dot(x, x.T)
fgraph = theano.FunctionGraph([x], [y]) fgraph = theano.FunctionGraph([x], [y], clone=False)
shapes = shape_of_variables(fgraph, {x: (5, 1)}) shapes = shape_of_variables(fgraph, {x: (5, 1)})
assert shapes[x] == (5, 1) assert shapes[x] == (5, 1)
assert shapes[y] == (5, 5) assert shapes[y] == (5, 5)
def test_subtensor(self):
x = theano.tensor.matrix('x')
subx = x[1:]
fgraph = theano.FunctionGraph([x], [subx], clone=False)
shapes = shape_of_variables(fgraph, {x: (10, 10)})
assert shapes[subx] == (9, 10)
def test_shape_of_variables_subtensor(): def test_err(self):
x = theano.tensor.matrix('x') x = theano.tensor.matrix('x')
subx = x[1:] subx = x[1:]
fgraph = theano.FunctionGraph([x], [subx]) fgraph = theano.FunctionGraph([x], [subx])
shapes = shape_of_variables(fgraph, {x: (10, 10)}) self.assertRaises(ValueError, shape_of_variables, fgraph, {x: (10, 10)})
assert shapes[subx] == (9, 10)
import numpy import numpy
import theano import theano
from theano.compat.python2x import any
from theano.gof.cc import hash_from_code from theano.gof.cc import hash_from_code
...@@ -69,7 +70,7 @@ def shape_of_variables(fgraph, input_shapes): ...@@ -69,7 +70,7 @@ def shape_of_variables(fgraph, input_shapes):
>>> import theano >>> import theano
>>> x = theano.tensor.matrix('x') >>> x = theano.tensor.matrix('x')
>>> y = x[512:]; y.name = 'y' >>> y = x[512:]; y.name = 'y'
>>> fgraph = theano.FunctionGraph([x], [y]) >>> fgraph = theano.FunctionGraph([x], [y], clone=False)
>>> shape_of_variables(fgraph, {x: (1024, 1024)}) >>> shape_of_variables(fgraph, {x: (1024, 1024)})
{y: (512, 1024), x: (1024, 1024)} {y: (512, 1024), x: (1024, 1024)}
""" """
...@@ -85,6 +86,12 @@ def shape_of_variables(fgraph, input_shapes): ...@@ -85,6 +86,12 @@ def shape_of_variables(fgraph, input_shapes):
compute_shapes = theano.function(input_dims, output_dims) compute_shapes = theano.function(input_dims, output_dims)
if any([i not in fgraph.inputs for i in input_shapes.keys()]):
raise ValueError(
"input_shapes keys aren't in the fgraph.inputs. FunctionGraph()"
" interface changed. Now by default, it clone the graph it receive."
" To have the old behavior, give him this new parameter `clone=False`.")
numeric_input_dims = [dim for inp in fgraph.inputs numeric_input_dims = [dim for inp in fgraph.inputs
for dim in input_shapes[inp]] for dim in input_shapes[inp]]
numeric_output_dims = compute_shapes(*numeric_input_dims) numeric_output_dims = compute_shapes(*numeric_input_dims)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论