提交 559e603d authored 作者: fsavard's avatar fsavard

Added __setstate__ in Prod to update old pickled Prod objects to have no_zeros_in_input attribute

上级 e8c50c78
......@@ -1143,6 +1143,11 @@ class Prod(CAReduce):
self.no_zeros_in_input = no_zeros_in_input
def __setstate__(self, dct):
self.__dict__.update(dct)
if 'no_zeros_in_input' not in dct:
self.no_zeros_in_input = False
def __eq__(self, other):
return type(self) == type(other) and self.scalar_op == other.scalar_op and self.axis == other.axis and self.no_zeros_in_input == other.no_zeros_in_input
......@@ -1207,7 +1212,6 @@ class Prod(CAReduce):
prod_without_zeros = ProdWithoutZeros(axis=self.axis)(prod_without_zeros_in)
prod_without_zeros = prod_without_zeros.dimshuffle(new_dims)
groups_without_zeros = T.eq(sum_where_zeros, 0.0).dimshuffle(new_dims)
final_grad = T.switch(groups_without_zeros, grad_case_without_zeros,
T.switch(where_single_zero, prod_without_zeros, 0.0) * gz)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论