提交 0fb5fe65 authored 作者: Kyle Meyer's avatar Kyle Meyer

Move values_eq_approx_remove* methods to functions

Move values_eq_approx_remove_inf, values_eq_approx_remove_nan, and values_eq_approx_remove_inf_nan static methods to top-level functions to fix pickling error that occurred during parallel sampling with PyMC (https://github.com/pymc-devs/pymc3/issues/660).
上级 e67a5ba0
...@@ -24,6 +24,7 @@ from theano.gof import Apply ...@@ -24,6 +24,7 @@ from theano.gof import Apply
from theano.tensor.nnet.sigm import sigmoid, softplus from theano.tensor.nnet.sigm import sigmoid, softplus
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
from theano.gradient import grad_not_implemented from theano.gradient import grad_not_implemented
from theano.tensor.type import values_eq_approx_remove_nan
############ ############
...@@ -1965,7 +1966,7 @@ def make_out_pattern(X): ...@@ -1965,7 +1966,7 @@ def make_out_pattern(X):
out_var = stabilized_X - tensor.log(tensor.exp(stabilized_X).sum( out_var = stabilized_X - tensor.log(tensor.exp(stabilized_X).sum(
axis=1)).dimshuffle(0, 'x') axis=1)).dimshuffle(0, 'x')
#tell DEBUG_MODE that it's OK if the original graph produced NaN and the optimized graph does not #tell DEBUG_MODE that it's OK if the original graph produced NaN and the optimized graph does not
out_var.values_eq_approx = out_var.type.values_eq_approx_remove_nan out_var.values_eq_approx = values_eq_approx_remove_nan
return out_var return out_var
......
...@@ -35,6 +35,9 @@ from theano import scalar ...@@ -35,6 +35,9 @@ from theano import scalar
from theano.tensor import basic as T from theano.tensor import basic as T
from theano import compile # to register the optimizer built by this file from theano import compile # to register the optimizer built by this file
from theano.compile.ops import Shape_i from theano.compile.ops import Shape_i
from theano.tensor.type import (values_eq_approx_remove_inf,
values_eq_approx_remove_nan,
values_eq_approx_remove_inf_nan)
from theano.gof.python25 import any, all from theano.gof.python25 import any, all
from theano.gof.opt import (Optimizer, pre_constant_merge, from theano.gof.opt import (Optimizer, pre_constant_merge,
...@@ -2883,8 +2886,7 @@ def local_mul_switch_sink(node): ...@@ -2883,8 +2886,7 @@ def local_mul_switch_sink(node):
listmul = node.inputs[:idx] + node.inputs[idx + 1:] listmul = node.inputs[:idx] + node.inputs[idx + 1:]
fct = [T.switch(switch.inputs[0], 0, fct = [T.switch(switch.inputs[0], 0,
T.mul(*(listmul + [switch.inputs[2]])))] T.mul(*(listmul + [switch.inputs[2]])))]
fct[0].values_eq_approx = fct[ fct[0].values_eq_approx = values_eq_approx_remove_nan
0].type.values_eq_approx_remove_nan
return fct return fct
except NotScalarConstantError: except NotScalarConstantError:
pass pass
...@@ -2894,8 +2896,7 @@ def local_mul_switch_sink(node): ...@@ -2894,8 +2896,7 @@ def local_mul_switch_sink(node):
listmul = node.inputs[:idx] + node.inputs[idx + 1:] listmul = node.inputs[:idx] + node.inputs[idx + 1:]
fct = [T.switch(switch.inputs[0], fct = [T.switch(switch.inputs[0],
T.mul(*(listmul + [switch.inputs[1]])), 0)] T.mul(*(listmul + [switch.inputs[1]])), 0)]
fct[0].values_eq_approx = fct[ fct[0].values_eq_approx = values_eq_approx_remove_nan
0].type.values_eq_approx_remove_nan
return fct return fct
except NotScalarConstantError: except NotScalarConstantError:
pass pass
...@@ -2925,8 +2926,7 @@ def local_div_switch_sink(node): ...@@ -2925,8 +2926,7 @@ def local_div_switch_sink(node):
if get_scalar_constant_value(switch.inputs[1]) == 0.: if get_scalar_constant_value(switch.inputs[1]) == 0.:
fct = [T.switch(switch.inputs[0], 0, fct = [T.switch(switch.inputs[0], 0,
op(switch.inputs[2], node.inputs[1]))] op(switch.inputs[2], node.inputs[1]))]
fct[0].values_eq_approx = fct[ fct[0].values_eq_approx = values_eq_approx_remove_nan
0].type.values_eq_approx_remove_nan
return fct return fct
except NotScalarConstantError: except NotScalarConstantError:
pass pass
...@@ -2934,8 +2934,7 @@ def local_div_switch_sink(node): ...@@ -2934,8 +2934,7 @@ def local_div_switch_sink(node):
if get_scalar_constant_value(switch.inputs[2]) == 0.: if get_scalar_constant_value(switch.inputs[2]) == 0.:
fct = [T.switch(switch.inputs[0], fct = [T.switch(switch.inputs[0],
op(switch.inputs[1], node.inputs[1]), 0)] op(switch.inputs[1], node.inputs[1]), 0)]
fct[0].values_eq_approx = fct[ fct[0].values_eq_approx = values_eq_approx_remove_nan
0].type.values_eq_approx_remove_nan
return fct return fct
except NotScalarConstantError: except NotScalarConstantError:
pass pass
...@@ -4474,7 +4473,7 @@ def local_log_add(node): ...@@ -4474,7 +4473,7 @@ def local_log_add(node):
ret = max_pre + T.log1p(T.exp(T.add(*[p - max_pre ret = max_pre + T.log1p(T.exp(T.add(*[p - max_pre
for p in pre_exp]))) for p in pre_exp])))
ret.values_eq_approx = ret.type.values_eq_approx_remove_inf ret.values_eq_approx = values_eq_approx_remove_inf
return [ret] return [ret]
...@@ -4899,7 +4898,7 @@ def local_log_erfc(node): ...@@ -4899,7 +4898,7 @@ def local_log_erfc(node):
threshold = 26.641747557 threshold = 26.641747557
ret = T.switch(x < threshold, node.outputs[0], stab_value) ret = T.switch(x < threshold, node.outputs[0], stab_value)
ret.values_eq_approx = ret.type.values_eq_approx_remove_inf ret.values_eq_approx = values_eq_approx_remove_inf
return [ret] return [ret]
...@@ -5046,7 +5045,7 @@ def local_grad_log_erfc_neg(node): ...@@ -5046,7 +5045,7 @@ def local_grad_log_erfc_neg(node):
elif x.dtype == 'float64': elif x.dtype == 'float64':
threshold = 26.641747557 threshold = 26.641747557
ret = T.switch(x < threshold, true_div_no_mul, stab_value) * y ret = T.switch(x < threshold, true_div_no_mul, stab_value) * y
ret.values_eq_approx = ret.type.values_eq_approx_remove_inf_nan ret.values_eq_approx = values_eq_approx_remove_inf_nan
return [ret] return [ret]
""" """
......
...@@ -378,18 +378,6 @@ class TensorType(Type): ...@@ -378,18 +378,6 @@ class TensorType(Type):
return False return False
@staticmethod
def values_eq_approx_remove_inf(a, b):
return TensorType.values_eq_approx(a, b, True)
@staticmethod
def values_eq_approx_remove_nan(a, b):
return TensorType.values_eq_approx(a, b, False, True)
@staticmethod
def values_eq_approx_remove_inf_nan(a, b):
return TensorType.values_eq_approx(a, b, True, True)
def __hash__(self): def __hash__(self):
"""Hash equal for same kinds of TensorType""" """Hash equal for same kinds of TensorType"""
return hashtype(self) ^ hash(self.dtype) ^ hash(self.broadcastable) return hashtype(self) ^ hash(self.dtype) ^ hash(self.broadcastable)
...@@ -629,6 +617,19 @@ class TensorType(Type): ...@@ -629,6 +617,19 @@ class TensorType(Type):
return numpy.dtype(self.dtype).itemsize return numpy.dtype(self.dtype).itemsize
theano.compile.ops.expandable_types += (TensorType,) theano.compile.ops.expandable_types += (TensorType,)
def values_eq_approx_remove_inf(a, b):
return TensorType.values_eq_approx(a, b, True)
def values_eq_approx_remove_nan(a, b):
return TensorType.values_eq_approx(a, b, False, True)
def values_eq_approx_remove_inf_nan(a, b):
return TensorType.values_eq_approx(a, b, True, True)
# Register TensorType C code for ViewOp. # Register TensorType C code for ViewOp.
theano.compile.register_view_op_c_code( theano.compile.register_view_op_c_code(
TensorType, TensorType,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论