提交 bcae5c00 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Change T.shape for T._shape, since the latter is an instance of the Shape Op.

上级 4ac9cbd2
......@@ -159,34 +159,35 @@ register_canonicalize(local_dimshuffle_lift)
# Shape lifters #
#################
@gof.local_optimizer([T.shape, None])
@gof.local_optimizer([T._shape, None])
def local_shape_lift_elemwise(node):
"""
shape(elemwise_op(..., x, ...)) -> shape(x)
Where x contains the maximal shape information.
"""
if not opt.check_chain(node, T.shape, T.Elemwise):
if not opt.check_chain(node, T._shape, T.Elemwise):
return False
output = node.inputs[0]
parent = output.owner
for input in parent.inputs:
if input.type.broadcastable == output.type.broadcastable:
return T.shape.make_node(input).outputs
return T._shape(input),
return False
register_canonicalize(local_shape_lift_elemwise, 'shape_lift')
register_specialize(local_shape_lift_elemwise, 'shape_lift')
@gof.local_optimizer([T.shape, None])
@gof.local_optimizer([T._shape, None])
def local_shape_lift_sum(node):
"""
shape(sum{n}(x)) -> [shape(x)[0], ..., shape(x)[n-1], shape(x)[n+1], ...]
"""
if not opt.check_chain(node, T.shape, T.Sum):
if not opt.check_chain(node, T._shape, T.Sum):
return False
input = node.inputs[0].owner.inputs[0]
......@@ -195,7 +196,7 @@ def local_shape_lift_sum(node):
axis = range(input.type.ndim)
ish = T.shape(input)
ish = T._shape(input)
return T.make_lvector.make_node(*(ish[i] for i in xrange(input.type.ndim) if i not in axis)).outputs
# return T.vertical_stack.make_node(ish[:axis], ish[axis+1:]).outputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论