提交 dcf27611 authored 作者: Frederic Bastien's avatar Frederic Bastien

added __eq__ fct to compare different module.

上级 d5841f8d
......@@ -20,15 +20,15 @@ class LogisticRegressionN(module.FancyModule):
self.b = rng.randn(n_out)
self.lr = 0.01
self.__hide__ = ['params']
# def __eq__(self, other):
# if not isinstance(other.component, LogisticRegressionN) and not isinstance(other.component, LogisticRegression2):
# raise NotImplemented
# #we compare the member.
# if (N.abs(self.w-other.w)<1e-8).all() and (N.abs(self.b-other.b)<1e-8).all() and self.stepsize == other.stepsize:
# return True
# return False
# def __hash__(self):
# raise NotImplemented
def __eq__(self, other):
if not isinstance(other.component, LogisticRegressionN) and not isinstance(other.component, LogisticRegression2):
raise NotImplemented
#we compare the member.
if (N.abs(self.w-other.w)<1e-8).all() and (N.abs(self.b-other.b)<1e-8).all() and self.lr == other.lr:
return True
return False
def __hash__(self):
raise NotImplemented
def __init__(self, x = None, targ = None):
super(LogisticRegressionN, self).__init__() #boilerplate
......@@ -67,7 +67,16 @@ class LogisticRegression2(module.FancyModule):
self.b = rng.randn(1)
self.lr = 0.01
self.__hide__ = ['params']
def __eq__(self, other):
if not isinstance(other.component, LogisticRegressionN) and not isinstance(other.component, LogisticRegression2):
raise NotImplemented
#we compare the member.
if (N.abs(self.w-other.w)<1e-8).all() and (N.abs(self.b-other.b)<1e-8).all() and self.lr == other.lr:
return True
return False
def __hash__(self):
raise NotImplemented
def __init__(self, x = None, targ = None):
super(LogisticRegression2, self).__init__() #boilerplate
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论