提交 5f04a6bc authored 作者: dima's avatar dima

Fixed comparisons with None which were displayed during the test run

上级 5ec4c302
...@@ -1935,11 +1935,11 @@ def local_useless_subtensor(node): ...@@ -1935,11 +1935,11 @@ def local_useless_subtensor(node):
# If idx is not a slice, this means we remove this dimension # If idx is not a slice, this means we remove this dimension
# from the output, so the subtensor is not useless # from the output, so the subtensor is not useless
return False return False
if idx.start not in [0, None]: if idx.start is not None and idx.start != 0:
# If the start of the slice is different from 0, or is a # If the start of the slice is different from 0, or is a
# variable, then we assume the subtensor is not useless # variable, then we assume the subtensor is not useless
return False return False
if idx.step not in [1, None]: if idx.step is not None and idx.step != 1:
# If we are going backwards, or skipping elements, then this # If we are going backwards, or skipping elements, then this
# is not a useless subtensor # is not a useless subtensor
return False return False
...@@ -2543,9 +2543,9 @@ def local_setsubtensor_of_constants(node): ...@@ -2543,9 +2543,9 @@ def local_setsubtensor_of_constants(node):
except NotScalarConstantError: except NotScalarConstantError:
pass pass
if (replace_x == replace_y and if (replace_x is not None and
replace_x is not None and replace_y is not None and
replace_y is not None): replace_x == replace_y):
return [x] return [x]
else: else:
return False return False
......
...@@ -141,7 +141,7 @@ def get_canonical_form_slice(theslice, length): ...@@ -141,7 +141,7 @@ def get_canonical_form_slice(theslice, length):
(is_start_constant and is_length_constant and (is_start_constant and is_length_constant and
start < 0 and start + length <= 0)) start < 0 and start + length <= 0))
is_stop_length = ( is_stop_length = (
stop in [None, length, maxsize] or stop is None or stop in [length, maxsize] or
(is_stop_constant and is_length_constant and (is_stop_constant and is_length_constant and
stop >= length)) stop >= length))
if is_start_0: if is_start_0:
...@@ -217,7 +217,7 @@ def get_canonical_form_slice(theslice, length): ...@@ -217,7 +217,7 @@ def get_canonical_form_slice(theslice, length):
start = switch(ge(start, length), start = switch(ge(start, length),
switch_neg_step(length - 1, length), switch_neg_step(length - 1, length),
start) start)
if stop in [None, maxsize]: if stop is None or stop == maxsize:
# The special "maxsize" case is probably not needed here, # The special "maxsize" case is probably not needed here,
# as slices containing maxsize are not generated by # as slices containing maxsize are not generated by
# __getslice__ anymore. # __getslice__ anymore.
...@@ -478,7 +478,7 @@ class Subtensor(Op): ...@@ -478,7 +478,7 @@ class Subtensor(Op):
start = get_scalar_constant_value(start) start = get_scalar_constant_value(start)
except NotScalarConstantError: except NotScalarConstantError:
pass pass
if start in [None, 0]: if start is None or start == 0:
start = p.start start = p.start
if start is None: if start is None:
start = 0 start = 0
......
...@@ -369,7 +369,7 @@ class _tensor_py_operators: ...@@ -369,7 +369,7 @@ class _tensor_py_operators:
axis = None axis = None
for i, arg in enumerate(args): for i, arg in enumerate(args):
try: try:
if arg != numpy.newaxis: if arg is not numpy.newaxis:
theano.tensor.subtensor.Subtensor.convert(arg) theano.tensor.subtensor.Subtensor.convert(arg)
except theano.tensor.subtensor.AdvancedIndexingError: except theano.tensor.subtensor.AdvancedIndexingError:
if advanced: if advanced:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论