提交 5d1a7558 authored 作者: Frederic Bastien's avatar Frederic Bastien
...@@ -417,6 +417,30 @@ class NEQ(LogicalComparison): ...@@ -417,6 +417,30 @@ class NEQ(LogicalComparison):
return x != y return x != y
neq = NEQ() neq = NEQ()
class InRange(LogicalComparison):
nin = 3
def __init__(self, openlow, openhi):
self.openlow = openlow
self.openhi = openhi
def impl(self, x, low, hi):
if self.openlow and x <= low:
return False
elif not self.openlow and x < low:
return False
if self.openhi and x >= hi:
return False
elif not self.openhi and x > hi:
return False
return True
def c_code(self, node, name, (x, low, hi), (z, ), sub):
cmp1 = '>' if self.openlow else '>='
cmp2 = '<' if self.openhi else '<='
return "%(z)s = %(x)s %(cmp1)s %(low)s && %(x)s %(cmp2)s %(hi)s;" % locals()
def grad(self, (x, low, hi), (gz, )):
return None, None, None
inopenrange = InRange(True, True)
inclosedrange = InRange(False, False)
#################### ####################
# BIT-WISE OPERATORS # BIT-WISE OPERATORS
#################### ####################
...@@ -545,6 +569,17 @@ class Pow(BinaryScalarOp): ...@@ -545,6 +569,17 @@ class Pow(BinaryScalarOp):
return gz * y * x**(y - 1), gz * log(x) * x**y return gz * y * x**(y - 1), gz * log(x) * x**y
pow = Pow(upcast_out, name = 'pow') pow = Pow(upcast_out, name = 'pow')
class Clip(ScalarOp):
nin = 3
def impl(self, x, min, max):
return min if x < min else max if x > max else x
def c_code(self, node, name, (x, min, max), (z, ), sub):
return "%(z)s = %(x)s < %(min)s ? %(min)s : %(x)s > %(max)s ? %(max)s : %(x)s;" % locals()
def grad(self, (x, min, max), (gz, )):
gx = (x > min & x < max) * gz
return gx, None, None
clip = Clip(transfer_type(0), name = 'clip')
class First(BinaryScalarOp): class First(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
return x return x
......
...@@ -583,6 +583,14 @@ class _tensor_py_operators: ...@@ -583,6 +583,14 @@ class _tensor_py_operators:
def sum(self, axis=None): def sum(self, axis=None):
return elemwise.Sum(axis)(self) return elemwise.Sum(axis)(self)
def norm(self, L, axis=None):
if L==0:
raise NotImplementedError()
if L==float('inf'):
raise NotImplementedError()
#optimizations will/should catch cases like L=1, L=2
return pow(pow(self, L).sum(axis=axis), 1.0/L)
class TensorResult(Result, _tensor_py_operators): class TensorResult(Result, _tensor_py_operators):
...@@ -1144,6 +1152,10 @@ def mod(a, b): ...@@ -1144,6 +1152,10 @@ def mod(a, b):
def pow(a, b): def pow(a, b):
"""elementwise power""" """elementwise power"""
@_scal_elemwise
def clip(x, min, max):
"""clip x to be between min and max"""
pprint.assign(add, printing.OperatorPrinter('+', -2, 'either')) pprint.assign(add, printing.OperatorPrinter('+', -2, 'either'))
pprint.assign(mul, printing.OperatorPrinter('*', -1, 'either')) pprint.assign(mul, printing.OperatorPrinter('*', -1, 'either'))
pprint.assign(sub, printing.OperatorPrinter('-', -2, 'left')) pprint.assign(sub, printing.OperatorPrinter('-', -2, 'left'))
......
...@@ -591,7 +591,7 @@ class Elemwise(Op): ...@@ -591,7 +591,7 @@ class Elemwise(Op):
task_code = self.scalar_op.c_code(Apply(self.scalar_op, task_code = self.scalar_op.c_code(Apply(self.scalar_op,
[Scalar(dtype = input.type.dtype)() for input in node.inputs], [Scalar(dtype = input.type.dtype)() for input in node.inputs],
[Scalar(dtype = output.type.dtype)() for input in node.outputs]), [Scalar(dtype = output.type.dtype)() for input in node.outputs]),
None, name + '_scalar_',
["%s_i" % s for s in _inames], ["%s_i" % s for s in _inames],
["%s_i" % s for s in onames], ["%s_i" % s for s in onames],
sub) sub)
...@@ -615,6 +615,9 @@ class Elemwise(Op): ...@@ -615,6 +615,9 @@ class Elemwise(Op):
code = "\n".join(self._c_all(node, name, inames, onames, sub)) code = "\n".join(self._c_all(node, name, inames, onames, sub))
return code return code
def c_support_code(self):
return self.scalar_op.c_support_code()
# def elemwise_to_scal(env): # def elemwise_to_scal(env):
# mapping = {} # mapping = {}
# inputs = [] # inputs = []
......
...@@ -475,6 +475,8 @@ class Canonizer(gof.LocalOptimizer): ...@@ -475,6 +475,8 @@ class Canonizer(gof.LocalOptimizer):
return v return v
if isinstance(v, gof.Constant): if isinstance(v, gof.Constant):
return v.data return v.data
if not hasattr(v, 'owner'):
return v
if v.owner and isinstance(v.owner.op, DimShuffle): if v.owner and isinstance(v.owner.op, DimShuffle):
return cls.get_constant(v.owner.inputs[0]) return cls.get_constant(v.owner.inputs[0])
return None return None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论