提交 193c9979 authored 作者: Gijs van Tulder's avatar Gijs van Tulder

assert_conv_shape resolves constants.

上级 beded2b3
......@@ -432,6 +432,9 @@ def check_conv_gradinputs_shape(image_shape, kernel_shape, output_shape,
def assert_conv_shape(shape):
"""This function adds Assert nodes that check if shape is a valid convolution shape.
The first two dimensions should be larger than or equal to zero. The convolution
dimensions should be larger than zero.
Parameters
----------
shape: tuple of int (symbolic or numeric) corresponding to the input, output or
......@@ -442,18 +445,29 @@ def assert_conv_shape(shape):
Returns
-------
Returns a tuple similar to the given `shape`, but with each element wrapped in
an `Assert` op that checks that dimension. The first two dimensions should be
larger than or equal to zero. The convolution dimensions should be larger than zero.
Returns a tuple similar to the given `shape`. For constant elements in `shape`,
the function checks the value and raises a `ValueError` if the dimension is invalid.
The elements that are not constant are wrapped in an `Assert` op that checks the
dimension at run time.
"""
out_shape = []
for i, n in enumerate(shape):
if i < 2:
assert_shp = Assert('The convolution would produce an invalid shape (dim[%d] < 0).' % i)
out_shape.append(assert_shp(n, theano.tensor.ge(n, 0)))
else:
assert_shp = Assert('The convolution would produce an invalid shape (dim[%d] <= 0).' % i)
out_shape.append(assert_shp(n, theano.tensor.gt(n, 0)))
try:
const_n = get_scalar_constant_value(n)
if i < 2:
if const_n < 0:
raise ValueError('The convolution would produce an invalid shape (dim[%d]: %d < 0).' % (i, const_n))
else:
if const_n <= 0:
raise ValueError('The convolution would produce an invalid shape (dim[%d]: %d <= 0).' % (i, const_n))
out_shape.append(n)
except NotScalarConstantError:
if i < 2:
assert_shp = Assert('The convolution would produce an invalid shape (dim[%d] < 0).' % i)
out_shape.append(assert_shp(n, theano.tensor.ge(n, 0)))
else:
assert_shp = Assert('The convolution would produce an invalid shape (dim[%d] <= 0).' % i)
out_shape.append(assert_shp(n, theano.tensor.gt(n, 0)))
return tuple(out_shape)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论