【昇腾CANN训练营】aclnn调用中的C++模板编程
1 C++模板简介
在编程世界中,C++模板就像是一个神奇的模具,能够根据需要塑造出不同形态的函数和类。这种机制的精妙之处在于,它允许开发者用一套代码应对各种数据类型,而不必为整数、浮点数或是自定义类型分别编写重复的逻辑。想象一下,如果每次处理新数据类型都要重写一遍相似的代码,那将是多么低效的事情。正是为了避免这种"重复造轮子"的情况,模板应运而生。
让我们从函数模板这个最基础的概念说起。通过引入类型参数T,函数模板实现了所谓的泛型编程——这里的T就像是一个万能容器,可以容纳任何数据类型。当编译器遇到模板代码时,会根据实际调用的数据类型,自动生成对应的特化版本。这个过程发生在编译阶段,既保证了类型安全,又不会带来运行时开销。
举个简单的例子,假设我们需要一个比较两个值大小的函数。如果没有模板,就不得不为int、float、double等类型分别编写几乎相同的代码。而使用模板后,只需定义一次,就能适用于所有可比较的类型。这种优雅的解决方案背后,是C++强大的编译时多态机制在发挥作用。
简而言之,模板用于创建通用函数/类,支持多种数据类型。以下是一个简单的函数模板的代码示例,函数模板通过类型参数 T 实现泛型编程。
#include <iostream>
using namespace std;
// 声明函数模板
template <typename T>
T add(T a, T b) {
return a + b;
}
int main() {
cout << add<int>(3, 5) << endl; // 输出8
cout << add<double>(2.5, 3.7) << endl; // 输出6.2
return 0;
}
2 aclnn调用中的模版编程
我们以 matmul_all_reduce 算子为例,来看看在aclnn调用中的模版编程。
GetInputBuffer 模板函数
GetInputBuffer 模板函数位于 class OpRunner 中,当调用者请求某个输入缓冲区时,这个方法会先检查索引是否越界——这是基本的防御性编程。值得注意的是 reinterpret_cast<T*> 这个操作,它就像是一个万能钥匙,能够将原始的字节流按照模板参数T指定的类型进行解读。这种设计既保持了接口的统一性,又不会牺牲类型安全,正是模板编程的精妙所在。
template<typename T>
auto GetInputBuffer(size_t index) -> T*
{
if (index >= numInputs_) {
ERROR_LOG("index out of range. index = %zu, numInputs = %zu", index, numInputs_);
return nullptr;
}
return reinterpret_cast<T *>(hostInputs_[index]);
}
多线程调度部分
在主函数部分,std::vector<std::unique_ptr<std::thread>> 这个结构生成了一个线程容器,每个线程都被封装在智能指针中,这种层层嵌套的模板使用方式,既保证了线程对象的生命周期管理,又避免了内存泄漏的风险。循环中创建线程的方式,则是通过 threads[rankId].reset 这种操作,为每个线程分配明确的任务——执行 RunOp 函数,并传递特定的 rankId 和通信对象。
int main(int argc, char **argv)
{
if (!InitResource()) {
ERROR_LOG("Init resource failed");
return FAILED;
}
INFO_LOG("Init resource success");
HcclComm comms[RANK_DIM];
int32_t devices[RANK_DIM];
for (int32_t i = 0; i < RANK_DIM; i++) {
devices[i] = i;
}
if (HcclCommInitAll(RANK_DIM, devices, comms) != HCCL_SUCCESS) {
ERROR_LOG("Hccl comm init failed.");
(void)aclFinalize();
return FAILED;
}
// run with multithread
std::vector<std::unique_ptr<std::thread>> threads(RANK_DIM);
for (uint32_t rankId = 0; rankId < RANK_DIM; rankId++) {
threads[rankId].reset(new(std::nothrow) std::thread(&RunOp, rankId, std::ref(comms[rankId])));
}
for (uint32_t rankId = 0; rankId < RANK_DIM; rankId++) {
threads[rankId]->join();
}
(void)aclFinalize();
return SUCCESS;
}
这就是aclnn调用中的C++模板编程部分,它体现了模板的核心价值:通过编译时多态实现代码复用,同时严格维护类型系统的完整性。而且不同于运行时多态需要付出虚函数调用的开销,模板的所有解析都在编译阶段完成,既高效又安全。
- 点赞
- 收藏
- 关注作者
评论(0)