提交 3cccea1f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove Scan.*_is_tensor attributes

上级 89b0b822
...@@ -705,21 +705,6 @@ class Scan(Op, ScanMethodsMixin): ...@@ -705,21 +705,6 @@ class Scan(Op, ScanMethodsMixin):
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) self.__dict__.update(d)
if getattr(self, "_fn", None) is not None:
if not hasattr(self, "outs_is_tensor"):
# The thunk has been compiled before the analysis, at
# compilation time, of the location of the inputs and outputs.
# Perform this analysis here.
self.inps_is_tensor = [
isinstance(out, TensorVariable)
for out in self.fn.maker.fgraph.inputs
]
self.outs_is_tensor = [
isinstance(out, TensorVariable)
for out in self.fn.maker.fgraph.outputs
]
# Ensure that the graph associated with the inner function is valid. # Ensure that the graph associated with the inner function is valid.
self.validate_inner_graph() self.validate_inner_graph()
...@@ -1376,10 +1361,10 @@ class Scan(Op, ScanMethodsMixin): ...@@ -1376,10 +1361,10 @@ class Scan(Op, ScanMethodsMixin):
# Analyse the compile inner function to determine which inputs and # Analyse the compile inner function to determine which inputs and
# outputs are on the gpu and speed up some checks during the execution # outputs are on the gpu and speed up some checks during the execution
self.inps_is_tensor = [ inps_is_tensor = [
isinstance(out, TensorVariable) for out in self.fn.maker.fgraph.inputs isinstance(out, TensorVariable) for out in self.fn.maker.fgraph.inputs
] ]
self.outs_is_tensor = [ outs_is_tensor = [
isinstance(out, TensorVariable) for out in self.fn.maker.fgraph.outputs isinstance(out, TensorVariable) for out in self.fn.maker.fgraph.outputs
] ]
...@@ -1420,8 +1405,8 @@ class Scan(Op, ScanMethodsMixin): ...@@ -1420,8 +1405,8 @@ class Scan(Op, ScanMethodsMixin):
self.mitmots_preallocated, dtype="int32" self.mitmots_preallocated, dtype="int32"
) )
cython_inps_is_tensor = np.asarray(self.inps_is_tensor, dtype="int32") cython_inps_is_tensor = np.asarray(inps_is_tensor, dtype="int32")
cython_outs_is_tensor = np.asarray(self.outs_is_tensor, dtype="int32") cython_outs_is_tensor = np.asarray(outs_is_tensor, dtype="int32")
if self.destroy_map: if self.destroy_map:
cython_destroy_map = [ cython_destroy_map = [
...@@ -1695,7 +1680,7 @@ class Scan(Op, ScanMethodsMixin): ...@@ -1695,7 +1680,7 @@ class Scan(Op, ScanMethodsMixin):
if var is None: if var is None:
old_inner_output_data[idx] = None old_inner_output_data[idx] = None
elif self.outs_is_tensor[idx]: elif isinstance(self.fn.maker.fgraph.outputs[idx], TensorVariable):
old_inner_output_data[idx] = var.data old_inner_output_data[idx] = var.data
else: else:
old_inner_output_data[idx] = var.gpudata old_inner_output_data[idx] = var.gpudata
...@@ -1713,7 +1698,9 @@ class Scan(Op, ScanMethodsMixin): ...@@ -1713,7 +1698,9 @@ class Scan(Op, ScanMethodsMixin):
if var is None: if var is None:
old_mitmot_input_data[idx] = None old_mitmot_input_data[idx] = None
elif self.inps_is_tensor[idx + self.n_seqs]: elif isinstance(
self.fn.maker.fgraph.inputs[idx + self.n_seqs], TensorVariable
):
old_mitmot_input_data[idx] = var.data old_mitmot_input_data[idx] = var.data
else: else:
old_mitmot_input_data[idx] = var.gpudata old_mitmot_input_data[idx] = var.gpudata
...@@ -1783,7 +1770,10 @@ class Scan(Op, ScanMethodsMixin): ...@@ -1783,7 +1770,10 @@ class Scan(Op, ScanMethodsMixin):
new_var = inner_input_storage[self.n_seqs + inp_idx].storage[0] new_var = inner_input_storage[self.n_seqs + inp_idx].storage[0]
if old_var is new_var: if old_var is new_var:
old_data = old_mitmot_input_data[inp_idx] old_data = old_mitmot_input_data[inp_idx]
if self.inps_is_tensor[self.n_seqs + inp_idx]: if isinstance(
self.fn.maker.fgraph.inputs[self.n_seqs + inp_idx],
TensorVariable,
):
same_data = new_var.data == old_data same_data = new_var.data == old_data
else: else:
same_data = new_var.gpudata == old_data same_data = new_var.gpudata == old_data
...@@ -1832,7 +1822,9 @@ class Scan(Op, ScanMethodsMixin): ...@@ -1832,7 +1822,9 @@ class Scan(Op, ScanMethodsMixin):
old_data = old_inner_output_data[offset_out + j] old_data = old_inner_output_data[offset_out + j]
if old_data is None: if old_data is None:
output_reused = False output_reused = False
elif self.outs_is_tensor[offset_out + j]: elif isinstance(
self.fn.maker.fgraph.outputs[offset_out + j], TensorVariable
):
output_reused = new_var.data == old_data output_reused = new_var.data == old_data
else: else:
output_reused = new_var.gpudata == old_data output_reused = new_var.gpudata == old_data
...@@ -1893,7 +1885,9 @@ class Scan(Op, ScanMethodsMixin): ...@@ -1893,7 +1885,9 @@ class Scan(Op, ScanMethodsMixin):
if old_var is new_var: if old_var is new_var:
if old_data is None: if old_data is None:
output_reused = False output_reused = False
elif self.outs_is_tensor[offset_out + j]: elif isinstance(
self.fn.maker.fgraph.outputs[offset_out + j], TensorVariable
):
output_reused = new_var.data == old_data output_reused = new_var.data == old_data
else: else:
output_reused = new_var.gpudata == old_data output_reused = new_var.gpudata == old_data
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论