提交 f88a7b18 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Typo fix in message and minor code simplification

上级 f6de708b
...@@ -521,14 +521,13 @@ class Rebroadcast(gof.Op): ...@@ -521,14 +521,13 @@ class Rebroadcast(gof.Op):
self.axis = dict(axis) self.axis = dict(axis)
for axis, broad in self.axis.iteritems(): for axis, broad in self.axis.iteritems():
assert isinstance(axis, (numpy.integer, int)), ( assert isinstance(axis, (numpy.integer, int)), (
"Rebroadcast need integers axis. Got ", axis) "Rebroadcast needs integer axes. Got ", axis)
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.axis == other.axis return type(self) == type(other) and self.axis == other.axis
def __hash__(self): def __hash__(self):
items = self.axis.items() items = sorted(self.axis.iteritems()) # no ambiguity because each item key is unique
items.sort() # no ambiguity because each item key is unique
return hash(type(self)) ^ hash(tuple(items)) return hash(type(self)) ^ hash(tuple(items))
def __str__(self): def __str__(self):
...@@ -536,7 +535,7 @@ class Rebroadcast(gof.Op): ...@@ -536,7 +535,7 @@ class Rebroadcast(gof.Op):
broadcast_pattern = [] broadcast_pattern = []
else: else:
broadcast_pattern = ['?' for i broadcast_pattern = ['?' for i
in xrange(1 + numpy.max(self.axis.keys()))] in xrange(1 + max(self.axis))]
for k, v in self.axis.iteritems(): for k, v in self.axis.iteritems():
broadcast_pattern[k] = str(int(v)) broadcast_pattern[k] = str(int(v))
return '%s{%s}' % (self.__class__.__name__, return '%s{%s}' % (self.__class__.__name__,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论