提交 bcae5c00 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Change T.shape for T._shape, since the latter is an instance of the Shape Op.

上级 4ac9cbd2
...@@ -159,34 +159,35 @@ register_canonicalize(local_dimshuffle_lift) ...@@ -159,34 +159,35 @@ register_canonicalize(local_dimshuffle_lift)
# Shape lifters # # Shape lifters #
################# #################
@gof.local_optimizer([T.shape, None]) @gof.local_optimizer([T._shape, None])
def local_shape_lift_elemwise(node): def local_shape_lift_elemwise(node):
""" """
shape(elemwise_op(..., x, ...)) -> shape(x) shape(elemwise_op(..., x, ...)) -> shape(x)
Where x contains the maximal shape information. Where x contains the maximal shape information.
""" """
if not opt.check_chain(node, T.shape, T.Elemwise): if not opt.check_chain(node, T._shape, T.Elemwise):
return False return False
output = node.inputs[0] output = node.inputs[0]
parent = output.owner parent = output.owner
for input in parent.inputs: for input in parent.inputs:
if input.type.broadcastable == output.type.broadcastable: if input.type.broadcastable == output.type.broadcastable:
return T.shape.make_node(input).outputs return T._shape(input),
return False return False
register_canonicalize(local_shape_lift_elemwise, 'shape_lift') register_canonicalize(local_shape_lift_elemwise, 'shape_lift')
register_specialize(local_shape_lift_elemwise, 'shape_lift')
@gof.local_optimizer([T.shape, None]) @gof.local_optimizer([T._shape, None])
def local_shape_lift_sum(node): def local_shape_lift_sum(node):
""" """
shape(sum{n}(x)) -> [shape(x)[0], ..., shape(x)[n-1], shape(x)[n+1], ...] shape(sum{n}(x)) -> [shape(x)[0], ..., shape(x)[n-1], shape(x)[n+1], ...]
""" """
if not opt.check_chain(node, T.shape, T.Sum): if not opt.check_chain(node, T._shape, T.Sum):
return False return False
input = node.inputs[0].owner.inputs[0] input = node.inputs[0].owner.inputs[0]
...@@ -195,7 +196,7 @@ def local_shape_lift_sum(node): ...@@ -195,7 +196,7 @@ def local_shape_lift_sum(node):
axis = range(input.type.ndim) axis = range(input.type.ndim)
ish = T.shape(input) ish = T._shape(input)
return T.make_lvector.make_node(*(ish[i] for i in xrange(input.type.ndim) if i not in axis)).outputs return T.make_lvector.make_node(*(ish[i] for i in xrange(input.type.ndim) if i not in axis)).outputs
# return T.vertical_stack.make_node(ish[:axis], ish[axis+1:]).outputs # return T.vertical_stack.make_node(ish[:axis], ish[axis+1:]).outputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论