提交 c09be4a5 authored 作者: Frederic Bastien's avatar Frederic Bastien

small bugfix.

上级 00d10a13
...@@ -22,13 +22,13 @@ class LogisticRegressionN(module.FancyModule): ...@@ -22,13 +22,13 @@ class LogisticRegressionN(module.FancyModule):
self.__hide__ = ['params'] self.__hide__ = ['params']
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other.component, LogisticRegressionN) and not isinstance(other.component, LogisticRegression2): if not isinstance(other.component, LogisticRegressionN) and not isinstance(other.component, LogisticRegression2):
raise NotImplemented raise NotImplementedError
#we compare the member. #we compare the member.
if (N.abs(self.w-other.w)<1e-8).all() and (N.abs(self.b-other.b)<1e-8).all() and self.lr == other.lr: if (N.abs(self.w-other.w)<1e-8).all() and (N.abs(self.b-other.b)<1e-8).all() and self.lr == other.lr:
return True return True
return False return False
def __hash__(self): def __hash__(self):
raise NotImplemented raise NotImplementedError
def __init__(self, x = None, targ = None): def __init__(self, x = None, targ = None):
super(LogisticRegressionN, self).__init__() #boilerplate super(LogisticRegressionN, self).__init__() #boilerplate
...@@ -69,13 +69,13 @@ class LogisticRegression2(module.FancyModule): ...@@ -69,13 +69,13 @@ class LogisticRegression2(module.FancyModule):
self.__hide__ = ['params'] self.__hide__ = ['params']
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other.component, LogisticRegressionN) and not isinstance(other.component, LogisticRegression2): if not isinstance(other.component, LogisticRegressionN) and not isinstance(other.component, LogisticRegression2):
raise NotImplemented raise NotImplementedError
#we compare the member. #we compare the member.
if (N.abs(self.w-other.w)<1e-8).all() and (N.abs(self.b-other.b)<1e-8).all() and self.lr == other.lr: if (N.abs(self.w-other.w)<1e-8).all() and (N.abs(self.b-other.b)<1e-8).all() and self.lr == other.lr:
return True return True
return False return False
def __hash__(self): def __hash__(self):
raise NotImplemented raise NotImplementedError
def __init__(self, x = None, targ = None): def __init__(self, x = None, targ = None):
super(LogisticRegression2, self).__init__() #boilerplate super(LogisticRegression2, self).__init__() #boilerplate
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论