TensorFlow c++代码案例
时间: 2023-07-09 21:35:22 浏览: 195
以下是一个使用 TensorFlow C++ API 实现简单的线性回归模型的示例代码:
```cpp
#include <iostream>
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor.h"
using namespace std;
using namespace tensorflow;
using namespace tensorflow::ops;
int main() {
// 构建训练数据
Tensor x(DT_FLOAT, TensorShape({3, 1}));
auto x_map = x.tensor<float, 2>();
x_map(0, 0) = 1.0;
x_map(1, 0) = 2.0;
x_map(2, 0) = 3.0;
Tensor y(DT_FLOAT, TensorShape({3, 1}));
auto y_map = y.tensor<float, 2>();
y_map(0, 0) = 2.0;
y_map(1, 0) = 4.0;
y_map(2, 0) = 6.0;
// 定义模型
auto scope = Scope::NewRootScope();
auto X = Placeholder(scope, DT_FLOAT);
auto Y = Placeholder(scope, DT_FLOAT);
auto W = Variable(scope, {1, 1}, DT_FLOAT);
auto b = Variable(scope, {1}, DT_FLOAT);
auto Y_hat = Add(scope, MatMul(scope, X, W), b);
auto loss = Mean(scope, Square(scope, Sub(scope, Y_hat, Y)), {0, 1});
// 构建优化器
auto train_op = GradientDescentOptimizer(scope, 0.01f).Minimize(loss);
// 运行会话
ClientSession session(scope);
TF_CHECK_OK(session.Run({Assign(scope, W, {{0.0f}}), Assign(scope, b, {0.0f})}, nullptr));
for (int i = 0; i < 100; i++) {
TF_CHECK_OK(session.Run({{X, x}, {Y, y}}, {train_op}, nullptr));
}
// 打印结果
auto W_tensor = W.eval(session);
auto b_tensor = b.eval(session);
cout << "W: " << W_tensor.tensor<float, 2>()(0, 0) << endl;
cout << "b: " << b_tensor.tensor<float, 1>()(0) << endl;
return 0;
}
```
这段代码实现了一个简单的线性回归模型,使用 TensorFlow C++ API 定义了模型、损失函数和优化器,并在训练数据上进行了训练。最后输出训练得到的模型参数。
阅读全文