提交 324184a7 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Fixed a bug involving shape_i op and constant folding. In short, new constants…

Fixed a bug involving shape_i op and constant folding. In short, new constants introduced by the constant folding optimization were not inserted in shape_of dictionary
上级 ae6a445d
...@@ -24,6 +24,8 @@ from theano import compile #to register the optimizer built by this file ...@@ -24,6 +24,8 @@ from theano import compile #to register the optimizer built by this file
from theano.gof.python25 import any, all from theano.gof.python25 import any, all
from theano.gof.opt import Optimizer from theano.gof.opt import Optimizer
from theano.gof import toolbox, DestroyHandler from theano.gof import toolbox, DestroyHandler
# Utilities # Utilities
def out2in(*local_opts): def out2in(*local_opts):
...@@ -395,6 +397,13 @@ class ShapeFeature(object): ...@@ -395,6 +397,13 @@ class ShapeFeature(object):
else: else:
self.shape_of[r] = tuple([self.unpack(s_i) for s_i in s]) self.shape_of[r] = tuple([self.unpack(s_i) for s_i in s])
def init_r(self,r):
if r not in self.shape_of:
try:
self.set_shape(r, self.shape_tuple(r))
except AttributeError:
self.set_shape(r,None)
def make_vector_shape(self, r): def make_vector_shape(self, r):
return make_vector(*self.shape_of[r]) return make_vector(*self.shape_of[r])
# #
...@@ -421,11 +430,7 @@ class ShapeFeature(object): ...@@ -421,11 +430,7 @@ class ShapeFeature(object):
for i, r in enumerate(node.inputs): for i, r in enumerate(node.inputs):
# make sure we have shapes for the inputs # make sure we have shapes for the inputs
if r not in self.shape_of: self.init_r(r)
try:
self.set_shape(r, self.shape_tuple(r))
except AttributeError:
self.set_shape(r, None ) # not a TensorType variable
try: try:
shape_infer = node.op.infer_shape shape_infer = node.op.infer_shape
...@@ -453,7 +458,7 @@ class ShapeFeature(object): ...@@ -453,7 +458,7 @@ class ShapeFeature(object):
# TODO: # TODO:
# This tells us that r and new_r must have the same shape # This tells us that r and new_r must have the same shape
# if we didn't know that the shapes are related, now we do. # if we didn't know that the shapes are related, now we do.
self.init_r(new_r)
# change_input happens in two cases: # change_input happens in two cases:
# 1) we are trying to get rid of r, or # 1) we are trying to get rid of r, or
# 2) we are putting things back after a failed transaction. # 2) we are putting things back after a failed transaction.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论