提交 ce440e7e authored 作者: Joseph Turian's avatar Joseph Turian

Backported to python2.4

上级 91a95bec
...@@ -198,8 +198,10 @@ def infer_reuse_pattern(env, outputs_to_disown): ...@@ -198,8 +198,10 @@ def infer_reuse_pattern(env, outputs_to_disown):
seen.add(r) seen.add(r)
do_not_reuse.append(r) do_not_reuse.append(r)
op = r.owner op = r.owner
dmap = op.destroy_map() if hasattr(op, 'destroy_map') else {} if hasattr(op, 'destroy_map'): dmap = op.destroy_map()
vmap = op.view_map() if hasattr(op, 'view_map') else {} else: dmap = {}
if hasattr(op, 'view_map'): vmap = op.view_map()
else: vmap = {}
cat = lambda x, y: list(x) + list(y) cat = lambda x, y: list(x) + list(y)
for r2 in reduce(cat, dmap.values()) + reduce(cat, vmap.values()): for r2 in reduce(cat, dmap.values()) + reduce(cat, vmap.values()):
accumulate(r2) accumulate(r2)
......
...@@ -14,11 +14,22 @@ if sys.version_info[:2] < (2,5): ...@@ -14,11 +14,22 @@ if sys.version_info[:2] < (2,5):
if element: if element:
return True return True
return False return False
def partial(func, *args, **keywords):
def newfunc(*fargs, **fkeywords):
newkeywords = keywords.copy()
newkeywords.update(fkeywords)
return func(*(args + fargs), **newkeywords)
newfunc.func = func
newfunc.args = args
newfunc.keywords = keywords
return newfunc
else: else:
# Only bother with this else clause and the __all__ line if you are putting # Only bother with this else clause and the __all__ line if you are putting
# this in a separate file. # this in a separate file.
import __builtin__ import __builtin__
all = __builtin__.all all = __builtin__.all
any = __builtin__.any any = __builtin__.any
import functools
partial = functools.partial
__all__ = ['all', 'any'] __all__ = ['all', 'any']
...@@ -4,11 +4,10 @@ import math ...@@ -4,11 +4,10 @@ import math
from copy import copy from copy import copy
from functools import partial
import gof import gof
from gof import Result, GuardedOp, Env, utils from gof import Result, GuardedOp, Env, utils
from gof.python25 import partial
def as_scalar(x, name = None): def as_scalar(x, name = None):
if isinstance(x, gof.Op): if isinstance(x, gof.Op):
......
...@@ -15,7 +15,7 @@ import blas # for gemm, dot ...@@ -15,7 +15,7 @@ import blas # for gemm, dot
import elemwise as s2t import elemwise as s2t
import scalar as scal import scalar as scal
from functools import partial from gof.python25 import partial
class Tensor(Result): class Tensor(Result):
...@@ -617,10 +617,13 @@ class Subtensor_dx(Op, Viewer): ...@@ -617,10 +617,13 @@ class Subtensor_dx(Op, Viewer):
cdata = [] cdata = []
for c in self.idx_list: for c in self.idx_list:
if isinstance(c, slice): if isinstance(c, slice):
cdata.append(slice( if c.start is None: start = None
None if c.start is None else self.inputs[c.start].data, else: start = self.inputs[c.start].data
None if c.stop is None else self.inputs[c.stop].data, if c.stop is None: stop = None
None if c.step is None else self.inputs[c.step].data)) else: stop = self.inputs[c.stop].data
if c.step is None: step = None
else: step = self.inputs[c.step].data
cdata.append(slice(start, stop, step))
else: else:
d = self.inputs[c].data d = self.inputs[c].data
assert 'int' in str(d.dtype) assert 'int' in str(d.dtype)
...@@ -680,9 +683,12 @@ class Subtensor(Op, Viewer): ...@@ -680,9 +683,12 @@ class Subtensor(Op, Viewer):
inputs.append(ai) inputs.append(ai)
except TypeError: except TypeError:
if isinstance(idx, slice): if isinstance(idx, slice):
start = None if idx.start is None else asidx(idx.start) if idx.start is None: start = None
stop = None if idx.stop is None else asidx(idx.stop) else: start = asidx(idx.start)
step = None if idx.step is None else asidx(idx.step) if idx.stop is None: stop = None
else: stop = asidx(idx.stop)
if idx.step is None: step = None
else: step = asidx(idx.step)
# If we get here, then everything got turned (successfully) # If we get here, then everything got turned (successfully)
# into a scal.Scalar (with integer dtype) or None # into a scal.Scalar (with integer dtype) or None
...@@ -734,10 +740,13 @@ class Subtensor(Op, Viewer): ...@@ -734,10 +740,13 @@ class Subtensor(Op, Viewer):
cdata = [] cdata = []
for c in self.idx_list: for c in self.idx_list:
if isinstance(c, slice): if isinstance(c, slice):
cdata.append(slice( if c.start is None: start = None
None if c.start is None else self.inputs[c.start].data, else: start = self.inputs[c.start].data
None if c.stop is None else self.inputs[c.stop].data, if c.stop is None: stop = None
None if c.step is None else self.inputs[c.step].data)) else: stop = self.inputs[c.stop].data
if c.step is None: step = None
else: step = self.inputs[c.step].data
cdata.append(slice(start, stop, step))
else: else:
d = self.inputs[c].data d = self.inputs[c].data
assert 'int' in str(d.dtype) assert 'int' in str(d.dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论