提交 430f0b6a authored 作者: sentient07's avatar sentient07

Added test for special case of get_scalar_constant_value

上级 0797283e
...@@ -30,7 +30,7 @@ from theano.tensor.nnet.abstract_conv import (AbstractConv2d, ...@@ -30,7 +30,7 @@ from theano.tensor.nnet.abstract_conv import (AbstractConv2d,
from theano.tests.breakpoint import PdbBreakpoint from theano.tests.breakpoint import PdbBreakpoint
from .type import (GpuArrayType, GpuArrayConstant, get_context, from .type import (GpuArrayType, GpuArrayConstant, get_context,
ContextNotDefined, GpuArraySharedVariable, GpuArrayVariable) ContextNotDefined)
from .basic_ops import (as_gpuarray_variable, infer_context_name, from .basic_ops import (as_gpuarray_variable, infer_context_name,
host_from_gpu, GpuToGpu, host_from_gpu, GpuToGpu,
HostFromGpu, GpuFromHost, HostFromGpu, GpuFromHost,
......
...@@ -7003,6 +7003,9 @@ class T_get_scalar_constant_value(unittest.TestCase): ...@@ -7003,6 +7003,9 @@ class T_get_scalar_constant_value(unittest.TestCase):
assert get_scalar_constant_value(s) == 3 assert get_scalar_constant_value(s) == 3
s = opt.Shape_i(1)(c) s = opt.Shape_i(1)(c)
assert get_scalar_constant_value(s) == 4 assert get_scalar_constant_value(s) == 4
d = theano.tensor.constant(numpy.random.rand(1, 1))
f = theano.tensor.basic.ScalarFromTensor()(opt.Shape_i(0)(d))
assert get_scalar_constant_value(f) == 1
def test_elemwise(self): def test_elemwise(self):
# We test only for a few elemwise, the list of all supported # We test only for a few elemwise, the list of all supported
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论