当前位置:网站首页>Pytorch calls cublasltmattmul to do GEMM and add bias. It's well written
Pytorch calls cublasltmattmul to do GEMM and add bias. It's well written
2022-07-19 09:36:00 【Eloudy】
template <typename Dtype>
void gemm_and_bias(
bool transpose_mat1,
bool transpose_mat2,
int64_t m,
int64_t n,
int64_t k,
at::opmath_type<Dtype> alpha_val,
const Dtype* mat1_ptr,
int64_t mat1_ld,
const Dtype* mat2_ptr,
int64_t mat2_ld,
const Dtype* bias,
Dtype* result_ptr,
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation) {
using opmath_t = at::opmath_type<Dtype>;
opmath_t beta_val = 0; // bias is added in epilogue
cudaDataType_t abcType = CUDA_R_32F;
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
cudaDataType_t scaleType = CUDA_R_32F;
if (std::is_same<Dtype, double>::value) {
abcType = CUDA_R_64F;
computeType = CUBLAS_COMPUTE_64F;
scaleType = CUDA_R_64F;
} else if (std::is_same<Dtype, float>::value) {
if (at::globalContext().allowTF32CuBLAS()) {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
}
abcType = CUDA_R_32F;
} else if (std::is_same<Dtype, at::Half>::value) {
abcType = CUDA_R_16F;
} else if (std::is_same<Dtype, at::BFloat16>::value) {
abcType = CUDA_R_16BF;
}
CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
cublasOperation_t transa = transpose_mat1 ? CUBLAS_OP_T : CUBLAS_OP_N;
TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute(
computeDesc.descriptor(),
CUBLASLT_MATMUL_DESC_TRANSA,
&transa,
sizeof(transa)));
cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute(
computeDesc.descriptor(),
CUBLASLT_MATMUL_DESC_TRANSB,
&transb,
sizeof(transb)));
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
if (activation == GEMMAndBiasActivationEpilogue::RELU) {
epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
} else if (activation == GEMMAndBiasActivationEpilogue::GELU) {
#if CUDA_VERSION >= 11040
epilogue = CUBLASLT_EPILOGUE_GELU_BIAS;
#endif
}
TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute(
computeDesc.descriptor(),
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue,
sizeof(epilogue)));
TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute(
computeDesc.descriptor(),
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias,
sizeof(Dtype*)));
CuBlasLtMatrixLayout Adesc(
abcType, transpose_mat1 ? k : m, transpose_mat1 ? m : k, mat1_ld);
CuBlasLtMatrixLayout Bdesc(
abcType, transpose_mat2 ? n : k, transpose_mat2 ? k : n, mat2_ld);
CuBlasLtMatrixLayout Cdesc(abcType, m, n, result_ld);
CuBlasLtMatmulPreference preference;
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// setting this to 1M.
size_t workspaceSize = 1024 * 1024;
TORCH_CUDABLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(
preference.descriptor(),
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspaceSize,
sizeof(workspaceSize)));
auto workspace = at::empty(
{static_cast<int64_t>(workspaceSize)},
at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte));
cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0;
cublasLtHandle_t ltHandle =
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
ltHandle,
computeDesc.descriptor(),
Adesc.descriptor(),
Bdesc.descriptor(),
Cdesc.descriptor(),
Cdesc.descriptor(),
preference.descriptor(),
1,
&heuristicResult,
&returnedResult));
if (returnedResult == 0) {
TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED);
}
TORCH_CUDABLAS_CHECK(cublasLtMatmul(
ltHandle,
computeDesc.descriptor(),
&alpha_val,
mat1_ptr,
Adesc.descriptor(),
mat2_ptr,
Bdesc.descriptor(),
&beta_val,
result_ptr,
Cdesc.descriptor(),
result_ptr,
Cdesc.descriptor(),
&heuristicResult.algo,
workspace.data_ptr(),
workspaceSize,
at::cuda::getCurrentCUDAStream()));
}
边栏推荐
- [fishing artifact] UI library second low code tool - form part (II) sub control
- C语言力扣第25题之k个一组反转链表。多指针遍历
- Fundamentals of C language -- 2-1 pointer and wild pointer
- MySQL -- SQL optimization case -- implicit character encoding conversion
- MySQL 用户管理
- how to use culasLt
- Componentized advanced -- slot
- DEDECMS织梦文章列表标题重复显示解决方案
- 【网络研究院】机器学习系统的威胁是时候该认真对待了
- Day 5 training
猜你喜欢

565. 数组嵌套

MySQL user management

Anycontrol demo demo demo

Add - before the command in makefile to ignore the error caused by the command and continue to execute the next command

银河麒麟v10-arm版离线安装Portainer

LDA classifier

18、shell脚本编程(1)

ETH的拐点可能指日可待,这就是如何

Flink small knowledge -- configuration of task scheduling slots slotsharinggroup

Day 6 training
随机推荐
Mutual access between components
【C语言】数组知识点总结
Static routing!! Static routing!! Static routing!!
Go exceed API source code reading (II) -- openFile ()
如何正确执行Jedis单元测试
组件间的相互访问
开发第一个Flink应用
v-mode
[performance optimization methodology series] VI. summary
Anaconda and jupyter notebook entry level detailed tutorial
MySQL升级为主备,如何同步历史数据?
目标检测模型大小计算,模型复杂度(参数换算公式)
Detailed explanation of the different usage of PIP and PIP3
2022.7.16-----leetcode. Sword finger offer 041
Part I - Fundamentals of C language_ 4. Procedure flow structure
ArrayList底层分析
Series operation of vector container (detailed explanation)
EBay searches eBay products by keyword API return value description
Chapter VIII vector of STL
The inflection point of eth may be just around the corner, which is how to