提交 663ddaa0 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

some stuff

上级 ef638c6a
......@@ -189,6 +189,24 @@ def eval_outputs(outputs,
return rval
def infer_reuse_pattern(env, outputs_to_disown):
do_not_reuse = outputs_to_disown
seen = set()
def walk(r):
if env.edge(r) or r in seen:
return
seen.add(r)
do_not_reuse.append(r)
op = r.owner
dmap = op.destroy_map() if hasattr(op, 'destroy_map') else {}
vmap = op.view_map() if hasattr(op, 'view_map') else {}
cat = lambda x, y: list(x) + list(y)
for r2 in reduce(cat, dmap.values()) + reduce(cat, vmap.values()):
accumulate(r2)
for output in outputs_to_disown:
walk(output)
return do_not_reuse
# StateFunction([x, y], [e], (w, w + lr * bla()))
......
......@@ -412,11 +412,14 @@ class Broadcast(Op, Destroyer):
def make_broadcast(scalar_opclass, inplace_pattern = {}, name = None):
def make_broadcast(scalar_opclass, inplace_pattern = {}, name = None, module_name = None):
scalar_name = scalar_opclass.__name__
if name is None:
name = "Tensor" + scalar_opclass.__name__
name = scalar_name
if module_name is None:
module_name = 'elemwise.make_broadcast(%s, %s, %s)' % (scalar_name, inplace_pattern, repr(name))
name = "New"
scalar_name = scalar_opclass.__name__
previous_doc = Broadcast.__doc__
scalar_doc = scalar_opclass.__doc__ or ""
......@@ -449,6 +452,7 @@ def make_broadcast(scalar_opclass, inplace_pattern = {}, name = None):
def desc(cls):
return (Broadcast, scalar_opclass, tuple(inplace_pattern.items()))
New.__name__ = name
New.__module__ = module_name
return New
def wrap_broadcast(op):
......
......@@ -493,4 +493,3 @@ def view_roots(r):
return [r]
else:
return [r]
......@@ -445,15 +445,21 @@ class _Op(Op):
# Unary Operations
##########################
def broadcast(scalar_opclass, name, inplace_versions = True):
C = s2t.make_broadcast(scalar_opclass, name = name)
def broadcast(scalar_opclass, name, module_name = None, inplace_versions = True):
C = s2t.make_broadcast(scalar_opclass, name = name, module_name = module_name) # this returns a class
C.__module__ = module_name
c = gof.op.constructor(s2t.wrap_broadcast(C))
if inplace_versions:
CInplace = s2t.make_broadcast(scalar_opclass, {0:0}, name = name+"Inplace")
CInplace.__module__ = module_name
c_inplace = gof.op.constructor(s2t.wrap_broadcast(CInplace))
return C, c, CInplace, c_inplace
else:
return C, c
def _broadcast(scalar_opclass, name, inplace_versions = True):
return broadcast(scalar_opclass, name, 'tensor', inplace_versions)
class Argmax(Op):
"""Calculate the max and argmax over a given axis"""
......@@ -487,32 +493,43 @@ def max(x, axis=None):
# but when Argmax.c_impl() is in place, it should be fine.
return argmax(x,axis)[0]
Abs, _abs, AbsInplace, abs_inplace = broadcast(scal.Abs, 'Abs')
Exp, exp, ExpInplace, exp_inplace = broadcast(scal.Exp, 'Exp')
Neg, neg, NegInplace, neg_inplace = broadcast(scal.Neg, 'Neg')
Log, log, LogInplace, log_inplace = broadcast(scal.Log, 'Log')
Log2, log2, Log2Inplace, log2_inplace = broadcast(scal.Log2, 'Log2')
Sgn, sgn, SgnInplace, sgn_inplace = broadcast(scal.Sgn, 'Sgn')
Sqr, sqr, SqrInplace, sqr_inplace = broadcast(scal.Sqr, 'Sqr')
Sqrt, sqrt, SqrtInplace, sqrt_inplace = broadcast(scal.Sqrt, 'Sqrt')
Cos, cos, CosInplace, cos_inplace = broadcast(scal.Cos, 'Cos')
Sin, sin, SinInplace, sin_inplace = broadcast(scal.Sin, 'Sin')
Tan, tan, TanInplace, tan_inplace = broadcast(scal.Tan, 'Tan')
Cosh, cosh, CoshInplace, cosh_inplace = broadcast(scal.Cosh, 'Cosh')
Sinh, sinh, SinhInplace, sinh_inplace = broadcast(scal.Sinh, 'Sinh')
Tanh, tanh, TanhInplace, tanh_inplace = broadcast(scal.Tanh, 'Tanh')
Sum = s2t.Sum
sum = gof.op.constructor(Sum)
Fill, fill, FillInplace, fill_inplace = broadcast(scal.Second, 'Fill')
Abs, _abs, AbsInplace, abs_inplace = _broadcast(scal.Abs, 'Abs')
Exp, exp, ExpInplace, exp_inplace = _broadcast(scal.Exp, 'Exp')
Neg, neg, NegInplace, neg_inplace = _broadcast(scal.Neg, 'Neg')
Log, log, LogInplace, log_inplace = _broadcast(scal.Log, 'Log')
Log2, log2, Log2Inplace, log2_inplace = _broadcast(scal.Log2, 'Log2')
Sgn, sgn, SgnInplace, sgn_inplace = _broadcast(scal.Sgn, 'Sgn')
Sqr, sqr, SqrInplace, sqr_inplace = _broadcast(scal.Sqr, 'Sqr')
Sqrt, sqrt, SqrtInplace, sqrt_inplace = _broadcast(scal.Sqrt, 'Sqrt')
Cos, cos, CosInplace, cos_inplace = _broadcast(scal.Cos, 'Cos')
Sin, sin, SinInplace, sin_inplace = _broadcast(scal.Sin, 'Sin')
Tan, tan, TanInplace, tan_inplace = _broadcast(scal.Tan, 'Tan')
Cosh, cosh, CoshInplace, cosh_inplace = _broadcast(scal.Cosh, 'Cosh')
Sinh, sinh, SinhInplace, sinh_inplace = _broadcast(scal.Sinh, 'Sinh')
Tanh, tanh, TanhInplace, tanh_inplace = _broadcast(scal.Tanh, 'Tanh')
Fill, fill, FillInplace, fill_inplace = _broadcast(scal.Second, 'Fill')
def ones_like(model):
return fill(model, 1.0)
def zeros_like(model):
return fill(model, 0.0)
TensorCopy, tensor_copy = broadcast(scal.Identity, 'TensorCopy', False)
TensorCopy, tensor_copy = _broadcast(scal.Identity, 'TensorCopy', inplace_versions = False)
Sum = s2t.Sum
sum = gof.op.constructor(Sum)
##########################
# Arithmetics
##########################
Add, add, AddInplace, add_inplace = _broadcast(scal.Add, 'Add')
Sub, sub, SubInplace, sub_inplace = _broadcast(scal.Sub, 'Sub')
Mul, mul, MulInplace, mul_inplace = _broadcast(scal.Mul, 'Mul')
Div, div, DivInplace, div_inplace = _broadcast(scal.Div, 'Div')
Pow, pow, PowInplace, pow_inplace = _broadcast(scal.Pow, 'Pow')
##########################
......@@ -606,17 +623,6 @@ class Subtensor(Op, Viewer):
subtensor = gof.op.constructor(Subtensor)
##########################
# Arithmetics
##########################
Add, add, AddInplace, add_inplace = broadcast(scal.Add, 'Add')
Sub, sub, SubInplace, sub_inplace = broadcast(scal.Sub, 'Sub')
Mul, mul, MulInplace, mul_inplace = broadcast(scal.Mul, 'Mul')
Div, div, DivInplace, div_inplace = broadcast(scal.Div, 'Div')
Pow, pow, PowInplace, pow_inplace = broadcast(scal.Pow, 'Pow')
#########################
# Linalg : Dot
#########################
......@@ -624,8 +630,7 @@ Pow, pow, PowInplace, pow_inplace = broadcast(scal.Pow, 'Pow')
class Dot(_Op):
nin=2
nout=1
@staticmethod
def broadcastable_rule(bx,by):
def propagate_broadcastable(self, bx, by):
if len(bx) == 0: # x is a scalar
rval = by
else:
......@@ -635,20 +640,11 @@ class Dot(_Op):
rval = bx[:-1]
else: #y is a scalar
rval = bx
return rval
def propagate_broadcastable(self, bx, by):
return [self.broadcastable_rule(bx,by)]
return [rval]
def impl(self, x, y):
return numpy.dot(x, y)
def grad(self, (x, y), (gz,)):
return dot(gz, y.T), dot(x.T, gz)
if 0:
def c_support_code(self):
return blas.cblas_header_text()
def c_libs(self):
return blas.ldflags()
def c_impl(self, (_x, _y), (_z, )):
return blas.gemm_code('', '1.0', '0.0')
dot = gof.op.constructor(Dot)
class Gemm(_Op):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论