提交 ed72a991 authored 作者: James Bergstra's avatar James Bergstra

Added convert_to_<type> functions to scalar, and a corresponding cast(x, dtype)

function. This corresponds to what is already in theano.tensor.
上级 a4f8ef4e
...@@ -621,8 +621,6 @@ class Add(ScalarOp): ...@@ -621,8 +621,6 @@ class Add(ScalarOp):
else: else:
retval += [None] retval += [None]
return retval return retval
#backport
#return [(gz if i.type in grad_types else None) for i in inputs]
add = Add(upcast_out, name = 'add') add = Add(upcast_out, name = 'add')
class Mul(ScalarOp): class Mul(ScalarOp):
...@@ -668,8 +666,6 @@ class Sub(BinaryScalarOp): ...@@ -668,8 +666,6 @@ class Sub(BinaryScalarOp):
second_part = None second_part = None
return first_part, second_part return first_part, second_part
#return gz if x.type in grad_types else None, -gz if y.type in grad_types else None
sub = Sub(upcast_out, name = 'sub') sub = Sub(upcast_out, name = 'sub')
def div_proxy(x, y): def div_proxy(x, y):
...@@ -707,11 +703,7 @@ class TrueDiv(BinaryScalarOp): ...@@ -707,11 +703,7 @@ class TrueDiv(BinaryScalarOp):
second_part = -(gz * x) / (y * y) second_part = -(gz * x) / (y * y)
else: else:
second_part = None second_part = None
return first_part, second_part
return (first_part, second_part)
#return (gz / y if x.type in grad_types else None,
# -(gz * x) / (y * y) if y.type in grad_types else None)
true_div = TrueDiv(upcast_out, name = 'true_div') true_div = TrueDiv(upcast_out, name = 'true_div')
class IntDiv(BinaryScalarOp): class IntDiv(BinaryScalarOp):
...@@ -825,6 +817,31 @@ class Identity(UnaryScalarOp): ...@@ -825,6 +817,31 @@ class Identity(UnaryScalarOp):
#return gz if x.type in grad_types else None, #return gz if x.type in grad_types else None,
identity = Identity(same_out, name = 'identity') identity = Identity(same_out, name = 'identity')
#### CASTING OPERATIONS
convert_to_int8 = Identity(specific_out(int8), name='convert_to_int8')
convert_to_int16 = Identity(specific_out(int16), name='convert_to_int16')
convert_to_int32 = Identity(specific_out(int32), name='convert_to_int32')
convert_to_int64 = Identity(specific_out(int64), name='convert_to_int64')
convert_to_float32 = Identity(specific_out(float32), name='convert_to_float32')
convert_to_float64 = Identity(specific_out(float64), name='convert_to_float64')
convert_to_complex64 = Identity(specific_out(complex64), name='convert_to_complex64')
convert_to_complex128 = Identity(specific_out(complex128), name='convert_to_complex128')
def cast(t, dtype):
if t.type.dtype == dtype:
return t
"""symbolically cast `t` to a Scalar of type `dtype`."""
mapping = {'int8': convert_to_int8,
'int16': convert_to_int16,
'int32': convert_to_int32,
'int64': convert_to_int64,
'float32': convert_to_float32,
'float64': convert_to_float64,
'complex64': convert_to_complex64,
'complex128': convert_to_complex128}
return mapping[dtype](t)
class Abs(UnaryScalarOp): class Abs(UnaryScalarOp):
def make_node(self, x): def make_node(self, x):
inputs = [as_scalar(input) for input in [x]] inputs = [as_scalar(input) for input in [x]]
...@@ -883,8 +900,6 @@ class Neg(UnaryScalarOp): ...@@ -883,8 +900,6 @@ class Neg(UnaryScalarOp):
return -gz, return -gz,
else: else:
return None, return None,
#backport
#return -gz if x.type in grad_types else None,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = -%(x)s;" % locals() return "%(z)s = -%(x)s;" % locals()
neg = Neg(same_out, name = 'neg') neg = Neg(same_out, name = 'neg')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论