提交 9b7319ba authored 作者: James Bergstra's avatar James Bergstra

disabling TensorType.shape pending support for partially-known type attributes

上级 df700c8f
...@@ -196,20 +196,20 @@ def local_shape_lift_sum(node): ...@@ -196,20 +196,20 @@ def local_shape_lift_sum(node):
register_canonicalize(local_shape_lift_sum, 'shape_lift') register_canonicalize(local_shape_lift_sum, 'shape_lift')
@gof.local_optimizer([T.shape, T.dot]) @gof.local_optimizer([T._shape, T.dot])
def local_shape_lift_dot(node): def local_shape_lift_dot(node):
""" """
shape(dot(a, b)) -> [shape(a)[0], shape(b)[1]] shape(dot(a, b)) -> [shape(a)[0], shape(b)[1]]
""" """
if not opt.check_chain(node, T.shape, T.dot): if not opt.check_chain(node, T._shape, T.dot):
return False return False
a, b = node.inputs[0].owner.inputs a, b = node.inputs[0].owner.inputs
if a.type.ndim == 2 and b.type.ndim == 2: if a.type.ndim == 2 and b.type.ndim == 2:
return T.make_lvector.make_node(T.shape(a)[0], T.shape(b)[1]).outputs return T.make_lvector.make_node(T._shape(a)[0], T._shape(b)[1]).outputs
elif a.type.ndim == 1 and b.type.ndim == 2: elif a.type.ndim == 1 and b.type.ndim == 2:
return T.make_lvector.make_node(T.shape(b)[1]).outputs return T.make_lvector.make_node(T._shape(b)[1]).outputs
elif a.type.ndim == 2 and b.type.ndim == 1: elif a.type.ndim == 2 and b.type.ndim == 1:
return T.make_lvector.make_node(T.shape(a)[0]).outputs return T.make_lvector.make_node(T._shape(a)[0]).outputs
elif a.type.ndim == 1 and b.type.ndim == 1: elif a.type.ndim == 1 and b.type.ndim == 1:
return T.make_lvector.make_node().outputs return T.make_lvector.make_node().outputs
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论