提交 bcd2bfea authored 作者: James Bergstra's avatar James Bergstra

adding tsor_apply file

上级 24294c33
"""Apply for use with Tensors that implements shape propagation via variable.tag.shape
"""
import sys
from theano import gof
def ishape(v):
try:
return (True, v.tag.shape)
except AttributeError:
return (False, (None,)*v.type.ndim)
class Apply(gof.Apply):
def __init__(self, op, inputs, outputs):
super(Apply, self).__init__(op, inputs, outputs)
# if any input has any shape info, then propagate it
try:
provided, ishapes = zip(*[ishape(i) for i in inputs])
except AttributeError:
# i.type.ndim didn't make sense for some i
return
if provided == [False for i in inputs]:
# no input had a tag.shape
return
try:
infer_shape = op.infer_shape
except AttributeError:
# op has no infer_shape, that's fine
return
oshapes = infer_shape(self, ishapes)
for o, oshp in zip(outputs, oshapes):
o.tag.shape = oshp
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论