Unverified 提交 450dd859 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: GitHub

Make Canonizer always collapse nested ops (#4)

Closes #6685.
上级 149976c2
...@@ -4992,8 +4992,18 @@ class Canonizer(gof.LocalOptimizer): ...@@ -4992,8 +4992,18 @@ class Canonizer(gof.LocalOptimizer):
return len(x) == len(y) and all(np.all(xe == ye) for xe, ye in return len(x) == len(y) and all(np.all(xe == ye) for xe, ye in
zip(x, y)) zip(x, y))
if same(orig_num, num) and same(orig_denum, denum): if (same(orig_num, num) and same(orig_denum, denum) and
# We return False if there are no changes # Check to see if we've collapsed some nested ops.
not (len(orig_denum) == 0 and
# Make sure this change would increase the number of vector
# arguments--decreasing the number of unnecessary `self.main`
# nodes.
len(node.inputs) < len(orig_num)) and
# Do a similar check for the reciprocal op.
not (self.use_reciprocal and
node.op == self.reciprocal and
len(orig_num) == 0 and node.inputs[0].owner and
len(node.inputs[0].owner.inputs) < len(orig_denum))):
return False return False
new = self.merge_num_denum(num, denum) new = self.merge_num_denum(num, denum)
......
...@@ -367,9 +367,9 @@ class test_canonize(unittest.TestCase): ...@@ -367,9 +367,9 @@ class test_canonize(unittest.TestCase):
cases = [ cases = [
(fx + fy, (fx, fy), (fxv, fyv), 1, 'float32'), (fx + fy, (fx, fy), (fxv, fyv), 1, 'float32'),
(fx * fy, (fx, fy), (fxv, fyv), 1, 'float32'), (fx * fy, (fx, fy), (fxv, fyv), 1, 'float32'),
# (fx+fy+fz,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'), (fx + fy + fz, (fx, fy, fz), (fxv, fyv, fzv), 1, 'float32'),
# (dx+dy+dz,(dx,dy,dz),(dxv,dyv,dzv),1,'float64'), # (dx+dy+dz,(dx,dy,dz),(dxv,dyv,dzv),1,'float64'),
# (fx*fy*fz,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'), (fx * fy * fz, (fx, fy, fz), (fxv, fyv, fzv), 1, 'float32'),
# (dx*dy*dz,(dx,dy,dz),(dxv,dyv,dzv),1,'float64'), # (dx*dy*dz,(dx,dy,dz),(dxv,dyv,dzv),1,'float64'),
# (fx*fy*(fx+fy+fz),(fx,fy,fz),(fxv,fyv,fzv),2,'float32'), # (fx*fy*(fx+fy+fz),(fx,fy,fz),(fxv,fyv,fzv),2,'float32'),
# (dx*dy*(dx+dy+dz),(dx,dy,dz),(dxv,dyv,dzv),2,'float64'), # (dx*dy*(dx+dy+dz),(dx,dy,dz),(dxv,dyv,dzv),2,'float64'),
...@@ -4988,7 +4988,7 @@ class T_local_erfc(unittest.TestCase): ...@@ -4988,7 +4988,7 @@ class T_local_erfc(unittest.TestCase):
f = theano.function([x], T.grad(T.log(T.erfc(x)).sum(), x), mode=mode) f = theano.function([x], T.grad(T.log(T.erfc(x)).sum(), x), mode=mode)
assert len(f.maker.fgraph.apply_nodes) == 23, len(f.maker.fgraph.apply_nodes) assert len(f.maker.fgraph.apply_nodes) == 22, len(f.maker.fgraph.apply_nodes)
assert all(np.isfinite(f(val))) assert all(np.isfinite(f(val)))
assert f.maker.fgraph.outputs[0].dtype == theano.config.floatX assert f.maker.fgraph.outputs[0].dtype == theano.config.floatX
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论