提交 193c9979 authored 作者: Gijs van Tulder's avatar Gijs van Tulder

assert_conv_shape resolves constants.

上级 beded2b3
...@@ -432,6 +432,9 @@ def check_conv_gradinputs_shape(image_shape, kernel_shape, output_shape, ...@@ -432,6 +432,9 @@ def check_conv_gradinputs_shape(image_shape, kernel_shape, output_shape,
def assert_conv_shape(shape): def assert_conv_shape(shape):
"""This function adds Assert nodes that check if shape is a valid convolution shape. """This function adds Assert nodes that check if shape is a valid convolution shape.
The first two dimensions should be larger than or equal to zero. The convolution
dimensions should be larger than zero.
Parameters Parameters
---------- ----------
shape: tuple of int (symbolic or numeric) corresponding to the input, output or shape: tuple of int (symbolic or numeric) corresponding to the input, output or
...@@ -442,12 +445,23 @@ def assert_conv_shape(shape): ...@@ -442,12 +445,23 @@ def assert_conv_shape(shape):
Returns Returns
------- -------
Returns a tuple similar to the given `shape`, but with each element wrapped in Returns a tuple similar to the given `shape`. For constant elements in `shape`,
an `Assert` op that checks that dimension. The first two dimensions should be the function checks the value and raises a `ValueError` if the dimension is invalid.
larger than or equal to zero. The convolution dimensions should be larger than zero. The elements that are not constant are wrapped in an `Assert` op that checks the
dimension at run time.
""" """
out_shape = [] out_shape = []
for i, n in enumerate(shape): for i, n in enumerate(shape):
try:
const_n = get_scalar_constant_value(n)
if i < 2:
if const_n < 0:
raise ValueError('The convolution would produce an invalid shape (dim[%d]: %d < 0).' % (i, const_n))
else:
if const_n <= 0:
raise ValueError('The convolution would produce an invalid shape (dim[%d]: %d <= 0).' % (i, const_n))
out_shape.append(n)
except NotScalarConstantError:
if i < 2: if i < 2:
assert_shp = Assert('The convolution would produce an invalid shape (dim[%d] < 0).' % i) assert_shp = Assert('The convolution would produce an invalid shape (dim[%d] < 0).' % i)
out_shape.append(assert_shp(n, theano.tensor.ge(n, 0))) out_shape.append(assert_shp(n, theano.tensor.ge(n, 0)))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论