0%

pytorch训练的模型导出数据或者部署到C++平台

pytorch训练的模型导出网格数据或者部署到C++平台

数据导出

  • 导出为.csv文件
  • 参考教程
# 定义和实例化网络
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1,30,kernel_size=5)
self.fc1 = nn.Linear(4320,100)
self.fc2 = nn.Linear(100, 10)
def forward(self,x):
x = F.max_pool2d(F.relu(self.conv1(x)),2)
x = x.view(-1, 4320)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x,dim=1)
network = Net()

# 保存.csv格式的参数
for name,param in network.named_parameters():
print(f"name:{name}\t\t\t,shape:{param.shape}")
data = pd.DataFrame(param.detach().numpy().reshape(1, -1))
filename = f"{name}.csv"
data.to_csv(f"./{filename}", index=False, header=False, sep=',')

使用Libtorch直接导出C++模型