提交 fd3550ae authored 作者: Frederic Bastien's avatar Frederic Bastien

Reuse a constant one to make optimization faster(less constant merge).

上级 c00e1a1d
...@@ -491,6 +491,8 @@ class Shape_i(T.Op): ...@@ -491,6 +491,8 @@ class Shape_i(T.Op):
def grad(self, inp, grads): def grad(self, inp, grads):
return [None] return [None]
lscalar_one = T.constant(1, dtype='int64')
assert lscalar_one.type == T.lscalar
class ShapeFeature(object): class ShapeFeature(object):
"""Graph optimizer for removing all calls to shape() """Graph optimizer for removing all calls to shape()
...@@ -571,7 +573,7 @@ class ShapeFeature(object): ...@@ -571,7 +573,7 @@ class ShapeFeature(object):
#TODO: Write a doc string for this method #TODO: Write a doc string for this method
if hasattr(r.type,"broadcastable") and r.type.broadcastable[i]: if hasattr(r.type,"broadcastable") and r.type.broadcastable[i]:
return T.constant(1.0,dtype='int64') return lscalar_one
else: else:
return Shape_i(i).make_node(r).outputs[0] return Shape_i(i).make_node(r).outputs[0]
...@@ -593,7 +595,7 @@ class ShapeFeature(object): ...@@ -593,7 +595,7 @@ class ShapeFeature(object):
assert s_i is not None assert s_i is not None
if s_i == 1: if s_i == 1:
# don't make the optimizer merge a zillion ones together # don't make the optimizer merge a zillion ones together
return self.lscalar_one return lscalar_one
if type(s_i) is int or isinstance(s_i, numpy.integer): if type(s_i) is int or isinstance(s_i, numpy.integer):
# this shape is a constant # this shape is a constant
assert s_i >= 0 assert s_i >= 0
...@@ -693,8 +695,6 @@ class ShapeFeature(object): ...@@ -693,8 +695,6 @@ class ShapeFeature(object):
env.shape_feature = self env.shape_feature = self
self.shape_of = {} # Variable -> tuple(scalars) or None (All tensor vars map to tuple) self.shape_of = {} # Variable -> tuple(scalars) or None (All tensor vars map to tuple)
self.scheduled = {} # Variable -> self.scheduled = {} # Variable ->
self.lscalar_one = T.constant(1, dtype='int64')
assert self.lscalar_one.type == T.lscalar
for node in env.toposort(): for node in env.toposort():
self.on_import(env, node) self.on_import(env, node)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论