提交 9bb6b8d6 authored 作者: James Bergstra's avatar James Bergstra

fix for gpu subtensor code

上级 0c5dc59f
...@@ -1556,6 +1556,10 @@ class TensorFromScalar(Op): ...@@ -1556,6 +1556,10 @@ class TensorFromScalar(Op):
tensor_from_scalar = TensorFromScalar() tensor_from_scalar = TensorFromScalar()
class ScalarFromTensor(Op): class ScalarFromTensor(Op):
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, t): def make_node(self, t):
assert isinstance(t.type, TensorType) assert isinstance(t.type, TensorType)
assert t.type.broadcastable == () assert t.type.broadcastable == ()
...@@ -3086,6 +3090,8 @@ class Subtensor(Op): ...@@ -3086,6 +3090,8 @@ class Subtensor(Op):
@staticmethod @staticmethod
def helper_c_code(node, name, inputs, outputs, sub, idx_list): def helper_c_code(node, name, inputs, outputs, sub, idx_list):
if not isinstance(node.inputs[0].type, TensorType):
raise NotImplementedError()
# #
# two arrays are created: # two arrays are created:
# is_slice: len == ndim, 0 means int, 1 means slice # is_slice: len == ndim, 0 means int, 1 means slice
...@@ -3295,6 +3301,8 @@ class Subtensor(Op): ...@@ -3295,6 +3301,8 @@ class Subtensor(Op):
@staticmethod @staticmethod
def helper_c_code_cache_version(): def helper_c_code_cache_version():
if not isinstance(node.inputs[0].type, TensorType):
return ()
return (2,) return (2,)
def c_code(self, node, name, inputs, outputs, sub): #DEBUG def c_code(self, node, name, inputs, outputs, sub): #DEBUG
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论