提交 044bbaf0 authored 作者: Joseph Turian's avatar Joseph Turian

local_shape_lift_dot now works with vectors, not just matrices

上级 9be3ec1a
...@@ -186,7 +186,16 @@ def local_shape_lift_dot(node): ...@@ -186,7 +186,16 @@ def local_shape_lift_dot(node):
if not opt.check_chain(node, T.shape, T.dot): if not opt.check_chain(node, T.shape, T.dot):
return False return False
a, b = node.inputs[0].owner.inputs a, b = node.inputs[0].owner.inputs
if a.type.ndim == 2 and b.type.ndim == 2:
return T.make_lvector.make_node(T.shape(a)[0], T.shape(b)[1]).outputs return T.make_lvector.make_node(T.shape(a)[0], T.shape(b)[1]).outputs
elif a.type.ndim == 1 and b.type.ndim == 2:
return T.make_lvector.make_node(T.shape(b)[1]).outputs
elif a.type.ndim == 2 and b.type.ndim == 1:
return T.make_lvector.make_node(T.shape(a)[0]).outputs
elif a.type.ndim == 1 and b.type.ndim == 1:
return T.make_lvector.make_node().outputs
else:
return False
register_canonicalize(local_shape_lift_dot, 'shape_lift') register_canonicalize(local_shape_lift_dot, 'shape_lift')
......
...@@ -189,6 +189,22 @@ def test_mixeddiv(): ...@@ -189,6 +189,22 @@ def test_mixeddiv():
d = dscalar() d = dscalar()
assert 0 == function([i,d], d*(i/(i+1)))(3, 1.0) assert 0 == function([i,d], d*(i/(i+1)))(3, 1.0)
def test_local_shape_lift_dot():
args_to_result = {
(fvector, fvector): "[]",
(fvector, fmatrix): "[<TensorType(float32, matrix)>.shape[1]]",
(fmatrix, fvector): "[<TensorType(float32, matrix)>.shape[0]]",
(fmatrix, fmatrix): "[<TensorType(float32, matrix)>.shape[0], <TensorType(float32, matrix)>.shape[1]]",
}
for x in [fvector, fmatrix]:
for y in [fvector, fmatrix]:
i = x()
j = y()
d = shape(dot(i,j))
g = Env([i,j], [d])
gof.TopoOptimizer(gof.LocalOptGroup(local_shape_lift_dot), order='out_to_in').optimize(g)
assert pprint(g.outputs[0]) == args_to_result[(x,y)]
# def test_plusmin(self): # def test_plusmin(self):
# x, y, z = inputs() # x, y, z = inputs()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论