提交 3327d4ad authored 作者: Nicolas Bouchard's avatar Nicolas Bouchard 提交者: Frederic

Add BinomialTester.

上级 12955fcf
...@@ -689,7 +689,7 @@ class Binomial(gof.op.Op): ...@@ -689,7 +689,7 @@ class Binomial(gof.op.Op):
return None, None, None return None, None, None
def infer_shape(self, node, ins_shapes): def infer_shape(self, node, ins_shapes):
return [ins_shapes[2]] return [(node.inputs[2][0], node.inputs[2][1])]
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
......
...@@ -325,6 +325,46 @@ class MultinomialTester(utt.InferShapeTester): ...@@ -325,6 +325,46 @@ class MultinomialTester(utt.InferShapeTester):
self.op_class) self.op_class)
class BinomialTester(utt.InferShapeTester):
n = tensor.scalar()
p = tensor.scalar()
shape = tensor.lvector()
_n = 5
_p = .25
_shape = np.asarray([3, 5], dtype='int64')
inputs = [n, p, shape]
_inputs = [_n, _p, _shape]
def setUp(self):
super(BinomialTester, self).setUp()
self.op_class = S2.Binomial
def test_op(self):
for sp_format in sparse.sparse_formats:
for o_type in sparse.float_dtypes:
f = theano.function(
self.inputs,
S2.Binomial(sp_format, o_type)(*self.inputs))
tested = f(*self._inputs)
assert tested.shape == tuple(self._shape)
assert tested.format == sp_format
assert tested.dtype == o_type
assert np.allclose(np.floor(tested.todense()),
tested.todense())
def test_infer_shape(self):
for sp_format in sparse.sparse_formats:
for o_type in sparse.float_dtypes:
self._compile_and_check(
self.inputs,
[S2.Binomial(sp_format, o_type)(*self.inputs)],
self._inputs,
self.op_class)
class _StructuredMonoidUnaryTester(unittest.TestCase): class _StructuredMonoidUnaryTester(unittest.TestCase):
def test_op(self): def test_op(self):
for format in sparse.sparse_formats: for format in sparse.sparse_formats:
...@@ -472,7 +512,6 @@ class SamplingDotTester(utt.InferShapeTester): ...@@ -472,7 +512,6 @@ class SamplingDotTester(utt.InferShapeTester):
for maximum in [5, 5, 2]] for maximum in [5, 5, 2]]
a[2] = sp.csr_matrix(a[2]) a[2] = sp.csr_matrix(a[2])
def setUp(self): def setUp(self):
super(SamplingDotTester, self).setUp() super(SamplingDotTester, self).setUp()
self.op_class = S2.SamplingDot self.op_class = S2.SamplingDot
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论