提交 195517c4 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5060 from abergeron/fix_advsub_dtypes

GpuAdvancedIncSubtensor1 supports mixed dtypes.
...@@ -994,19 +994,12 @@ def local_gpua_advanced_incsubtensor(op, context_name, inputs, outputs): ...@@ -994,19 +994,12 @@ def local_gpua_advanced_incsubtensor(op, context_name, inputs, outputs):
x, y, ilist = inputs x, y, ilist = inputs
# Gpu Ops needs both inputs to have the same dtype
if (x.type.dtype != y.type.dtype):
dtype = scalar.upcast(x.type.dtype, y.type.dtype)
if x.type.dtype != dtype:
x = tensor.cast(x, dtype)
if y.type.dtype != dtype:
y = tensor.cast(y, dtype)
set_instead_of_inc = op.set_instead_of_inc set_instead_of_inc = op.set_instead_of_inc
compute_capability = int(context.bin_id[-2]) compute_capability = int(context.bin_id[-2])
if (compute_capability < 2 or x.ndim != 2 or y.ndim != 2): if (compute_capability < 2 or x.ndim != 2 or y.ndim != 2 or
x.type.dtype != y.type.dtype):
return GpuAdvancedIncSubtensor1( return GpuAdvancedIncSubtensor1(
set_instead_of_inc=set_instead_of_inc) set_instead_of_inc=set_instead_of_inc)
else: else:
......
...@@ -599,7 +599,6 @@ class GpuAdvancedIncSubtensor1(Op): ...@@ -599,7 +599,6 @@ class GpuAdvancedIncSubtensor1(Op):
y_ = as_gpuarray_variable(y, ctx_name) y_ = as_gpuarray_variable(y, ctx_name)
ilist_ = tensor.as_tensor_variable(ilist) ilist_ = tensor.as_tensor_variable(ilist)
assert x_.type.dtype == y_.type.dtype
assert x_.type.ndim >= y_.type.ndim assert x_.type.ndim >= y_.type.ndim
if ilist_.type.dtype[:3] not in ('int', 'uin'): if ilist_.type.dtype[:3] not in ('int', 'uin'):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论