提交 9184ad40 authored 作者: James Bergstra's avatar James Bergstra

made tsor_apply more robust and pass all test cases

上级 5bcefac4
...@@ -13,6 +13,8 @@ def ishape(v): ...@@ -13,6 +13,8 @@ def ishape(v):
class Apply(gof.Apply): class Apply(gof.Apply):
def __init__(self, op, inputs, outputs): def __init__(self, op, inputs, outputs):
super(Apply, self).__init__(op, inputs, outputs) super(Apply, self).__init__(op, inputs, outputs)
if not inputs:
return
# if any input has any shape info, then propagate it # if any input has any shape info, then propagate it
try: try:
provided, ishapes = zip(*[ishape(i) for i in inputs]) provided, ishapes = zip(*[ishape(i) for i in inputs])
...@@ -28,7 +30,11 @@ class Apply(gof.Apply): ...@@ -28,7 +30,11 @@ class Apply(gof.Apply):
# op has no infer_shape, that's fine # op has no infer_shape, that's fine
return return
oshapes = infer_shape(self, ishapes) try:
oshapes = infer_shape(self, ishapes)
except NotImplementedError:
return
for o, oshp in zip(outputs, oshapes): for o, oshp in zip(outputs, oshapes):
o.tag.shape = oshp o.tag.shape = oshp
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论