提交 5f470abb authored 作者: Razvan Pascanu's avatar Razvan Pascanu

list of expandable types used by scan for shared variables

上级 56f011a4
...@@ -179,3 +179,7 @@ class DeepCopyOp(gof.Op): ...@@ -179,3 +179,7 @@ class DeepCopyOp(gof.Op):
deep_copy_op = DeepCopyOp() deep_copy_op = DeepCopyOp()
# List of Theano Types that one can add an extra dimension and for which
# Scan can deal with.
expandable_types = ()
...@@ -411,6 +411,7 @@ class CudaNdarrayType(Type): ...@@ -411,6 +411,7 @@ class CudaNdarrayType(Type):
def c_compile_args(self): def c_compile_args(self):
return [] return []
theano.compile.ops.expandable_types += (CudaNdarrayType,)
# Register C code for ViewOp on CudaNdarrayType # Register C code for ViewOp on CudaNdarrayType
theano.compile.register_view_op_c_code( theano.compile.register_view_op_c_code(
......
...@@ -53,6 +53,7 @@ from theano.tensor import opt ...@@ -53,6 +53,7 @@ from theano.tensor import opt
from theano import tensor from theano import tensor
from theano import config from theano import config
from theano.updates import Updates from theano.updates import Updates
from theano.compile import ops
import scan_op import scan_op
...@@ -849,7 +850,7 @@ def scan(fn, ...@@ -849,7 +850,7 @@ def scan(fn,
new_var = safe_new(input.variable) new_var = safe_new(input.variable)
if getattr(input.variable, 'name', None) is not None: if getattr(input.variable, 'name', None) is not None:
new_var.name = input.variable.name + '_copy' new_var.name = input.variable.name + '_copy'
if isinstance(new_var.type, tensor.TensorType): if isinstance(new_var.type, ops.expandable_types):
sit_sot_inner_inputs.append(new_var) sit_sot_inner_inputs.append(new_var)
sit_sot_scan_inputs.append( sit_sot_scan_inputs.append(
scan_utils.expand( scan_utils.expand(
......
...@@ -1076,6 +1076,7 @@ class TensorType(Type): ...@@ -1076,6 +1076,7 @@ class TensorType(Type):
""" """
return numpy.zeros(shape, dtype=self.dtype) return numpy.zeros(shape, dtype=self.dtype)
theano.compile.ops.expandable_types += (TensorType,)
# Register TensorType C code for ViewOp. # Register TensorType C code for ViewOp.
theano.compile.register_view_op_c_code( theano.compile.register_view_op_c_code(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论