提交 3360e333 authored 作者: Frederic's avatar Frederic

Fix c code subtensor when the input have 0 dims and the outputs 0 dims.

Added test.
上级 ba9c7fd4
...@@ -3307,8 +3307,25 @@ class Subtensor(Op): ...@@ -3307,8 +3307,25 @@ class Subtensor(Op):
{ {
%(fail)s; %(fail)s;
} }
assert (xview->dimensions != %(x)s->dimensions);
assert (xview->strides != %(x)s->strides); if ((xview->dimensions == %(x)s->dimensions)
&& (%(x)s->dimensions != NULL))
{
PyErr_Format(PyExc_ValueError, "x and xview"
"(with %%d dims) have the same dimensions"
" pointors: %%p and %%p",
%(x)s->nd, xview->dimensions, %(x)s->dimensions);
%(fail)s;
}
if (xview->strides == %(x)s->strides
&& (%(x)s->dimensions != NULL))
{
PyErr_Format(PyExc_ValueError, "x and xview"
"(with %%d dims) have the same strides"
" pointors: %%p and %%p",
%(x)s->nd, xview->strides, %(x)s->strides);
%(fail)s;
}
for (; outer_ii < %(len_is_slice)s; ++outer_ii) for (; outer_ii < %(len_is_slice)s; ++outer_ii)
{ {
...@@ -3425,7 +3442,7 @@ class Subtensor(Op): ...@@ -3425,7 +3442,7 @@ class Subtensor(Op):
@staticmethod @staticmethod
def helper_c_code_cache_version(): def helper_c_code_cache_version():
return (2,) return (3,)
def c_code(self, node, name, inputs, outputs, sub): #DEBUG def c_code(self, node, name, inputs, outputs, sub): #DEBUG
part0 = self.helper_c_code(node, name, inputs, outputs, sub, part0 = self.helper_c_code(node, name, inputs, outputs, sub,
...@@ -3446,6 +3463,8 @@ class Subtensor(Op): ...@@ -3446,6 +3463,8 @@ class Subtensor(Op):
def c_code_cache_version(self): def c_code_cache_version(self):
hv = self.helper_c_code_cache_version() hv = self.helper_c_code_cache_version()
if len(hv) == 0:
return ()
return (1, hv) return (1, hv)
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
......
...@@ -1957,6 +1957,22 @@ class T_subtensor(unittest.TestCase): ...@@ -1957,6 +1957,22 @@ class T_subtensor(unittest.TestCase):
self.assertTrue(tval.shape == (2,)) self.assertTrue(tval.shape == (2,))
self.assertTrue(numpy.allclose(tval, n.get_value()[idx])) self.assertTrue(numpy.allclose(tval, n.get_value()[idx]))
def test1_0_dims(self):
n = self.shared(numpy.ones((), dtype=self.dtype))
t = theano.tensor.Subtensor([])(n)
self.assertTrue(isinstance(t.owner.op, Subtensor))
# Silence expected error messages
_logger = logging.getLogger('theano.gof.opt')
oldlevel = _logger.level
_logger.setLevel(logging.CRITICAL)
mode = self.mode
self.mode = mode.excluding("local_useless_subtensor")
try:
self.eval_output_and_check(t)
finally:
self.mode = mode
_logger.setLevel(oldlevel)
def test1_err_invalid(self): def test1_err_invalid(self):
n = self.shared(numpy.ones(1, dtype=self.dtype)) n = self.shared(numpy.ones(1, dtype=self.dtype))
try: try:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论