提交 3cccea1f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove Scan.*_is_tensor attributes

上级 89b0b822
......@@ -705,21 +705,6 @@ class Scan(Op, ScanMethodsMixin):
def __setstate__(self, d):
self.__dict__.update(d)
if getattr(self, "_fn", None) is not None:
if not hasattr(self, "outs_is_tensor"):
# The thunk has been compiled before the analysis, at
# compilation time, of the location of the inputs and outputs.
# Perform this analysis here.
self.inps_is_tensor = [
isinstance(out, TensorVariable)
for out in self.fn.maker.fgraph.inputs
]
self.outs_is_tensor = [
isinstance(out, TensorVariable)
for out in self.fn.maker.fgraph.outputs
]
# Ensure that the graph associated with the inner function is valid.
self.validate_inner_graph()
......@@ -1376,10 +1361,10 @@ class Scan(Op, ScanMethodsMixin):
# Analyse the compile inner function to determine which inputs and
# outputs are on the gpu and speed up some checks during the execution
self.inps_is_tensor = [
inps_is_tensor = [
isinstance(out, TensorVariable) for out in self.fn.maker.fgraph.inputs
]
self.outs_is_tensor = [
outs_is_tensor = [
isinstance(out, TensorVariable) for out in self.fn.maker.fgraph.outputs
]
......@@ -1420,8 +1405,8 @@ class Scan(Op, ScanMethodsMixin):
self.mitmots_preallocated, dtype="int32"
)
cython_inps_is_tensor = np.asarray(self.inps_is_tensor, dtype="int32")
cython_outs_is_tensor = np.asarray(self.outs_is_tensor, dtype="int32")
cython_inps_is_tensor = np.asarray(inps_is_tensor, dtype="int32")
cython_outs_is_tensor = np.asarray(outs_is_tensor, dtype="int32")
if self.destroy_map:
cython_destroy_map = [
......@@ -1695,7 +1680,7 @@ class Scan(Op, ScanMethodsMixin):
if var is None:
old_inner_output_data[idx] = None
elif self.outs_is_tensor[idx]:
elif isinstance(self.fn.maker.fgraph.outputs[idx], TensorVariable):
old_inner_output_data[idx] = var.data
else:
old_inner_output_data[idx] = var.gpudata
......@@ -1713,7 +1698,9 @@ class Scan(Op, ScanMethodsMixin):
if var is None:
old_mitmot_input_data[idx] = None
elif self.inps_is_tensor[idx + self.n_seqs]:
elif isinstance(
self.fn.maker.fgraph.inputs[idx + self.n_seqs], TensorVariable
):
old_mitmot_input_data[idx] = var.data
else:
old_mitmot_input_data[idx] = var.gpudata
......@@ -1783,7 +1770,10 @@ class Scan(Op, ScanMethodsMixin):
new_var = inner_input_storage[self.n_seqs + inp_idx].storage[0]
if old_var is new_var:
old_data = old_mitmot_input_data[inp_idx]
if self.inps_is_tensor[self.n_seqs + inp_idx]:
if isinstance(
self.fn.maker.fgraph.inputs[self.n_seqs + inp_idx],
TensorVariable,
):
same_data = new_var.data == old_data
else:
same_data = new_var.gpudata == old_data
......@@ -1832,7 +1822,9 @@ class Scan(Op, ScanMethodsMixin):
old_data = old_inner_output_data[offset_out + j]
if old_data is None:
output_reused = False
elif self.outs_is_tensor[offset_out + j]:
elif isinstance(
self.fn.maker.fgraph.outputs[offset_out + j], TensorVariable
):
output_reused = new_var.data == old_data
else:
output_reused = new_var.gpudata == old_data
......@@ -1893,7 +1885,9 @@ class Scan(Op, ScanMethodsMixin):
if old_var is new_var:
if old_data is None:
output_reused = False
elif self.outs_is_tensor[offset_out + j]:
elif isinstance(
self.fn.maker.fgraph.outputs[offset_out + j], TensorVariable
):
output_reused = new_var.data == old_data
else:
output_reused = new_var.gpudata == old_data
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论