提交 f608b307 authored 作者: Frederic Bastien's avatar Frederic Bastien

Allow binary bit-wise op on int16(previous commit was for unary bit-wise) and…

Allow binary bit-wise op on int16(previous commit was for unary bit-wise) and add c code for bit-wise and, or, xor and invert.
上级 3912f005
......@@ -788,7 +788,7 @@ class BinaryBitOp(BinaryScalarOp):
def output_types(self, *input_types):
t0, t1 = input_types[0]
for i in input_types[0]:
if i not in (int8, int32, int64):
if i not in (int8, int16, int32, int64):
raise TypeError('input to a BitOp must have type int8, int32 or int64... not %s' % i)
return upcast_out(*input_types[0])
def grad(self, inputs, output_gradients):
......@@ -800,6 +800,8 @@ class OR(BinaryBitOp):
associative = False
def impl(self, x, y):
return x | y
def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = (%(x)s | %(y)s);" % locals()
or_ = OR()
class XOR(BinaryBitOp):
......@@ -808,6 +810,8 @@ class XOR(BinaryBitOp):
associative = False
def impl(self, x, y):
return x ^ y
def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = (%(x)s ^ %(y)s);" % locals()
xor = XOR()
class AND(BinaryBitOp):
......@@ -816,12 +820,29 @@ class AND(BinaryBitOp):
associative = False
def impl(self, x, y):
return x & y
def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = (%(x)s & %(y)s);" % locals()
and_ = AND()
class Invert(UnaryBitOp):
identity = False
def impl(self, x):
return ~x
def c_code(self, node, name, (x,), (z, ), sub):
dtype = node.inputs[0].type.dtype
# For an unknow reason, the pattern must have 2 times the number of bits
# then the inputs...
if dtype == 'int8':
pattern = "0xFF"
elif dtype == 'int16':
pattern = "0xFFFF"
elif dtype == 'int32':
pattern = "0xFFFFFFFF"
elif dtype == 'int64':
pattern = "0xFFFFFFFFFFFFFFFF"
else:
super(Invert, self).c_code(node, name, (x,), (z, ), sub)
return "%(z)s = (%(x)s ^ %(pattern)s);" % locals()
invert = Invert()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论