提交 77634eed authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #4843 from lamblin/useless_reshape

Remove useless reshape
...@@ -21,6 +21,7 @@ from ..basic_ops import ( ...@@ -21,6 +21,7 @@ from ..basic_ops import (
host_from_gpu, HostFromGpu, GpuFromHost, GpuReshape, GpuToGpu, host_from_gpu, HostFromGpu, GpuFromHost, GpuReshape, GpuToGpu,
GpuAlloc, GpuAllocEmpty, GpuContiguous, GpuAlloc, GpuAllocEmpty, GpuContiguous,
gpu_join, GpuJoin, GpuSplit, GpuEye, gpu_contiguous) gpu_join, GpuJoin, GpuSplit, GpuEye, gpu_contiguous)
from ..elemwise import GpuDimShuffle, GpuElemwise
from ..subtensor import GpuSubtensor from ..subtensor import GpuSubtensor
from .config import mode_with_gpu, mode_without_gpu, test_ctx_name from .config import mode_with_gpu, mode_without_gpu, test_ctx_name
...@@ -324,7 +325,7 @@ class G_reshape(test_basic.T_reshape): ...@@ -324,7 +325,7 @@ class G_reshape(test_basic.T_reshape):
mode=mode_with_gpu, mode=mode_with_gpu,
ignore_topo=(HostFromGpu, GpuFromHost, ignore_topo=(HostFromGpu, GpuFromHost,
theano.compile.DeepCopyOp, theano.compile.DeepCopyOp,
theano.gpuarray.elemwise.GpuElemwise, GpuDimShuffle, GpuElemwise,
theano.tensor.opt.Shape_i, theano.tensor.opt.Shape_i,
theano.tensor.opt.MakeVector)) theano.tensor.opt.MakeVector))
assert self.op == GpuReshape assert self.op == GpuReshape
......
...@@ -39,7 +39,7 @@ from theano import scalar ...@@ -39,7 +39,7 @@ from theano import scalar
from theano.scalar import basic from theano.scalar import basic
from theano.tensor import basic as T from theano.tensor import basic as T
from theano import compile # to register the optimizer built by this file from theano import compile # to register the optimizer built by this file
from theano.compile.ops import Shape_i from theano.compile.ops import Shape, Shape_i
from theano.tensor.type import (values_eq_approx_remove_inf, from theano.tensor.type import (values_eq_approx_remove_inf,
values_eq_approx_remove_nan, values_eq_approx_remove_nan,
values_eq_approx_remove_inf_nan) values_eq_approx_remove_inf_nan)
...@@ -48,7 +48,8 @@ from theano.gof.opt import (Optimizer, pre_constant_merge, ...@@ -48,7 +48,8 @@ from theano.gof.opt import (Optimizer, pre_constant_merge,
pre_greedy_local_optimizer) pre_greedy_local_optimizer)
from theano.gof import toolbox from theano.gof import toolbox
from theano.tensor.basic import (Alloc, get_scalar_constant_value, ShapeError, from theano.tensor.basic import (Alloc, get_scalar_constant_value, ShapeError,
extract_constant, NotScalarConstantError) extract_constant, NotScalarConstantError,
Reshape)
from six import StringIO from six import StringIO
_logger = logging.getLogger('theano.tensor.opt') _logger = logging.getLogger('theano.tensor.opt')
...@@ -574,25 +575,6 @@ def local_dimshuffle_lift(node): ...@@ -574,25 +575,6 @@ def local_dimshuffle_lift(node):
""" """
op = node.op op = node.op
if (isinstance(op, T.Reshape) and
node.inputs[0].owner is not None and
isinstance(node.inputs[0].owner.op, DimShuffle)):
new_order = node.inputs[0].owner.op.new_order
new_order = [i for i in new_order if i != 'x']
input = node.inputs[0].owner.inputs[0]
broadcastables = input.broadcastable
new_order_of_nonbroadcastables = []
for i, bd in zip(new_order, broadcastables):
if not bd:
new_order_of_nonbroadcastables.append(i)
no_change_in_order = all(
new_order_of_nonbroadcastables[i] <= new_order_of_nonbroadcastables[i + 1]
for i in xrange(len(new_order_of_nonbroadcastables) - 1))
if no_change_in_order:
shape = node.inputs[1]
ret = op.__class__(node.outputs[0].ndim)(input, shape)
copy_stack_trace(node.outputs[0], ret)
return [ret]
if not isinstance(op, DimShuffle): if not isinstance(op, DimShuffle):
return False return False
...@@ -626,6 +608,42 @@ def local_dimshuffle_lift(node): ...@@ -626,6 +608,42 @@ def local_dimshuffle_lift(node):
return [ret] return [ret]
@register_canonicalize
@gof.local_optimizer([Reshape])
def local_useless_dimshuffle_in_reshape(node):
"""
Removes useless DimShuffle operation inside Reshape:
reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp)
reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp)
reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp)
reshape(col.dimshuffle(0), shp) => reshape(col, shp)
"""
op = node.op
if not isinstance(op, Reshape):
return False
if not (node.inputs[0].owner is not None and
isinstance(node.inputs[0].owner.op, DimShuffle)):
return False
new_order = node.inputs[0].owner.op.new_order
input = node.inputs[0].owner.inputs[0]
broadcastables = node.inputs[0].broadcastable
new_order_of_nonbroadcast = []
for i, bd in zip(new_order, broadcastables):
if not bd:
new_order_of_nonbroadcast.append(i)
no_change_in_order = all(
new_order_of_nonbroadcast[i] <= new_order_of_nonbroadcast[i + 1]
for i in xrange(len(new_order_of_nonbroadcast) - 1))
if no_change_in_order:
shape = node.inputs[1]
ret = op.__class__(node.outputs[0].ndim)(input, shape)
copy_stack_trace(node.outputs[0], ret)
return [ret]
@register_canonicalize @register_canonicalize
@gof.local_optimizer([T.DimShuffle]) @gof.local_optimizer([T.DimShuffle])
def local_lift_transpose_through_dot(node): def local_lift_transpose_through_dot(node):
...@@ -4165,15 +4183,135 @@ register_canonicalize(local_reshape_chain(T.Reshape), ...@@ -4165,15 +4183,135 @@ register_canonicalize(local_reshape_chain(T.Reshape),
@gof.local_optimizer([T.Reshape]) @gof.local_optimizer([T.Reshape])
def local_useless_reshape(node): def local_useless_reshape(node):
""" """
Remove Reshape when both the input and the output have a Remove two kinds of useless reshape.
single dimension.
Remove Reshape when both the input and output have a single dimension.
Remove Reshape when reshaping to the shape of the input.
""" """
if isinstance(node.op, T.Reshape): op = node.op
if (node.inputs[0].ndim == 1 and node.outputs[0].ndim == 1 and if not isinstance(op, Reshape):
node.inputs[0].broadcastable == return False
node.outputs[0].broadcastable):
return [node.inputs[0]] input = node.inputs[0]
output = node.outputs[0]
output_shape = node.inputs[1]
if input.ndim != output.ndim:
return False
# Simple case: both input and output have a single dimension.
# This could hide errors if the user provides inconsistent shapes.
if (input.ndim == 1 and output.ndim == 1 and
input.broadcastable == output.broadcastable):
return [input]
# Second case: all the shapes match the input shape
# Match Reshape(x, x.shape)
if output_shape.owner and isinstance(output_shape.owner.op, Shape):
shape_input = output_shape.owner.inputs[0]
if shape_input == input:
return [input]
# Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for
# broadcastable and constant dimensions
if output_shape.owner and isinstance(output_shape.owner.op, MakeVector):
output_shape_is = output_shape.owner.inputs
if not hasattr(node, 'fgraph'):
shape_feature = None
else:
shape_feature = getattr(node.fgraph, 'shape_feature', None)
shape_match = [False] * input.ndim
for dim in xrange(input.ndim):
outshp_i = output_shape_is[dim]
# Match Shape_i{dim}(input)
if (outshp_i.owner and isinstance(outshp_i.owner.op, Shape_i) and
outshp_i.owner.op.i == dim and
outshp_i.owner.inputs[0] == input):
shape_match[dim] = True
continue
# Match Shape(input)[dim]
if (outshp_i.owner and isinstance(outshp_i.owner.op, Subtensor) and
len(outshp_i.owner.inputs) == 2 and
extract_constant(outshp_i.owner.inputs[1]) == dim):
subtensor_inp = outshp_i.owner.inputs[0]
if (subtensor_inp.owner and
isinstance(subtensor_inp.owner.op, Shape)):
shape_input_i = subtensor_inp.owner.inputs[0]
if shape_input_i == input:
shape_match[dim] = True
continue
# Match 1 if input.broadcastable[dim] is True
if (input.broadcastable[dim] and
extract_constant(outshp_i, only_process_constants=1) == 1):
shape_match[dim] = True
continue
# Match shape_of[input][dim] or its constant equivalent
if shape_feature:
inpshp_i = shape_feature.get_shape(input, dim)
if (inpshp_i == outshp_i or
(extract_constant(inpshp_i, only_process_constants=1) ==
extract_constant(outshp_i, only_process_constants=1))):
shape_match[dim] = True
continue
if all(shape_match):
return [input]
# TODO later: if all the shapes except one match, we may want to
# consider it useless as well, like we do in the 1-dim case.
@register_canonicalize
@gof.local_optimizer([T.Reshape])
def local_reshape_to_dimshuffle(node):
"""
Broadcastable dimensions in Reshape are replaced with dimshuffle.
The goal is to avoid using reshape to add or remove broadcastable
dimensions, but use dimshuffle instead, so dimshuffles can cancel out
or be removed later on.
For example:
- reshape(x, (1, n)) --> dimshuffle{x,0}(reshape(x, (n,))
- reshape(x, (1, m, 1, n, 1, 1))
--> dimshuffle{x,0,x,1,x,x}(reshape(x, (m, n)))
"""
op = node.op
if not isinstance(op, Reshape):
return False
input = node.inputs[0]
output = node.outputs[0]
output_shape = node.inputs[1]
dimshuffle_new_order = []
new_output_shape = []
index = 0 # index over the output of the new reshape
for i in xrange(output.ndim):
# Since output_shape is a symbolic vector, we trust extract_constant
# to go through however it is formed to see if its i-th element is 1.
# We need only_process_constants=False for that.
dim = extract_constant(output_shape[i], only_process_constants=False,
elemwise=False)
if dim == 1:
dimshuffle_new_order.append('x')
else:
dimshuffle_new_order.append(index)
new_output_shape.append(dim)
index = index + 1
if index != output.ndim:
inner = op.__class__(len(new_output_shape))(input, new_output_shape)
copy_stack_trace(output, inner)
new_node = [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)]
copy_stack_trace(output, new_node)
return new_node
@register_canonicalize @register_canonicalize
......
...@@ -5132,13 +5132,17 @@ class T_reshape(utt.InferShapeTester, utt.TestOptimizationMixin): ...@@ -5132,13 +5132,17 @@ class T_reshape(utt.InferShapeTester, utt.TestOptimizationMixin):
self.ignore_topo = ignore_topo self.ignore_topo = ignore_topo
super(T_reshape, self).__init__(name) super(T_reshape, self).__init__(name)
def function(self, inputs, outputs): def function(self, inputs, outputs, ignore_empty=False):
f = function(inputs, outputs, mode=self.mode) f = function(inputs, outputs, mode=self.mode)
if self.mode is not None or theano.config.mode != "FAST_COMPILE": if self.mode is not None or theano.config.mode != "FAST_COMPILE":
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
topo_ = [node for node in topo if not isinstance(node.op, topo_ = [node for node in topo if not isinstance(node.op,
self.ignore_topo)] self.ignore_topo)]
if ignore_empty:
assert len(topo_) <= 1, topo_
else:
assert len(topo_) == 1, topo_ assert len(topo_) == 1, topo_
if len(topo_) > 0:
assert type(topo_[0].op) is self.op assert type(topo_[0].op) is self.op
return f return f
...@@ -5212,10 +5216,21 @@ class T_reshape(utt.InferShapeTester, utt.TestOptimizationMixin): ...@@ -5212,10 +5216,21 @@ class T_reshape(utt.InferShapeTester, utt.TestOptimizationMixin):
# test broadcast flag for constant value of 1 # test broadcast flag for constant value of 1
c = reshape(b, (b.shape[0], b.shape[1], 1)) c = reshape(b, (b.shape[0], b.shape[1], 1))
f = self.function([b], c) # That reshape may get replaced with a dimshuffle, with is ignored,
# so we pass "ignore_empty=True"
f = self.function([b], c, ignore_empty=True)
assert numpy.all(f(numpy.asarray([[0, 1, 2], [3, 4, 5]])) == assert numpy.all(f(numpy.asarray([[0, 1, 2], [3, 4, 5]])) ==
numpy.asarray([[[0], [1], [2]], [[3], [4], [5]]])) numpy.asarray([[[0], [1], [2]], [[3], [4], [5]]]))
assert (f.maker.fgraph.toposort()[-2].outputs[0].type.broadcastable == assert (f.maker.fgraph.toposort()[-1].outputs[0].type.broadcastable ==
(False, False, True))
# test broadcast flag for constant value of 1 if it cannot be
# replaced with dimshuffle
c = reshape(b, (b.shape[1], b.shape[0], 1))
f = self.function([b], c, ignore_empty=True)
assert numpy.all(f(numpy.asarray([[0, 1, 2], [3, 4, 5]])) ==
numpy.asarray([[[0], [1]], [[2], [3]], [[4], [5]]]))
assert (f.maker.fgraph.toposort()[-1].outputs[0].type.broadcastable ==
(False, False, True)) (False, False, True))
def test_m1(self): def test_m1(self):
......
...@@ -12,7 +12,7 @@ import unittest ...@@ -12,7 +12,7 @@ import unittest
import numpy import numpy
from six.moves import xrange from six.moves import xrange
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from nose.tools import assert_raises from nose.tools import assert_raises, assert_true
from numpy.testing import dec from numpy.testing import dec
from numpy.testing.noseclasses import KnownFailureTest from numpy.testing.noseclasses import KnownFailureTest
...@@ -32,8 +32,11 @@ import theano.tensor.opt as opt ...@@ -32,8 +32,11 @@ import theano.tensor.opt as opt
from theano.tensor.opt import ( from theano.tensor.opt import (
local_add_specialize, local_add_specialize,
local_dimshuffle_lift, local_dimshuffle_lift,
local_useless_dimshuffle_in_reshape,
local_useless_alloc, local_useless_alloc,
local_greedy_distributor, local_greedy_distributor,
local_useless_reshape,
local_reshape_to_dimshuffle,
mul_canonizer, mul_canonizer,
out2in, out2in,
Shape_i, Shape_i,
...@@ -60,7 +63,6 @@ from theano.tensor import ( ...@@ -60,7 +63,6 @@ from theano.tensor import (
join, join,
Subtensor, Subtensor,
TensorType, TensorType,
Tile,
tile tile
) )
from theano.tensor.elemwise import DimShuffle from theano.tensor.elemwise import DimShuffle
...@@ -222,7 +224,8 @@ class test_dimshuffle_lift(unittest.TestCase): ...@@ -222,7 +224,8 @@ class test_dimshuffle_lift(unittest.TestCase):
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
self.assertTrue(hasattr(g.outputs[0].tag, 'trace')) self.assertTrue(hasattr(g.outputs[0].tag, 'trace'))
def test_useless_dimshuffle_in_presence_of_reshape(self):
def test_local_useless_dimshuffle_in_reshape():
vector = TensorType(broadcastable=(False,), dtype='float64')('vector') vector = TensorType(broadcastable=(False,), dtype='float64')('vector')
mat = TensorType(broadcastable=(False, False), dtype='float64')('mat') mat = TensorType(broadcastable=(False, False), dtype='float64')('mat')
row = TensorType(broadcastable=(True, False), dtype='float64')('row') row = TensorType(broadcastable=(True, False), dtype='float64')('row')
...@@ -237,17 +240,27 @@ class test_dimshuffle_lift(unittest.TestCase): ...@@ -237,17 +240,27 @@ class test_dimshuffle_lift(unittest.TestCase):
[reshape_dimshuffle_vector, reshape_dimshuffle_mat, [reshape_dimshuffle_vector, reshape_dimshuffle_mat,
reshape_dimshuffle_row, reshape_dimshuffle_col]) reshape_dimshuffle_row, reshape_dimshuffle_col])
self.assertTrue(str(g) == "[Reshape{1}(DimShuffle{x,0}(vector), Shape(vector)), " assert_true(str(g) == "[Reshape{1}(DimShuffle{x,0}(vector), Shape(vector)), "
"Reshape{2}(DimShuffle{x,0,x,1}(mat), Shape(mat)), " "Reshape{2}(DimShuffle{x,0,x,1}(mat), Shape(mat)), "
"Reshape{2}(DimShuffle{1,x}(row), Shape(row)), " "Reshape{2}(DimShuffle{1,x}(row), Shape(row)), "
"Reshape{2}(DimShuffle{0}(col), Shape(col))]") "Reshape{2}(DimShuffle{0}(col), Shape(col))]")
dimshuffle_lift.optimize(g) useless_dimshuffle_in_reshape = out2in(local_useless_dimshuffle_in_reshape)
self.assertTrue(str(g) == "[Reshape{1}(vector, Shape(vector)), " useless_dimshuffle_in_reshape.optimize(g)
assert_true(str(g) == "[Reshape{1}(vector, Shape(vector)), "
"Reshape{2}(mat, Shape(mat)), " "Reshape{2}(mat, Shape(mat)), "
"Reshape{2}(row, Shape(row)), " "Reshape{2}(row, Shape(row)), "
"Reshape{2}(col, Shape(col))]") "Reshape{2}(col, Shape(col))]")
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
self.assertTrue(hasattr(g.outputs[0].tag, 'trace')) assert_true(check_stack_trace(g, ops_to_check='all'))
# Check that the optimization does not get applied when the order
# of dimensions has changed.
reshape_dimshuffle_mat2 = tensor.reshape(mat.dimshuffle('x', 1, 'x', 0), mat.shape)
h = FunctionGraph([mat], [reshape_dimshuffle_mat2])
str_h = str(h)
useless_dimshuffle_in_reshape.optimize(h)
assert_true(str(h) == str(h))
def test_add_canonizer_problem0(): def test_add_canonizer_problem0():
...@@ -4216,23 +4229,31 @@ def test_local_mul_specialize(): ...@@ -4216,23 +4229,31 @@ def test_local_mul_specialize():
class T_Tile(unittest.TestCase): class T_Tile(unittest.TestCase):
def test_local_useless_tile(self): def test_local_useless_tile(self):
# Tile op is deprecated so the tile function doesn't use it
# anymore, we'll test here the op directly
v = T.vector() v = T.vector()
m = T.matrix() m = T.matrix()
mode = None mode = None
if theano.config.mode == "FAST_COMPILE": if theano.config.mode == "FAST_COMPILE":
mode = "FAST_RUN" mode = "FAST_RUN"
for var, data in [(v, [1, 2, 3]), (m, [[1, 2], [3, 4]])]: for var, data in [(v, [1, 2, 3]), (m, [[1, 2], [3, 4]])]:
# Currently, only a repeat patter == ndim is supported. # When len(repeat pattern) <= var.ndim, everything is removed
for ndim in [var.ndim]: # range(1, var.ndim): # for ndim in range(1, var.ndim):
f = theano.function([var], Tile(ndim)(var, (1,)*ndim), mode=mode) for ndim in range(var.ndim + 1):
f = theano.function([var], tile(var, (1,) * ndim), mode=mode)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len(topo) == 1 assert len(topo) == 1
assert isinstance(topo[0].op, compile.DeepCopyOp) assert isinstance(topo[0].op, compile.DeepCopyOp)
f(data) f(data)
# In this case the opt only removes nodes, # In this case the opt only removes nodes,
# no need to check_stack_trace # no need to check_stack_trace
# When len(repeat pattern) > var.ndim, only a dimshuffle should be
# left, but there can be a DeepCopy as well
for ndim in range(var.ndim + 1, var.ndim + 3):
f = theano.function([var], tile(var, (1,) * ndim), mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) <= 2
assert isinstance(topo[0].op, DimShuffle)
assert check_stack_trace(f, ops_to_check=[DimShuffle])
f(data)
def speed_local_pow_specialize_range(): def speed_local_pow_specialize_range():
...@@ -6163,7 +6184,11 @@ class Test_Reshape(unittest.TestCase): ...@@ -6163,7 +6184,11 @@ class Test_Reshape(unittest.TestCase):
assert sum(isinstance(node.op, self.op) for node in topo) == 1 assert sum(isinstance(node.op, self.op) for node in topo) == 1
def test_local_useless_reshape(): class Test_local_useless_reshape(unittest.TestCase):
def setUp(self):
self.rng = numpy.random.RandomState(utt.fetch_seed())
def test_0(self):
mode = theano.compile.get_default_mode().including( mode = theano.compile.get_default_mode().including(
'local_useless_reshape') 'local_useless_reshape')
i = T.iscalar('i') i = T.iscalar('i')
...@@ -6172,6 +6197,68 @@ def test_local_useless_reshape(): ...@@ -6172,6 +6197,68 @@ def test_local_useless_reshape():
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo) assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo)
def test_1(self):
x = theano.tensor.matrix('x')
r = x.reshape(x.shape)
m0 = theano.compile.get_default_mode()
m1 = m0.including('local_useless_reshape')
f1 = theano.function([x], r, mode=m1)
topo = f1.maker.fgraph.toposort()
assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo)
m2 = m1.excluding('ShapeOpt')
f2 = theano.function([x], r, mode=m2)
topo = f2.maker.fgraph.toposort()
assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo)
def test_2(self):
x = theano.tensor.matrix('x')
r = x.reshape([Shape_i(i)(x) for i in xrange(x.ndim)])
m0 = theano.compile.get_default_mode()
m1 = m0.including('local_useless_reshape')
f1 = theano.function([x], r, mode=m1)
topo = f1.maker.fgraph.toposort()
assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo)
m2 = m1.excluding('ShapeOpt')
f2 = theano.function([x], r, mode=m2)
topo = f2.maker.fgraph.toposort()
assert not any(isinstance(n.op, tensor.basic.Reshape) for n in topo)
class Test_local_reshape_to_dimshuffle(unittest.TestCase):
def setUp(self):
self.rng = numpy.random.RandomState(utt.fetch_seed())
def test_1(self):
reshape_lift = out2in(local_reshape_to_dimshuffle)
useless_reshape = out2in(local_useless_reshape)
x = shared(self.rng.randn(4,))
y = shared(self.rng.randn(5, 6))
reshape_x = tensor.reshape(x, (1, 4))
reshape_y = tensor.reshape(y, (1, 5, 1, 6, 1, 1))
g = FunctionGraph([x, y], [reshape_x, reshape_y])
self.assertTrue(str(g) == ("[Reshape{2}"
"(<TensorType(float64, vector)>, "
"TensorConstant{[1 4]}), "
"Reshape{6}"
"(<TensorType(float64, matrix)>, "
"TensorConstant{[1 5 1 6 1 1]})]"))
reshape_lift.optimize(g)
useless_reshape.optimize(g)
self.assertTrue(str(g) == "[DimShuffle{x,0}"
"(<TensorType(float64, vector)>), "
"DimShuffle{x,0,x,1,x,x}"
"(Reshape{2}(<TensorType(float64, matrix)>, "
"TensorConstant{[5 6]}))]")
# Check stacktrace was copied over correctly after opt was applied
check_stack_trace(g, ops_to_check=(T.DimShuffle, T.Reshape))
def test_local_reshape_lift(): def test_local_reshape_lift():
x = tensor.tensor4() x = tensor.tensor4()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论