Learning cos(x) using torch::Sequential from the PyTorch C++ API
This post covers the use of the PyTorch C++ API for approximating a function of a single variable using a Neural Network (NN). We’ll use the Sequential
container to build the NN without using a lot of C++ and train the NN on $(x, cos(x))$ data.
The Sequential
container in PyTorch
The Sequential
container is used to chain a sequence of PyTorch modules, i.e. layers of a Neural Network (NN), into a sequence (a list). The Sequential
container then forwards transformations to the entire sequence of modules without writing additional code. For example to
will be applied on each module in Sequential
and a call to forward
will be automatically chained.
This way, Sequential
makes it possible to chain network layers and activation functions to build neural networks (NNs) without writing structs or classes in C++.
Unnamed modules are chained inside the Sequential
container like this:
torch::nn::Sequential model (
torch::nn::Linear(1,64),
torch::nn::Tanh(),
torch::nn::Linear(64,1)
);
This creates a NN with a single input, two hidden layers with 64
neurons, a tanh
activation function between them, and an output layer with a single scalar output.
Named sub-modules can be initialized into Sequential
model like this
torch::nn::Sequential model ({
{"in", torch::nn::Linear(1,64)},
{"tanh1", torch::nn::Tanh()},
{"out", torch::nn::Linear(64,1)}
});
or added using Sequential::push_back
. Multiple Sequential
models can be joined together into larger networks.
The torch::nn::Sequential
type is actually a std::shared_ptr
that can be dereferenced for convenience
auto& model_ref = *model;
Then, Sequential
member functions can be used with the .
access operator
model_ref.pretty_print(cout);
making it easier to loop over sub-modules using range loops.
Note: torch::Sequential
container stores sub-modules as torch::AnyModule type:
The PyTorch C++ API does not impose an interface on the signature of forward() in Module subclasses. This gives you complete freedom to design your forward() methods to your liking. However, this also means there is no unified base type you could store in order to call forward() polymorphically for any module. This is where the AnyModule comes in. Instead of inheritance, it relies on type erasure for polymorphism.
To access the underlying module type, we need torch::AnyModule::ptr()
member function, then we can see the names and parameters of the sub-moduels of Sequential
auto& model_ref = *model;
for (auto module : model_ref)
{
cout << module.ptr()->name() << endl;
cout << module.ptr()->parameters() << endl;
}
Training a PyTorch Sequential
model on $cos(x)$
We will train the model on the $cos(x)$ function. To do this, the periodicity of $cos(x)$ is used: if $f(x + T) = f(x)$, then $f(x)$ is a periodic function with a period $T$. This means, we can train the model on the interval $[0, T]$ and then wrap the trained model so that model_periodic(x + T)
returns model(x)
. Otherwise, we could try and train the model on $cos(x)$ using many periods, but this would not lead to much success.
Generating input data
We take a sequence of $1000$ samples from $[0, \pi]$.
// Preparing sample data
static const int N_SAMPLES = 1000;
torch::Tensor x_sequence = torch::linspace(0, 3*M_PI, N_SAMPLES);
// Reshape and save x_sequence
x_sequence = x_sequence.reshape({N_SAMPLES,1});
torch::save(x_sequence, "x_sequence.pt");
// Reshape and save y_sequence
torch::Tensor y_sequence = torch::cos(x_sequence);
torch::save(y_sequence, "y_sequence.pt");
The reason behind using $3\pi$ is described in the results section together with other NN hyperparameters.
The 1D $x$ and $cos(x)$ sample sequences need to be resized to the shape $(N_{samples}, 1)$ so that they can be accepted as input by the first and last module in the model that only have a single (scalar) neuron.
Next, we create the training sequence as a subset of the input sequence
// SAMPLE DATA: x, cos(x)
// Training set 70 / 30 split
torch::Tensor shuffled_indices = torch::randperm(
N_SAMPLES,
torch::TensorOptions().dtype(at::kLong)
);
auto n_val = int (0.7 * N_SAMPLES);
torch::Tensor training_indices =
shuffled_indices.index({Slice(0, n_val)});
torch::Tensor x_training = x_sequence.index(training_indices);
torch::Tensor y_training = y_sequence.index(training_indices);
Note that we need an integer
type for the sequence of indices that index $x$ and $cos(x)$ sequences, and the defaul type of any tensor created in PyTorch is always float
, so we need to pass the type argument at::kLong
to dtype
to make sure shuffled_indices
are integers.
Training the model
We select the optimizer that we’ll use to minimize the error, and select its’s step. We also need sequences to store the model prediction as well as loss values
// TRAIN THE MODEL ON THE TRAINING SET
torch::optim::Adam optimizer(model->parameters(), 0.01);
torch::Tensor training_prediction = torch::zeros_like(x_training);
torch::Tensor loss_values = torch::zeros_like(x_training);
Here is the training loop:
ofstream conv_file ("convergence_data.csv");
conv_file << "max_loss\n";
for (size_t epoch = 1; epoch <= 1000; ++epoch)
{
optimizer.zero_grad();
training_prediction = model->forward(x_training);
loss_values = torch::mse_loss(training_prediction, y_training);
loss_values.backward();
optimizer.step();
// Report the error with respect to y_training.
double max_loss = loss_values.max().item<double>();
cout << "Epoch " << epoch
<< ", max(loss_values) = " << max_loss << endl;
conv_file << max_loss << "\n";
}
PyTorch accumulates weight gradients of the network on subsequent backward propagations , so optimizer.zero_grad();
is called to zero the gradients in order to ensure previous passes do not influence the direction of the gradient. From the official documentation:
torch.Tensor is the central class of PyTorch. When you create a tensor, if you set its attribute .requires_grad as True, the package tracks all operations on it. This happens on subsequent backward passes. The gradient for this tensor will be accumulated into .grad attribute. The accumulation (or sum) of all the gradients is calculated when .backward() is called on the loss tensor.
Validating the model
For the validation, we use the last 30% of the randomly shuffled indices
// VALIDATE THE MODEL WITH THE VALIDATION SET
torch::Tensor validation_indices =
shuffled_indices.index({Slice(n_val+1, N_SAMPLES)});
torch::Tensor x_validation =
x_sequence.index(validation_indices);
torch::Tensor y_validation =
y_sequence.index(validation_indices);
torch::Tensor validation_values =
model->forward(x_validation);
torch::Tensor validation_loss =
torch::mse_loss(validation_values, y_validation);
cout << "Validation max(validation_loss) = "
<< torch::max(validation_loss) << endl;
// REPORT THE PREDICTION OVER COMPLETE INPUT
torch::Tensor y_model_sequence = model->forward(x_sequence);
torch::save(y_model_sequence, "y_model_sequence.pt");
Results
In the preparation of the training set we didn’t use $T=2\pi$ for the $cos(x)$ function, instead we used $3\pi$. The reason behind this is the asymmetry I found in the approximation of $cos(x)$ with this NN over $[0,2\pi]$.
The NN seems to not capture the non-linearity of $cos(x)$ at the right end of the interval $[0,2\pi]$. The training and validation $L_\infty$ norm of the MSE loss are
Epoch 1000, max(loss_values) = 0.000760586
Validation max(validation_loss) = 0.000676395
So it seems the NN does not overfit and since the model is missing accuracy near $2\pi$ one could think that it simply needs more neurons (increase in nonlinearity). I tried adding more neurons, more hidden layers, changing the loss function (to smooth $L_1$) and the activation functions: this changes the results but the asymmetry in the approximation remains.
Interestingly, if the same model is trained on $[0,4\pi]$, this is the result:
with
Epoch 1000, max(loss_values) = 0.00219363
Validation max(validation_loss) = 0.0024011
And the convergence behavior in terms of $\max(MSE(y_{validation}))$
The validation error is too large because the asymmetry is this time at $4\pi$, but the the same model seems to capture the nonlinearity on $[0,2\pi]$. Since the point of this post was to document torch::Sequential
, I’ll leave out further hyperparameter tuning of the NN for this example.
Thanks to Andre Weiner for the help with hyperparameters and the model IO!
Data & Code
The application is in applications/test/aiFoamTestPyTorch
. You can compile the application with wmake
if you have OpenFOAM installed and run it anywhere, or with cmake
?> mkdir build && cd build && cmake .. && make
and run ./aiFoamTestPyTorch
in the build
directory.
When compiling it within OpenFOAM, make sure OpenFOAM is built with support for C++14, by changing -std=c++14
or newer -std=c++2a
in
$WM_PROJECT_DIR/wmake/rules/General/Gcc/c++
if you are using the gcc compiler, or another sub-folder of General
for the compiler you use for building OpenFOAM.