提交 3360e333 authored 作者: Frederic's avatar Frederic

Fix c code subtensor when the input have 0 dims and the outputs 0 dims.

Added test.
上级 ba9c7fd4
......@@ -3307,8 +3307,25 @@ class Subtensor(Op):
{
%(fail)s;
}
assert (xview->dimensions != %(x)s->dimensions);
assert (xview->strides != %(x)s->strides);
if ((xview->dimensions == %(x)s->dimensions)
&& (%(x)s->dimensions != NULL))
{
PyErr_Format(PyExc_ValueError, "x and xview"
"(with %%d dims) have the same dimensions"
" pointors: %%p and %%p",
%(x)s->nd, xview->dimensions, %(x)s->dimensions);
%(fail)s;
}
if (xview->strides == %(x)s->strides
&& (%(x)s->dimensions != NULL))
{
PyErr_Format(PyExc_ValueError, "x and xview"
"(with %%d dims) have the same strides"
" pointors: %%p and %%p",
%(x)s->nd, xview->strides, %(x)s->strides);
%(fail)s;
}
for (; outer_ii < %(len_is_slice)s; ++outer_ii)
{
......@@ -3425,7 +3442,7 @@ class Subtensor(Op):
@staticmethod
def helper_c_code_cache_version():
return (2,)
return (3,)
def c_code(self, node, name, inputs, outputs, sub): #DEBUG
part0 = self.helper_c_code(node, name, inputs, outputs, sub,
......@@ -3446,6 +3463,8 @@ class Subtensor(Op):
def c_code_cache_version(self):
hv = self.helper_c_code_cache_version()
if len(hv) == 0:
return ()
return (1, hv)
def R_op(self, inputs, eval_points):
......
......@@ -1957,6 +1957,22 @@ class T_subtensor(unittest.TestCase):
self.assertTrue(tval.shape == (2,))
self.assertTrue(numpy.allclose(tval, n.get_value()[idx]))
def test1_0_dims(self):
n = self.shared(numpy.ones((), dtype=self.dtype))
t = theano.tensor.Subtensor([])(n)
self.assertTrue(isinstance(t.owner.op, Subtensor))
# Silence expected error messages
_logger = logging.getLogger('theano.gof.opt')
oldlevel = _logger.level
_logger.setLevel(logging.CRITICAL)
mode = self.mode
self.mode = mode.excluding("local_useless_subtensor")
try:
self.eval_output_and_check(t)
finally:
self.mode = mode
_logger.setLevel(oldlevel)
def test1_err_invalid(self):
n = self.shared(numpy.ones(1, dtype=self.dtype))
try:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论