提交 c4b04aeb authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix chi2sf, it always return float64!

上级 f04d5cee
...@@ -854,6 +854,14 @@ def upgrade_to_float(*types): ...@@ -854,6 +854,14 @@ def upgrade_to_float(*types):
for type in types])), for type in types])),
def upgrade_to_float64(*types):
"""
Upgrade any int and float32 to float64 to do as SciPy.
"""
return get_scalar_type('float64'),
def same_out(type): def same_out(type):
return type, return type,
......
...@@ -6,6 +6,7 @@ import numpy ...@@ -6,6 +6,7 @@ import numpy
import theano import theano
from theano.scalar.basic import (UnaryScalarOp, BinaryScalarOp, from theano.scalar.basic import (UnaryScalarOp, BinaryScalarOp,
exp, upgrade_to_float, exp, upgrade_to_float,
upgrade_to_float64,
float_types) float_types)
from theano.scalar.basic import (upgrade_to_float_no_complex, from theano.scalar.basic import (upgrade_to_float_no_complex,
complex_types, discrete_types, complex_types, discrete_types,
...@@ -369,7 +370,7 @@ class Chi2SF(BinaryScalarOp): ...@@ -369,7 +370,7 @@ class Chi2SF(BinaryScalarOp):
return Chi2SF.st_impl(x, k) return Chi2SF.st_impl(x, k)
else: else:
super(Chi2SF, self).impl(x, k) super(Chi2SF, self).impl(x, k)
chi2sf = Chi2SF(upgrade_to_float, name='chi2sf') chi2sf = Chi2SF(upgrade_to_float64, name='chi2sf')
class J1(UnaryScalarOp): class J1(UnaryScalarOp):
......
...@@ -1718,7 +1718,7 @@ if imported_scipy_special: ...@@ -1718,7 +1718,7 @@ if imported_scipy_special:
expected_gamma = scipy.special.gamma expected_gamma = scipy.special.gamma
expected_gammaln = scipy.special.gammaln expected_gammaln = scipy.special.gammaln
expected_psi = scipy.special.psi expected_psi = scipy.special.psi
expected_chi2sf = lambda x, df: scipy.stats.chi2.sf(x, df).astype(x.dtype) expected_chi2sf = scipy.stats.chi2.sf
expected_j0 = scipy.special.j0 expected_j0 = scipy.special.j0
expected_j1 = scipy.special.j1 expected_j1 = scipy.special.j1
skip_scipy = False skip_scipy = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论