提交 edad9e5a authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Allow user to specify shape of inputs by using .tag.shape. This avoids the

madness of elemwise operation introduced by infershape of subtensor and by scan and makes plots of computational graphs readable and possibly faster.
上级 819412de
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import logging import logging
_logger = logging.getLogger('theano.tensor.opt') _logger = logging.getLogger('theano.tensor.opt')
import copy
import operator import operator
import itertools import itertools
import sys import sys
...@@ -572,6 +573,11 @@ class ShapeFeature(object): ...@@ -572,6 +573,11 @@ class ShapeFeature(object):
if hasattr(r.type,"broadcastable") and r.type.broadcastable[i]: if hasattr(r.type,"broadcastable") and r.type.broadcastable[i]:
return self.lscalar_one return self.lscalar_one
# If user provided size
elif ( hasattr(r.tag,'shape') and
r.tag.shape is not None and
r.tag.shape[i] is not None):
return T.constant(copy.copy(r.tag.shape[i]),dtype='int64')
else: else:
return Shape_i(i).make_node(r).outputs[0] return Shape_i(i).make_node(r).outputs[0]
...@@ -2732,7 +2738,14 @@ register_specialize(local_mul_specialize) ...@@ -2732,7 +2738,14 @@ register_specialize(local_mul_specialize)
@gof.local_optimizer([T.add]) @gof.local_optimizer([T.add])
def local_add_specialize(node): def local_add_specialize(node):
def fill_chain(v): def fill_chain(v):
return _fill_chain(v, node.inputs) # Not sure why this happens .. but I did not had the time to look
# into it, it probably has something to do with the dtype I'm
# providing the tag.shape of my variable
out = _fill_chain(v, node.inputs)
if out[0].dtype != node.outputs[0].dtype:
return [T.cast(out[0], dtype = node.outputs[0].dtype)]
else:
return out
#here, we are past the point of canonicalization, so we don't want to put in un-necessary fills. #here, we are past the point of canonicalization, so we don't want to put in un-necessary fills.
if node.op == T.add: if node.op == T.add:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论