提交 22fb1439 authored 作者: Amjad Almahairi's avatar Amjad Almahairi

fix dtype to use output.dtype and add necessary optimizations for tests

上级 0f1edb05
...@@ -4246,11 +4246,11 @@ def local_useless_elemwise_comparison(node): ...@@ -4246,11 +4246,11 @@ def local_useless_elemwise_comparison(node):
# Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X) # Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X)
if isinstance(node.op.scalar_op, (scalar.LT, scalar.GT)) and \ if isinstance(node.op.scalar_op, (scalar.LT, scalar.GT)) and \
node.inputs[0] is node.inputs[1]: node.inputs[0] is node.inputs[1]:
return [T.zeros_like(node.inputs[0], dtype=node.outputs[0].type.dtype)] return [T.zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X) # Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X)
if isinstance(node.op.scalar_op, (scalar.LE, scalar.GE)) and \ if isinstance(node.op.scalar_op, (scalar.LE, scalar.GE)) and \
node.inputs[0] is node.inputs[1]: node.inputs[0] is node.inputs[1]:
return [T.ones_like(node.inputs[0], dtype=node.outputs[0].type.dtype)] return [T.ones_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[{minimum,maximum}](X, X) -> X # Elemwise[{minimum,maximum}](X, X) -> X
if isinstance(node.op.scalar_op, (scalar.Minimum, scalar.Maximum)) and \ if isinstance(node.op.scalar_op, (scalar.Minimum, scalar.Maximum)) and \
node.inputs[0] is node.inputs[1]: node.inputs[0] is node.inputs[1]:
...@@ -4261,13 +4261,13 @@ def local_useless_elemwise_comparison(node): ...@@ -4261,13 +4261,13 @@ def local_useless_elemwise_comparison(node):
node.inputs[0].owner and \ node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \ isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1]) == 0: T.extract_constant(node.inputs[1]) == 0:
return [T.zeros_like(node.inputs[0], dtype=node.inputs[1].type.dtype)] return [T.zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X) # Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X)
if isinstance(node.op.scalar_op, scalar.GE) and \ if isinstance(node.op.scalar_op, scalar.GE) and \
node.inputs[0].owner and \ node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \ isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1]) == 0: T.extract_constant(node.inputs[1]) == 0:
return [T.ones_like(node.inputs[0], dtype=node.inputs[1].type.dtype)] return [T.ones_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[maximum](X.shape[i], 0) -> X.shape[i] # Elemwise[maximum](X.shape[i], 0) -> X.shape[i]
if isinstance(node.op.scalar_op, scalar.Maximum) and \ if isinstance(node.op.scalar_op, scalar.Maximum) and \
node.inputs[0].owner and \ node.inputs[0].owner and \
...@@ -4285,13 +4285,13 @@ def local_useless_elemwise_comparison(node): ...@@ -4285,13 +4285,13 @@ def local_useless_elemwise_comparison(node):
node.inputs[0].owner and \ node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, Shape_i) and \ isinstance(node.inputs[0].owner.op, Shape_i) and \
T.extract_constant(node.inputs[1]) == 0: T.extract_constant(node.inputs[1]) == 0:
return [T.zeros_like(node.inputs[0], dtype=node.inputs[1].type.dtype)] return [T.zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[minimum](0, X.shape[i]) -> 0 # Elemwise[minimum](0, X.shape[i]) -> 0
if isinstance(node.op.scalar_op, scalar.Minimum) and \ if isinstance(node.op.scalar_op, scalar.Minimum) and \
T.extract_constant(node.inputs[0]) == 0 and \ T.extract_constant(node.inputs[0]) == 0 and \
node.inputs[1].owner and \ node.inputs[1].owner and \
isinstance(node.inputs[1].owner.op, Shape_i): isinstance(node.inputs[1].owner.op, Shape_i):
return [T.zeros_like(node.inputs[1], dtype=node.inputs[0].type.dtype)] return [T.zeros_like(node.inputs[1], dtype=node.outputs[0].dtype)]
# Elemwise[LT](add([anything that is shapes]), 0) -> Elemwise[zeros](X) # Elemwise[LT](add([anything that is shapes]), 0) -> Elemwise[zeros](X)
if isinstance(node.op.scalar_op, scalar.LT) and \ if isinstance(node.op.scalar_op, scalar.LT) and \
...@@ -4302,7 +4302,7 @@ def local_useless_elemwise_comparison(node): ...@@ -4302,7 +4302,7 @@ def local_useless_elemwise_comparison(node):
for var in node.inputs[0].owner.inputs]) and \ for var in node.inputs[0].owner.inputs]) and \
T.extract_constant(node.inputs[1]) == 0: T.extract_constant(node.inputs[1]) == 0:
return [T.zeros_like(node.inputs[0], dtype=node.inputs[1].dtype)] return [T.zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X) # Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X)
if isinstance(node.op.scalar_op, scalar.GE) and \ if isinstance(node.op.scalar_op, scalar.GE) and \
node.inputs[0].owner and \ node.inputs[0].owner and \
...@@ -4311,7 +4311,7 @@ def local_useless_elemwise_comparison(node): ...@@ -4311,7 +4311,7 @@ def local_useless_elemwise_comparison(node):
all([isinstance(var.owner and var.owner.op, Shape_i) all([isinstance(var.owner and var.owner.op, Shape_i)
for var in node.inputs[0].owner.inputs]) and \ for var in node.inputs[0].owner.inputs]) and \
T.extract_constant(node.inputs[1]) == 0: T.extract_constant(node.inputs[1]) == 0:
return [T.ones_like(node.inputs[0], dtype=node.inputs[1].dtype)] return [T.ones_like(node.inputs[0], dtype=node.outputs[0].dtype)]
return return
......
...@@ -3245,8 +3245,10 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3245,8 +3245,10 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
def test_shape_inequality_with_self(self): def test_shape_inequality_with_self(self):
x = T.vector('x', dtype=config.floatX) x = T.vector('x', dtype=config.floatX)
mode = theano.compile.get_default_mode().including('local_useless_elemwise_comparison') mode = theano.compile.get_default_mode().including('local_useless_elemwise_comparison',
'local_shape_to_shape_i',
'local_track_shape_i',
'local_subtensor_make_vector')
f = theano.function([x], T.lt(x.shape[0], 0), mode=mode) f = theano.function([x], T.lt(x.shape[0], 0), mode=mode)
self.assert_eqs_const(f, 0) self.assert_eqs_const(f, 0)
...@@ -3275,7 +3277,10 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3275,7 +3277,10 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
def test_shape_add_inequality(self): def test_shape_add_inequality(self):
x = T.vector('x', dtype=config.floatX) x = T.vector('x', dtype=config.floatX)
mode = theano.compile.get_default_mode().including('local_useless_elemwise_comparison') mode = theano.compile.get_default_mode().including('local_useless_elemwise_comparison',
'local_shape_to_shape_i',
'local_track_shape_i',
'local_subtensor_make_vector')
y = T.vector('y', dtype=config.floatX) y = T.vector('y', dtype=config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论