提交 51046423 authored 作者: Frederic's avatar Frederic

Fix test of shape_of_variables and raise a good error to explain that the…

Fix test of shape_of_variables and raise a good error to explain that the FunctionGraph interface changed.
上级 41adddb7
import unittest
import numpy
import theano
......@@ -50,23 +52,30 @@ def test_hash_from_dict():
assert hash_from_dict({0: (0,)}) == hash_from_dict({0: [0]})
def test_shape_of_variables_simple():
x = theano.tensor.matrix('x')
y = x+x
fgraph = theano.FunctionGraph([x], [y])
assert shape_of_variables(fgraph, {x: (5, 5)}) == {x: (5, 5), y: (5, 5)}
class Tshape_of_variables(unittest.TestCase):
def test_simple(self):
x = theano.tensor.matrix('x')
y = x+x
fgraph = theano.FunctionGraph([x], [y], clone=False)
shapes = shape_of_variables(fgraph, {x: (5, 5)})
assert shapes == {x: (5, 5), y: (5, 5)}
x = theano.tensor.matrix('x')
y = theano.tensor.dot(x, x.T)
fgraph = theano.FunctionGraph([x], [y])
shapes = shape_of_variables(fgraph, {x: (5, 1)})
assert shapes[x] == (5, 1)
assert shapes[y] == (5, 5)
x = theano.tensor.matrix('x')
y = theano.tensor.dot(x, x.T)
fgraph = theano.FunctionGraph([x], [y], clone=False)
shapes = shape_of_variables(fgraph, {x: (5, 1)})
assert shapes[x] == (5, 1)
assert shapes[y] == (5, 5)
def test_subtensor(self):
x = theano.tensor.matrix('x')
subx = x[1:]
fgraph = theano.FunctionGraph([x], [subx], clone=False)
shapes = shape_of_variables(fgraph, {x: (10, 10)})
assert shapes[subx] == (9, 10)
def test_shape_of_variables_subtensor():
x = theano.tensor.matrix('x')
subx = x[1:]
fgraph = theano.FunctionGraph([x], [subx])
shapes = shape_of_variables(fgraph, {x: (10, 10)})
assert shapes[subx] == (9, 10)
def test_err(self):
x = theano.tensor.matrix('x')
subx = x[1:]
fgraph = theano.FunctionGraph([x], [subx])
self.assertRaises(ValueError, shape_of_variables, fgraph, {x: (10, 10)})
import numpy
import theano
from theano.compat.python2x import any
from theano.gof.cc import hash_from_code
......@@ -69,7 +70,7 @@ def shape_of_variables(fgraph, input_shapes):
>>> import theano
>>> x = theano.tensor.matrix('x')
>>> y = x[512:]; y.name = 'y'
>>> fgraph = theano.FunctionGraph([x], [y])
>>> fgraph = theano.FunctionGraph([x], [y], clone=False)
>>> shape_of_variables(fgraph, {x: (1024, 1024)})
{y: (512, 1024), x: (1024, 1024)}
"""
......@@ -85,6 +86,12 @@ def shape_of_variables(fgraph, input_shapes):
compute_shapes = theano.function(input_dims, output_dims)
if any([i not in fgraph.inputs for i in input_shapes.keys()]):
raise ValueError(
"input_shapes keys aren't in the fgraph.inputs. FunctionGraph()"
" interface changed. Now by default, it clone the graph it receive."
" To have the old behavior, give him this new parameter `clone=False`.")
numeric_input_dims = [dim for inp in fgraph.inputs
for dim in input_shapes[inp]]
numeric_output_dims = compute_shapes(*numeric_input_dims)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论