提交 d1c51b24 authored 作者: Frederic Bastien's avatar Frederic Bastien

In ShapeFeature, use a different variable for each env...

上级 cd2fdd51
...@@ -491,8 +491,6 @@ class Shape_i(T.Op): ...@@ -491,8 +491,6 @@ class Shape_i(T.Op):
def grad(self, inp, grads): def grad(self, inp, grads):
return [None] return [None]
lscalar_one = T.constant(1, dtype='int64')
assert lscalar_one.type == T.lscalar
class ShapeFeature(object): class ShapeFeature(object):
"""Graph optimizer for removing all calls to shape() """Graph optimizer for removing all calls to shape()
...@@ -568,12 +566,12 @@ class ShapeFeature(object): ...@@ -568,12 +566,12 @@ class ShapeFeature(object):
sometimes Theano constants?? That would be confusing. sometimes Theano constants?? That would be confusing.
""" """
@staticmethod
def shape_ir(i, r): def shape_ir(self, i, r):
#TODO: Write a doc string for this method #TODO: Write a doc string for this method
if hasattr(r.type,"broadcastable") and r.type.broadcastable[i]: if hasattr(r.type,"broadcastable") and r.type.broadcastable[i]:
return lscalar_one return self.lscalar_one
else: else:
return Shape_i(i).make_node(r).outputs[0] return Shape_i(i).make_node(r).outputs[0]
...@@ -595,7 +593,7 @@ class ShapeFeature(object): ...@@ -595,7 +593,7 @@ class ShapeFeature(object):
assert s_i is not None assert s_i is not None
if s_i == 1: if s_i == 1:
# don't make the optimizer merge a zillion ones together # don't make the optimizer merge a zillion ones together
return lscalar_one return self.lscalar_one
if type(s_i) is int or isinstance(s_i, numpy.integer): if type(s_i) is int or isinstance(s_i, numpy.integer):
# this shape is a constant # this shape is a constant
assert s_i >= 0 assert s_i >= 0
...@@ -693,6 +691,11 @@ class ShapeFeature(object): ...@@ -693,6 +691,11 @@ class ShapeFeature(object):
def on_attach(self, env): def on_attach(self, env):
assert not hasattr(env, 'shape_feature') assert not hasattr(env, 'shape_feature')
env.shape_feature = self env.shape_feature = self
# Must be local to the object as otherwise we reuse the same
# variable for multiple env!
self.lscalar_one = T.constant(1, dtype='int64')
assert self.lscalar_one.type == T.lscalar
self.shape_of = {} # Variable -> tuple(scalars) or None (All tensor vars map to tuple) self.shape_of = {} # Variable -> tuple(scalars) or None (All tensor vars map to tuple)
self.scheduled = {} # Variable -> self.scheduled = {} # Variable ->
for node in env.toposort(): for node in env.toposort():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论