提交 663ddaa0 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

some stuff

上级 ef638c6a
...@@ -189,6 +189,24 @@ def eval_outputs(outputs, ...@@ -189,6 +189,24 @@ def eval_outputs(outputs,
return rval return rval
def infer_reuse_pattern(env, outputs_to_disown):
do_not_reuse = outputs_to_disown
seen = set()
def walk(r):
if env.edge(r) or r in seen:
return
seen.add(r)
do_not_reuse.append(r)
op = r.owner
dmap = op.destroy_map() if hasattr(op, 'destroy_map') else {}
vmap = op.view_map() if hasattr(op, 'view_map') else {}
cat = lambda x, y: list(x) + list(y)
for r2 in reduce(cat, dmap.values()) + reduce(cat, vmap.values()):
accumulate(r2)
for output in outputs_to_disown:
walk(output)
return do_not_reuse
# StateFunction([x, y], [e], (w, w + lr * bla())) # StateFunction([x, y], [e], (w, w + lr * bla()))
......
...@@ -412,11 +412,14 @@ class Broadcast(Op, Destroyer): ...@@ -412,11 +412,14 @@ class Broadcast(Op, Destroyer):
def make_broadcast(scalar_opclass, inplace_pattern = {}, name = None): def make_broadcast(scalar_opclass, inplace_pattern = {}, name = None, module_name = None):
scalar_name = scalar_opclass.__name__
if name is None: if name is None:
name = "Tensor" + scalar_opclass.__name__ name = scalar_name
if module_name is None:
module_name = 'elemwise.make_broadcast(%s, %s, %s)' % (scalar_name, inplace_pattern, repr(name))
name = "New"
scalar_name = scalar_opclass.__name__
previous_doc = Broadcast.__doc__ previous_doc = Broadcast.__doc__
scalar_doc = scalar_opclass.__doc__ or "" scalar_doc = scalar_opclass.__doc__ or ""
...@@ -449,6 +452,7 @@ def make_broadcast(scalar_opclass, inplace_pattern = {}, name = None): ...@@ -449,6 +452,7 @@ def make_broadcast(scalar_opclass, inplace_pattern = {}, name = None):
def desc(cls): def desc(cls):
return (Broadcast, scalar_opclass, tuple(inplace_pattern.items())) return (Broadcast, scalar_opclass, tuple(inplace_pattern.items()))
New.__name__ = name New.__name__ = name
New.__module__ = module_name
return New return New
def wrap_broadcast(op): def wrap_broadcast(op):
......
...@@ -493,4 +493,3 @@ def view_roots(r): ...@@ -493,4 +493,3 @@ def view_roots(r):
return [r] return [r]
else: else:
return [r] return [r]
...@@ -445,15 +445,21 @@ class _Op(Op): ...@@ -445,15 +445,21 @@ class _Op(Op):
# Unary Operations # Unary Operations
########################## ##########################
def broadcast(scalar_opclass, name, inplace_versions = True): def broadcast(scalar_opclass, name, module_name = None, inplace_versions = True):
C = s2t.make_broadcast(scalar_opclass, name = name) C = s2t.make_broadcast(scalar_opclass, name = name, module_name = module_name) # this returns a class
C.__module__ = module_name
c = gof.op.constructor(s2t.wrap_broadcast(C)) c = gof.op.constructor(s2t.wrap_broadcast(C))
if inplace_versions: if inplace_versions:
CInplace = s2t.make_broadcast(scalar_opclass, {0:0}, name = name+"Inplace") CInplace = s2t.make_broadcast(scalar_opclass, {0:0}, name = name+"Inplace")
CInplace.__module__ = module_name
c_inplace = gof.op.constructor(s2t.wrap_broadcast(CInplace)) c_inplace = gof.op.constructor(s2t.wrap_broadcast(CInplace))
return C, c, CInplace, c_inplace return C, c, CInplace, c_inplace
else: else:
return C, c return C, c
def _broadcast(scalar_opclass, name, inplace_versions = True):
return broadcast(scalar_opclass, name, 'tensor', inplace_versions)
class Argmax(Op): class Argmax(Op):
"""Calculate the max and argmax over a given axis""" """Calculate the max and argmax over a given axis"""
...@@ -487,32 +493,43 @@ def max(x, axis=None): ...@@ -487,32 +493,43 @@ def max(x, axis=None):
# but when Argmax.c_impl() is in place, it should be fine. # but when Argmax.c_impl() is in place, it should be fine.
return argmax(x,axis)[0] return argmax(x,axis)[0]
Abs, _abs, AbsInplace, abs_inplace = broadcast(scal.Abs, 'Abs') Abs, _abs, AbsInplace, abs_inplace = _broadcast(scal.Abs, 'Abs')
Exp, exp, ExpInplace, exp_inplace = broadcast(scal.Exp, 'Exp') Exp, exp, ExpInplace, exp_inplace = _broadcast(scal.Exp, 'Exp')
Neg, neg, NegInplace, neg_inplace = broadcast(scal.Neg, 'Neg') Neg, neg, NegInplace, neg_inplace = _broadcast(scal.Neg, 'Neg')
Log, log, LogInplace, log_inplace = broadcast(scal.Log, 'Log') Log, log, LogInplace, log_inplace = _broadcast(scal.Log, 'Log')
Log2, log2, Log2Inplace, log2_inplace = broadcast(scal.Log2, 'Log2') Log2, log2, Log2Inplace, log2_inplace = _broadcast(scal.Log2, 'Log2')
Sgn, sgn, SgnInplace, sgn_inplace = broadcast(scal.Sgn, 'Sgn') Sgn, sgn, SgnInplace, sgn_inplace = _broadcast(scal.Sgn, 'Sgn')
Sqr, sqr, SqrInplace, sqr_inplace = broadcast(scal.Sqr, 'Sqr') Sqr, sqr, SqrInplace, sqr_inplace = _broadcast(scal.Sqr, 'Sqr')
Sqrt, sqrt, SqrtInplace, sqrt_inplace = broadcast(scal.Sqrt, 'Sqrt') Sqrt, sqrt, SqrtInplace, sqrt_inplace = _broadcast(scal.Sqrt, 'Sqrt')
Cos, cos, CosInplace, cos_inplace = broadcast(scal.Cos, 'Cos') Cos, cos, CosInplace, cos_inplace = _broadcast(scal.Cos, 'Cos')
Sin, sin, SinInplace, sin_inplace = broadcast(scal.Sin, 'Sin') Sin, sin, SinInplace, sin_inplace = _broadcast(scal.Sin, 'Sin')
Tan, tan, TanInplace, tan_inplace = broadcast(scal.Tan, 'Tan') Tan, tan, TanInplace, tan_inplace = _broadcast(scal.Tan, 'Tan')
Cosh, cosh, CoshInplace, cosh_inplace = broadcast(scal.Cosh, 'Cosh') Cosh, cosh, CoshInplace, cosh_inplace = _broadcast(scal.Cosh, 'Cosh')
Sinh, sinh, SinhInplace, sinh_inplace = broadcast(scal.Sinh, 'Sinh') Sinh, sinh, SinhInplace, sinh_inplace = _broadcast(scal.Sinh, 'Sinh')
Tanh, tanh, TanhInplace, tanh_inplace = broadcast(scal.Tanh, 'Tanh') Tanh, tanh, TanhInplace, tanh_inplace = _broadcast(scal.Tanh, 'Tanh')
Sum = s2t.Sum Fill, fill, FillInplace, fill_inplace = _broadcast(scal.Second, 'Fill')
sum = gof.op.constructor(Sum)
Fill, fill, FillInplace, fill_inplace = broadcast(scal.Second, 'Fill')
def ones_like(model): def ones_like(model):
return fill(model, 1.0) return fill(model, 1.0)
def zeros_like(model): def zeros_like(model):
return fill(model, 0.0) return fill(model, 0.0)
TensorCopy, tensor_copy = broadcast(scal.Identity, 'TensorCopy', False) TensorCopy, tensor_copy = _broadcast(scal.Identity, 'TensorCopy', inplace_versions = False)
Sum = s2t.Sum
sum = gof.op.constructor(Sum)
##########################
# Arithmetics
##########################
Add, add, AddInplace, add_inplace = _broadcast(scal.Add, 'Add')
Sub, sub, SubInplace, sub_inplace = _broadcast(scal.Sub, 'Sub')
Mul, mul, MulInplace, mul_inplace = _broadcast(scal.Mul, 'Mul')
Div, div, DivInplace, div_inplace = _broadcast(scal.Div, 'Div')
Pow, pow, PowInplace, pow_inplace = _broadcast(scal.Pow, 'Pow')
########################## ##########################
...@@ -606,17 +623,6 @@ class Subtensor(Op, Viewer): ...@@ -606,17 +623,6 @@ class Subtensor(Op, Viewer):
subtensor = gof.op.constructor(Subtensor) subtensor = gof.op.constructor(Subtensor)
##########################
# Arithmetics
##########################
Add, add, AddInplace, add_inplace = broadcast(scal.Add, 'Add')
Sub, sub, SubInplace, sub_inplace = broadcast(scal.Sub, 'Sub')
Mul, mul, MulInplace, mul_inplace = broadcast(scal.Mul, 'Mul')
Div, div, DivInplace, div_inplace = broadcast(scal.Div, 'Div')
Pow, pow, PowInplace, pow_inplace = broadcast(scal.Pow, 'Pow')
######################### #########################
# Linalg : Dot # Linalg : Dot
######################### #########################
...@@ -624,8 +630,7 @@ Pow, pow, PowInplace, pow_inplace = broadcast(scal.Pow, 'Pow') ...@@ -624,8 +630,7 @@ Pow, pow, PowInplace, pow_inplace = broadcast(scal.Pow, 'Pow')
class Dot(_Op): class Dot(_Op):
nin=2 nin=2
nout=1 nout=1
@staticmethod def propagate_broadcastable(self, bx, by):
def broadcastable_rule(bx,by):
if len(bx) == 0: # x is a scalar if len(bx) == 0: # x is a scalar
rval = by rval = by
else: else:
...@@ -635,20 +640,11 @@ class Dot(_Op): ...@@ -635,20 +640,11 @@ class Dot(_Op):
rval = bx[:-1] rval = bx[:-1]
else: #y is a scalar else: #y is a scalar
rval = bx rval = bx
return rval return [rval]
def propagate_broadcastable(self, bx, by):
return [self.broadcastable_rule(bx,by)]
def impl(self, x, y): def impl(self, x, y):
return numpy.dot(x, y) return numpy.dot(x, y)
def grad(self, (x, y), (gz,)): def grad(self, (x, y), (gz,)):
return dot(gz, y.T), dot(x.T, gz) return dot(gz, y.T), dot(x.T, gz)
if 0:
def c_support_code(self):
return blas.cblas_header_text()
def c_libs(self):
return blas.ldflags()
def c_impl(self, (_x, _y), (_z, )):
return blas.gemm_code('', '1.0', '0.0')
dot = gof.op.constructor(Dot) dot = gof.op.constructor(Dot)
class Gemm(_Op): class Gemm(_Op):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论