提交 fc66c3fd authored 作者: lamblin's avatar lamblin

Merge pull request #1470 from nouiz/split_file

Split file
...@@ -727,12 +727,19 @@ copy_reg.pickle(Function, _pickle_Function) ...@@ -727,12 +727,19 @@ copy_reg.pickle(Function, _pickle_Function)
### ###
class SanityCheckFunction(Function): class SanityCheckFunction(Function):
"""Deprecated. It is not used and not tested anywhere in Theano!
Also, we should remove the check_equal and related function in
this file, and use Type.values_equals() instead.
"""
def __init__(self, others, check_equal, *args, **kwargs): def __init__(self, others, check_equal, *args, **kwargs):
super(SanityCheckFunction, self).__init__(*args, **kwargs) super(SanityCheckFunction, self).__init__(*args, **kwargs)
self.others = others self.others = others
self.check_equal = check_equal self.check_equal = check_equal
# DEPRECATED? Is this just for DualLinker? # DEPRECATED? Is this just for DualLinker?
warnings.warn("SanityCheckFunction is deprecated")
def __setitem__(self, item, value): def __setitem__(self, item, value):
super(SanityCheckFunction, self).__setitem__(item, value) super(SanityCheckFunction, self).__setitem__(item, value)
......
...@@ -78,4 +78,4 @@ from theano.gof.type import \ ...@@ -78,4 +78,4 @@ from theano.gof.type import \
Type, Generic, generic Type, Generic, generic
from theano.gof.utils import \ from theano.gof.utils import \
object2, MethodNotDefined hashtype, object2, MethodNotDefined
...@@ -22,6 +22,11 @@ def hashgen(): ...@@ -22,6 +22,11 @@ def hashgen():
hashgen.next = 0 hashgen.next = 0
def hashtype(self):
t = type(self)
return hash(t.__name__) ^ hash(t.__module__)
class MethodNotDefined(Exception): class MethodNotDefined(Exception):
""" """
To be raised by functions defined as part of an interface. To be raised by functions defined as part of an interface.
......
...@@ -437,8 +437,8 @@ acceptable_ops = (theano.tensor.basic.Dot, ...@@ -437,8 +437,8 @@ acceptable_ops = (theano.tensor.basic.Dot,
theano.tensor.basic.Shape, theano.tensor.basic.Shape,
theano.tensor.basic.SpecifyShape, theano.tensor.basic.SpecifyShape,
theano.tensor.basic.MaxAndArgmax, theano.tensor.basic.MaxAndArgmax,
theano.tensor.basic.Subtensor, theano.tensor.Subtensor,
theano.tensor.basic.IncSubtensor, theano.tensor.IncSubtensor,
theano.tensor.basic.Rebroadcast, theano.tensor.basic.Rebroadcast,
theano.tensor.basic.Alloc, theano.tensor.basic.Alloc,
theano.tensor.elemwise.Elemwise, theano.tensor.elemwise.Elemwise,
......
...@@ -798,7 +798,7 @@ __global__ void k_take_3(const int d0, const int d1, const int d2, ...@@ -798,7 +798,7 @@ __global__ void k_take_3(const int d0, const int d1, const int d2,
// This prevent us from setting it to 0 before each use // This prevent us from setting it to 0 before each use
static int* err_var = NULL; static int* err_var = NULL;
// We try to be similat to the PyArray_TakeFrom function // We try to be similar to the PyArray_TakeFrom function
//http://docs.scipy.org/doc/numpy/reference/c-api.array.html //http://docs.scipy.org/doc/numpy/reference/c-api.array.html
//TODO: support other clip mode then raise(clip, wrap) //TODO: support other clip mode then raise(clip, wrap)
//self is the input that we copy data from. //self is the input that we copy data from.
......
...@@ -912,8 +912,9 @@ class T_Join_and_Split(theano.tensor.tests.test_basic.T_Join_and_Split): ...@@ -912,8 +912,9 @@ class T_Join_and_Split(theano.tensor.tests.test_basic.T_Join_and_Split):
self.shared = cuda.shared_constructor self.shared = cuda.shared_constructor
import theano.tensor.tests.test_subtensor
# This is to don't duplicate test. # This is to don't duplicate test.
class T_subtensor(theano.tensor.tests.test_basic.T_subtensor): class T_subtensor(theano.tensor.tests.test_subtensor.T_subtensor):
# This prevents nose from printing method docstrings instead of method # This prevents nose from printing method docstrings instead of method
# names # names
...@@ -933,7 +934,7 @@ class T_subtensor(theano.tensor.tests.test_basic.T_subtensor): ...@@ -933,7 +934,7 @@ class T_subtensor(theano.tensor.tests.test_basic.T_subtensor):
cuda.GpuAdvancedSubtensor1, cuda.GpuAdvancedIncSubtensor1) cuda.GpuAdvancedSubtensor1, cuda.GpuAdvancedIncSubtensor1)
def __init__(self, name): def __init__(self, name):
return super(theano.tensor.tests.test_basic.T_subtensor, return super(theano.tensor.tests.test_subtensor.T_subtensor,
self).__init__(name) self).__init__(name)
def test_adv_sub1_fast(self): def test_adv_sub1_fast(self):
......
...@@ -831,6 +831,8 @@ det = Det() ...@@ -831,6 +831,8 @@ det = Det()
def trace(X): def trace(X):
""" """
Returns the sum of diagonal elements of matrix X. Returns the sum of diagonal elements of matrix X.
:note: work on GPU since 0.6rc4.
""" """
return extract_diag(X).sum() return extract_diag(X).sum()
......
...@@ -691,7 +691,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -691,7 +691,7 @@ class ScanSaveMem(gof.Optimizer):
break break
# 2.2 non-subtensor nodes # 2.2 non-subtensor nodes
#=> output needs all its intermediate values #=> output needs all its intermediate values
elif not isinstance(cl.op, tensor.basic.Subtensor): elif not isinstance(cl.op, tensor.Subtensor):
global_nsteps = None global_nsteps = None
slices[i] = None slices[i] = None
break break
...@@ -699,7 +699,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -699,7 +699,7 @@ class ScanSaveMem(gof.Optimizer):
#=> output might need to store just a subset of its values #=> output might need to store just a subset of its values
else: else:
# 2.3.1 extract idx list of subtensor # 2.3.1 extract idx list of subtensor
this_slice = tensor.basic.get_idx_list(cl.inputs, this_slice = tensor.get_idx_list(cl.inputs,
cl.op.idx_list) cl.op.idx_list)
if this_slice is None: if this_slice is None:
# if unable to extract idx_list # if unable to extract idx_list
...@@ -719,7 +719,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -719,7 +719,7 @@ class ScanSaveMem(gof.Optimizer):
length = shape_of[out][0] length = shape_of[out][0]
except KeyError: except KeyError:
length = out.shape[0] length = out.shape[0]
cf_slice = tensor.basic.get_canonical_form_slice( cf_slice = tensor.get_canonical_form_slice(
this_slice[0], length) this_slice[0], length)
slices[i] += [(cf_slice, this_slice)] slices[i] += [(cf_slice, this_slice)]
...@@ -795,12 +795,12 @@ class ScanSaveMem(gof.Optimizer): ...@@ -795,12 +795,12 @@ class ScanSaveMem(gof.Optimizer):
if type(cl) == str: if type(cl) == str:
store_steps[i] = 0 store_steps[i] = 0
break break
elif not isinstance(cl.op, tensor.basic.Subtensor): elif not isinstance(cl.op, tensor.Subtensor):
store_steps[i] = 0 store_steps[i] = 0
break break
else: else:
this_slice = tensor.basic.get_idx_list(cl.inputs, this_slice = tensor.get_idx_list(cl.inputs,
cl.op.idx_list) cl.op.idx_list)
if this_slice is None: if this_slice is None:
store_steps[i] = 0 store_steps[i] = 0
break break
...@@ -817,8 +817,8 @@ class ScanSaveMem(gof.Optimizer): ...@@ -817,8 +817,8 @@ class ScanSaveMem(gof.Optimizer):
length = shape_of[out][0] length = shape_of[out][0]
except KeyError: except KeyError:
length = out.shape[0] length = out.shape[0]
cf_slice = tensor.basic.get_canonical_form_slice( cf_slice = tensor.get_canonical_form_slice(
this_slice[0], length) this_slice[0], length)
if isinstance(cf_slice[0], slice): if isinstance(cf_slice[0], slice):
start = tensor.basic.extract_constant( start = tensor.basic.extract_constant(
...@@ -973,9 +973,9 @@ class ScanSaveMem(gof.Optimizer): ...@@ -973,9 +973,9 @@ class ScanSaveMem(gof.Optimizer):
nw_slice = (fslice,) + tuple(old_slices[1:]) nw_slice = (fslice,) + tuple(old_slices[1:])
nw_pos = inv_compress_map[idx] nw_pos = inv_compress_map[idx]
subtens = tensor.basic.Subtensor(nw_slice) subtens = tensor.Subtensor(nw_slice)
# slice inputs # slice inputs
sl_ins = tensor.basic.Subtensor.collapse( sl_ins = tensor.Subtensor.collapse(
nw_slice, nw_slice,
lambda entry: isinstance(entry, lambda entry: isinstance(entry,
tensor.Variable)) tensor.Variable))
...@@ -1014,8 +1014,8 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1014,8 +1014,8 @@ class ScanSaveMem(gof.Optimizer):
nw_slice = (sanitize(position),) + \ nw_slice = (sanitize(position),) + \
tuple(old_slices[1:]) tuple(old_slices[1:])
subtens = tensor.basic.Subtensor(nw_slice) subtens = tensor.Subtensor(nw_slice)
sl_ins = tensor.basic.Subtensor.collapse( sl_ins = tensor.Subtensor.collapse(
nw_slice, nw_slice,
lambda entry: isinstance(entry, lambda entry: isinstance(entry,
tensor.Variable)) tensor.Variable))
......
...@@ -4,6 +4,8 @@ __docformat__ = "restructuredtext en" ...@@ -4,6 +4,8 @@ __docformat__ = "restructuredtext en"
import warnings import warnings
from theano.tensor.basic import * from theano.tensor.basic import *
from theano.tensor.subtensor import *
from theano.tensor.type_other import *
from theano.tensor import opt from theano.tensor import opt
from theano.tensor import opt_uncanonicalize from theano.tensor import opt_uncanonicalize
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -8,6 +8,7 @@ import numpy ...@@ -8,6 +8,7 @@ import numpy
import theano import theano
from theano import gof from theano import gof
from theano.tensor import basic as tensor from theano.tensor import basic as tensor
from theano.tensor import subtensor
from theano.tensor import elemwise, dmatrix, fmatrix, dvector, fvector from theano.tensor import elemwise, dmatrix, fmatrix, dvector, fvector
from theano.tensor import opt from theano.tensor import opt
from theano.compile import optdb from theano.compile import optdb
...@@ -1004,7 +1005,7 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op): ...@@ -1004,7 +1005,7 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
# typically we should not need the gradient w.r.t. dy). # typically we should not need the gradient w.r.t. dy).
y_idx_range = tensor.arange(y_idx.shape[0]) y_idx_range = tensor.arange(y_idx.shape[0])
g_dy = tensor.sum( g_dy = tensor.sum(
g_dx * tensor.AdvancedIncSubtensor()( g_dx * subtensor.AdvancedIncSubtensor()(
sm, tensor.fill(dy, -1), y_idx_range, y_idx), sm, tensor.fill(dy, -1), y_idx_range, y_idx),
axis=1) axis=1)
g_sm = dy.dimshuffle(0, 'x') * g_dx g_sm = dy.dimshuffle(0, 'x') * g_dx
...@@ -1396,7 +1397,7 @@ def _check_rows_is_arange_len_labels(rows, labels): ...@@ -1396,7 +1397,7 @@ def _check_rows_is_arange_len_labels(rows, labels):
# Not sure if that case happens any more after the introduction of # Not sure if that case happens any more after the introduction of
# ShapeOptimizer, but we keep it if ShapeOptimizer is not present # ShapeOptimizer, but we keep it if ShapeOptimizer is not present
if isinstance(stop.owner.op, tensor.Subtensor): if isinstance(stop.owner.op, subtensor.Subtensor):
shape_subtensor = stop.owner shape_subtensor = stop.owner
if list(shape_subtensor.op.idx_list) == [0]: if list(shape_subtensor.op.idx_list) == [0]:
shape_var, = shape_subtensor.inputs shape_var, = shape_subtensor.inputs
...@@ -1424,7 +1425,7 @@ def local_advanced_indexing_crossentropy_onehot(node): ...@@ -1424,7 +1425,7 @@ def local_advanced_indexing_crossentropy_onehot(node):
log = None log = None
sm = None sm = None
# First case: log(softmax(x))[rows, labels] # First case: log(softmax(x))[rows, labels]
if isinstance(node.op, tensor.AdvancedSubtensor): if isinstance(node.op, subtensor.AdvancedSubtensor):
try: try:
log, rows, labels = node.inputs log, rows, labels = node.inputs
except Exception: except Exception:
...@@ -1435,7 +1436,7 @@ def local_advanced_indexing_crossentropy_onehot(node): ...@@ -1435,7 +1436,7 @@ def local_advanced_indexing_crossentropy_onehot(node):
# Second case: log(softmax(x)[rows, labels]) # Second case: log(softmax(x)[rows, labels])
if node.op == tensor.log: if node.op == tensor.log:
pre_log = node.inputs[0].owner pre_log = node.inputs[0].owner
if pre_log and isinstance(pre_log.op, tensor.AdvancedSubtensor): if pre_log and isinstance(pre_log.op, subtensor.AdvancedSubtensor):
try: try:
sm, rows, labels = pre_log.inputs sm, rows, labels = pre_log.inputs
except Exception: except Exception:
...@@ -1524,7 +1525,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): ...@@ -1524,7 +1525,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
# After the check for AdvancedIncSubtensor, if anything does not fit with # After the check for AdvancedIncSubtensor, if anything does not fit with
# the formula above, there's no way to fit it with the the second case, # the formula above, there's no way to fit it with the the second case,
# so we return immediately. # so we return immediately.
if d_sm.owner and isinstance(d_sm.owner.op, tensor.AdvancedIncSubtensor): if d_sm.owner and isinstance(d_sm.owner.op, subtensor.AdvancedIncSubtensor):
try: try:
z, incr, rows, labels = d_sm.owner.inputs z, incr, rows, labels = d_sm.owner.inputs
except Exception: except Exception:
...@@ -1566,7 +1567,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): ...@@ -1566,7 +1567,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
if not denom.owner: if not denom.owner:
return return
if isinstance(denom.owner.op, tensor.AdvancedSubtensor): if isinstance(denom.owner.op, subtensor.AdvancedSubtensor):
# Base case # Base case
adv_subtensor = denom adv_subtensor = denom
#out_grad /= 1. #out_grad /= 1.
...@@ -1575,7 +1576,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): ...@@ -1575,7 +1576,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
# and the output gradient # and the output gradient
for i, input in enumerate(denom.owner.inputs): for i, input in enumerate(denom.owner.inputs):
if input.owner and isinstance(input.owner.op, if input.owner and isinstance(input.owner.op,
tensor.AdvancedSubtensor): subtensor.AdvancedSubtensor):
other_inputs = [in_ for (j, other_inputs = [in_ for (j,
in_) in enumerate(denom.owner.inputs) if j != i] in_) in enumerate(denom.owner.inputs) if j != i]
if len(other_inputs) == 1: if len(other_inputs) == 1:
...@@ -1630,7 +1631,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): ...@@ -1630,7 +1631,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
return return
# Check the numerator (AdvancedIncSubtensor) # Check the numerator (AdvancedIncSubtensor)
if num.owner and isinstance(num.owner.op, tensor.AdvancedIncSubtensor): if num.owner and isinstance(num.owner.op, subtensor.AdvancedIncSubtensor):
try: try:
z, incr, rows, labels = num.owner.inputs z, incr, rows, labels = num.owner.inputs
except Exception: except Exception:
......
...@@ -24,6 +24,8 @@ from theano.gof.python25 import maxsize ...@@ -24,6 +24,8 @@ from theano.gof.python25 import maxsize
from theano.gof.utils import MethodNotDefined from theano.gof.utils import MethodNotDefined
from theano.configparser import config from theano.configparser import config
from theano.tensor.elemwise import Elemwise, DimShuffle from theano.tensor.elemwise import Elemwise, DimShuffle
from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice,
Subtensor, IncSubtensor, AdvancedIncSubtensor1)
from theano import scalar from theano import scalar
from theano.tensor import basic as T from theano.tensor import basic as T
from theano import compile # to register the optimizer built by this file from theano import compile # to register the optimizer built by this file
...@@ -1217,13 +1219,13 @@ def local_track_shape_i(node): ...@@ -1217,13 +1219,13 @@ def local_track_shape_i(node):
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@gof.local_optimizer([T.Subtensor]) @gof.local_optimizer([Subtensor])
def local_subtensor_make_vector(node): def local_subtensor_make_vector(node):
# replace all subtensor(make_vector) like: # replace all subtensor(make_vector) like:
# [a,b,c][0] -> a # [a,b,c][0] -> a
# [a,b,c][0:2] -> [a,b] # [a,b,c][0:2] -> [a,b]
# we can do this for constant indexes # we can do this for constant indexes
if isinstance(node.op, T.Subtensor): if isinstance(node.op, Subtensor):
# This optimization needs ShapeOpt and fgraph.shape_feature # This optimization needs ShapeOpt and fgraph.shape_feature
x = node.inputs[0] x = node.inputs[0]
if x.owner and x.owner.op == make_vector: if x.owner and x.owner.op == make_vector:
...@@ -1591,12 +1593,12 @@ def local_upcast_elemwise_constant_inputs(node): ...@@ -1591,12 +1593,12 @@ def local_upcast_elemwise_constant_inputs(node):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@gof.local_optimizer([T.Subtensor]) @gof.local_optimizer([Subtensor])
def local_useless_subtensor(node): def local_useless_subtensor(node):
""" """
Remove Subtensor if it takes the full input Remove Subtensor if it takes the full input
""" """
if isinstance(node.op, T.Subtensor): if isinstance(node.op, Subtensor):
# This optimization needs ShapeOpt and fgraph.shape_feature # This optimization needs ShapeOpt and fgraph.shape_feature
if not hasattr(node.fgraph, 'shape_feature'): if not hasattr(node.fgraph, 'shape_feature'):
return return
...@@ -1677,7 +1679,7 @@ def local_subtensor_lift(node): ...@@ -1677,7 +1679,7 @@ def local_subtensor_lift(node):
when x,... are broadcasted scalar or not broadcasted at all when x,... are broadcasted scalar or not broadcasted at all
rebroadcast(x)[idx] => rebroadcast(x[idx]) rebroadcast(x)[idx] => rebroadcast(x[idx])
""" """
if isinstance(node.op, T.Subtensor): if isinstance(node.op, Subtensor):
u = node.inputs[0] u = node.inputs[0]
if not u.owner or len(u.clients) > 1: if not u.owner or len(u.clients) > 1:
return False return False
...@@ -1736,7 +1738,7 @@ def local_subtensor_lift(node): ...@@ -1736,7 +1738,7 @@ def local_subtensor_lift(node):
new_axis += [(j, u.broadcastable[i])] new_axis += [(j, u.broadcastable[i])]
j += 1 j += 1
subt_x = T.Subtensor(node.op.idx_list)(u.owner.inputs[0]) subt_x = Subtensor(node.op.idx_list)(u.owner.inputs[0])
rbcast_subt_x = T.Rebroadcast(*new_axis)(subt_x) rbcast_subt_x = T.Rebroadcast(*new_axis)(subt_x)
return [rbcast_subt_x] return [rbcast_subt_x]
...@@ -1764,8 +1766,8 @@ def merge_two_slices(slice1, len1, slice2, len2): ...@@ -1764,8 +1766,8 @@ def merge_two_slices(slice1, len1, slice2, len2):
if type(slice1) is not slice: if type(slice1) is not slice:
raise ValueError(('First provided slice should actually be of type' raise ValueError(('First provided slice should actually be of type'
'slice and not an index !'), slice1) 'slice and not an index !'), slice1)
sl1, reverse1 = T.get_canonical_form_slice(slice1, len1) sl1, reverse1 = get_canonical_form_slice(slice1, len1)
sl2, reverse2 = T.get_canonical_form_slice(slice2, len2) sl2, reverse2 = get_canonical_form_slice(slice2, len2)
if type(sl2) is not slice: if type(sl2) is not slice:
if reverse1 is None: if reverse1 is None:
...@@ -1885,15 +1887,15 @@ def local_subtensor_merge(node): ...@@ -1885,15 +1887,15 @@ def local_subtensor_merge(node):
expresses all slices in a canonical form, and then merges them together. expresses all slices in a canonical form, and then merges them together.
""" """
if isinstance(node.op, T.Subtensor): if isinstance(node.op, Subtensor):
u = node.inputs[0] u = node.inputs[0]
if u.owner and isinstance(u.owner.op, T.Subtensor): if u.owner and isinstance(u.owner.op, Subtensor):
# We can merge :) # We can merge :)
# x actual tensor on which we are picking slices # x actual tensor on which we are picking slices
x = u.owner.inputs[0] x = u.owner.inputs[0]
# slices of the first applied subtensor # slices of the first applied subtensor
slices1 = T.get_idx_list(u.owner.inputs, u.owner.op.idx_list) slices1 = get_idx_list(u.owner.inputs, u.owner.op.idx_list)
slices2 = T.get_idx_list(node.inputs, node.op.idx_list) slices2 = get_idx_list(node.inputs, node.op.idx_list)
# Get the shapes of the vectors ! # Get the shapes of the vectors !
try: try:
# try not to introduce new shape into the graph # try not to introduce new shape into the graph
...@@ -1927,8 +1929,8 @@ def local_subtensor_merge(node): ...@@ -1927,8 +1929,8 @@ def local_subtensor_merge(node):
else: else:
merged_slices += slices1[pos_1:] merged_slices += slices1[pos_1:]
subtens = T.Subtensor(merged_slices) subtens = Subtensor(merged_slices)
sl_ins = T.Subtensor.collapse( sl_ins = Subtensor.collapse(
merged_slices, merged_slices,
lambda x: isinstance(x, T.Variable)) lambda x: isinstance(x, T.Variable))
out = subtens.make_node(x, *sl_ins).outputs[0] out = subtens.make_node(x, *sl_ins).outputs[0]
...@@ -1941,14 +1943,14 @@ def local_subtensor_merge(node): ...@@ -1941,14 +1943,14 @@ def local_subtensor_merge(node):
@gof.local_optimizer([]) @gof.local_optimizer([])
def local_subtensor_of_alloc(node): def local_subtensor_of_alloc(node):
"""alloc[x:y] -> alloc""" """alloc[x:y] -> alloc"""
if not isinstance(node.op, T.Subtensor): if not isinstance(node.op, Subtensor):
return False return False
u = node.inputs[0] u = node.inputs[0]
if u.owner is None: if u.owner is None:
return False return False
if not isinstance(u.owner.op, T.Alloc): if not isinstance(u.owner.op, T.Alloc):
return False return False
slices = T.get_idx_list(node.inputs, node.op.idx_list) slices = get_idx_list(node.inputs, node.op.idx_list)
val = u.owner.inputs[0] val = u.owner.inputs[0]
dims = u.owner.inputs[1:] dims = u.owner.inputs[1:]
assert len(slices) <= len(dims) assert len(slices) <= len(dims)
...@@ -1972,7 +1974,7 @@ def local_subtensor_of_alloc(node): ...@@ -1972,7 +1974,7 @@ def local_subtensor_of_alloc(node):
else: else:
val_slices.append(sl) val_slices.append(sl)
csl, _ = T.get_canonical_form_slice(sl, dim) csl, _ = get_canonical_form_slice(sl, dim)
if type(csl) is not slice: if type(csl) is not slice:
# That dimension is removed. # That dimension is removed.
pass pass
...@@ -2026,7 +2028,7 @@ def local_IncSubtensor_serialize(node): ...@@ -2026,7 +2028,7 @@ def local_IncSubtensor_serialize(node):
def movable(i): def movable(i):
# Return True iff this is a incsubtensor that we can move # Return True iff this is a incsubtensor that we can move
return i.owner \ return i.owner \
and isinstance(i.owner.op, T.IncSubtensor) \ and isinstance(i.owner.op, IncSubtensor) \
and i.type == o_type \ and i.type == o_type \
and len(i.clients) == 1 \ and len(i.clients) == 1 \
and not i.owner.op.set_instead_of_inc and not i.owner.op.set_instead_of_inc
...@@ -2060,7 +2062,7 @@ def local_inplace_setsubtensor(node): ...@@ -2060,7 +2062,7 @@ def local_inplace_setsubtensor(node):
""" """
Also work for GpuIncSubtensor Also work for GpuIncSubtensor
""" """
if isinstance(node.op, T.IncSubtensor) and not node.op.inplace: if isinstance(node.op, IncSubtensor) and not node.op.inplace:
new_op = node.op.__class__( new_op = node.op.__class__(
node.op.idx_list, inplace=True, node.op.idx_list, inplace=True,
set_instead_of_inc=node.op.set_instead_of_inc, set_instead_of_inc=node.op.set_instead_of_inc,
...@@ -2077,7 +2079,7 @@ compile.optdb.register('inplace_setsubtensor', ...@@ -2077,7 +2079,7 @@ compile.optdb.register('inplace_setsubtensor',
@gof.local_optimizer([None]) @gof.local_optimizer([None])
def local_inplace_incsubtensor1(node): def local_inplace_incsubtensor1(node):
""" also work for GpuAdvancedIncSubtensor1 """ """ also work for GpuAdvancedIncSubtensor1 """
if isinstance(node.op, T.AdvancedIncSubtensor1) and not node.op.inplace: if isinstance(node.op, AdvancedIncSubtensor1) and not node.op.inplace:
new_op = node.op.__class__( new_op = node.op.__class__(
inplace=True, set_instead_of_inc=node.op.set_instead_of_inc) inplace=True, set_instead_of_inc=node.op.set_instead_of_inc)
new_node = new_op(*node.inputs) new_node = new_op(*node.inputs)
...@@ -2097,7 +2099,7 @@ def local_incsubtensor_of_allocs(node): ...@@ -2097,7 +2099,7 @@ def local_incsubtensor_of_allocs(node):
""" """
IncSubtensor(x, zeros, idx) -> x IncSubtensor(x, zeros, idx) -> x
""" """
if isinstance(node.op, T.IncSubtensor) and not node.op.set_instead_of_inc: if isinstance(node.op, IncSubtensor) and not node.op.set_instead_of_inc:
x = node.inputs[0] x = node.inputs[0]
y = node.inputs[1] y = node.inputs[1]
replace = False replace = False
...@@ -2122,7 +2124,7 @@ def local_setsubtensor_of_allocs(node): ...@@ -2122,7 +2124,7 @@ def local_setsubtensor_of_allocs(node):
when x is constant or alloc. when x is constant or alloc.
""" """
if isinstance(node.op, T.IncSubtensor) and node.op.set_instead_of_inc: if isinstance(node.op, IncSubtensor) and node.op.set_instead_of_inc:
x = node.inputs[0] x = node.inputs[0]
y = node.inputs[1] y = node.inputs[1]
replace_x = None replace_x = None
......
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论