提交 de93a12f authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2536 from kyleam/fix-pymc-660

Move values_eq_approx_remove* methods to functions
......@@ -24,6 +24,7 @@ from theano.gof import Apply
from theano.tensor.nnet.sigm import sigmoid, softplus
from theano.gradient import DisconnectedType
from theano.gradient import grad_not_implemented
from theano.tensor.type import values_eq_approx_remove_nan
############
......@@ -1965,7 +1966,7 @@ def make_out_pattern(X):
out_var = stabilized_X - tensor.log(tensor.exp(stabilized_X).sum(
axis=1)).dimshuffle(0, 'x')
#tell DEBUG_MODE that it's OK if the original graph produced NaN and the optimized graph does not
out_var.values_eq_approx = out_var.type.values_eq_approx_remove_nan
out_var.values_eq_approx = values_eq_approx_remove_nan
return out_var
......
......@@ -35,6 +35,9 @@ from theano import scalar
from theano.tensor import basic as T
from theano import compile # to register the optimizer built by this file
from theano.compile.ops import Shape_i
from theano.tensor.type import (values_eq_approx_remove_inf,
values_eq_approx_remove_nan,
values_eq_approx_remove_inf_nan)
from theano.gof.python25 import any, all
from theano.gof.opt import (Optimizer, pre_constant_merge,
......@@ -2883,8 +2886,7 @@ def local_mul_switch_sink(node):
listmul = node.inputs[:idx] + node.inputs[idx + 1:]
fct = [T.switch(switch.inputs[0], 0,
T.mul(*(listmul + [switch.inputs[2]])))]
fct[0].values_eq_approx = fct[
0].type.values_eq_approx_remove_nan
fct[0].values_eq_approx = values_eq_approx_remove_nan
return fct
except NotScalarConstantError:
pass
......@@ -2894,8 +2896,7 @@ def local_mul_switch_sink(node):
listmul = node.inputs[:idx] + node.inputs[idx + 1:]
fct = [T.switch(switch.inputs[0],
T.mul(*(listmul + [switch.inputs[1]])), 0)]
fct[0].values_eq_approx = fct[
0].type.values_eq_approx_remove_nan
fct[0].values_eq_approx = values_eq_approx_remove_nan
return fct
except NotScalarConstantError:
pass
......@@ -2925,8 +2926,7 @@ def local_div_switch_sink(node):
if get_scalar_constant_value(switch.inputs[1]) == 0.:
fct = [T.switch(switch.inputs[0], 0,
op(switch.inputs[2], node.inputs[1]))]
fct[0].values_eq_approx = fct[
0].type.values_eq_approx_remove_nan
fct[0].values_eq_approx = values_eq_approx_remove_nan
return fct
except NotScalarConstantError:
pass
......@@ -2934,8 +2934,7 @@ def local_div_switch_sink(node):
if get_scalar_constant_value(switch.inputs[2]) == 0.:
fct = [T.switch(switch.inputs[0],
op(switch.inputs[1], node.inputs[1]), 0)]
fct[0].values_eq_approx = fct[
0].type.values_eq_approx_remove_nan
fct[0].values_eq_approx = values_eq_approx_remove_nan
return fct
except NotScalarConstantError:
pass
......@@ -4474,7 +4473,7 @@ def local_log_add(node):
ret = max_pre + T.log1p(T.exp(T.add(*[p - max_pre
for p in pre_exp])))
ret.values_eq_approx = ret.type.values_eq_approx_remove_inf
ret.values_eq_approx = values_eq_approx_remove_inf
return [ret]
......@@ -4899,7 +4898,7 @@ def local_log_erfc(node):
threshold = 26.641747557
ret = T.switch(x < threshold, node.outputs[0], stab_value)
ret.values_eq_approx = ret.type.values_eq_approx_remove_inf
ret.values_eq_approx = values_eq_approx_remove_inf
return [ret]
......@@ -5046,7 +5045,7 @@ def local_grad_log_erfc_neg(node):
elif x.dtype == 'float64':
threshold = 26.641747557
ret = T.switch(x < threshold, true_div_no_mul, stab_value) * y
ret.values_eq_approx = ret.type.values_eq_approx_remove_inf_nan
ret.values_eq_approx = values_eq_approx_remove_inf_nan
return [ret]
"""
......
......@@ -378,18 +378,6 @@ class TensorType(Type):
return False
@staticmethod
def values_eq_approx_remove_inf(a, b):
return TensorType.values_eq_approx(a, b, True)
@staticmethod
def values_eq_approx_remove_nan(a, b):
return TensorType.values_eq_approx(a, b, False, True)
@staticmethod
def values_eq_approx_remove_inf_nan(a, b):
return TensorType.values_eq_approx(a, b, True, True)
def __hash__(self):
"""Hash equal for same kinds of TensorType"""
return hashtype(self) ^ hash(self.dtype) ^ hash(self.broadcastable)
......@@ -629,6 +617,19 @@ class TensorType(Type):
return numpy.dtype(self.dtype).itemsize
theano.compile.ops.expandable_types += (TensorType,)
def values_eq_approx_remove_inf(a, b):
return TensorType.values_eq_approx(a, b, True)
def values_eq_approx_remove_nan(a, b):
return TensorType.values_eq_approx(a, b, False, True)
def values_eq_approx_remove_inf_nan(a, b):
return TensorType.values_eq_approx(a, b, True, True)
# Register TensorType C code for ViewOp.
theano.compile.register_view_op_c_code(
TensorType,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论