提交 5f470abb authored 作者: Razvan Pascanu's avatar Razvan Pascanu

list of expandable types used by scan for shared variables

上级 56f011a4
......@@ -179,3 +179,7 @@ class DeepCopyOp(gof.Op):
deep_copy_op = DeepCopyOp()
# List of Theano Types that one can add an extra dimension and for which
# Scan can deal with.
expandable_types = ()
......@@ -411,6 +411,7 @@ class CudaNdarrayType(Type):
def c_compile_args(self):
return []
theano.compile.ops.expandable_types += (CudaNdarrayType,)
# Register C code for ViewOp on CudaNdarrayType
theano.compile.register_view_op_c_code(
......
......@@ -53,6 +53,7 @@ from theano.tensor import opt
from theano import tensor
from theano import config
from theano.updates import Updates
from theano.compile import ops
import scan_op
......@@ -849,7 +850,7 @@ def scan(fn,
new_var = safe_new(input.variable)
if getattr(input.variable, 'name', None) is not None:
new_var.name = input.variable.name + '_copy'
if isinstance(new_var.type, tensor.TensorType):
if isinstance(new_var.type, ops.expandable_types):
sit_sot_inner_inputs.append(new_var)
sit_sot_scan_inputs.append(
scan_utils.expand(
......
......@@ -1076,6 +1076,7 @@ class TensorType(Type):
"""
return numpy.zeros(shape, dtype=self.dtype)
theano.compile.ops.expandable_types += (TensorType,)
# Register TensorType C code for ViewOp.
theano.compile.register_view_op_c_code(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论