Stateless Models in PyTorch
Deep learning libraries handle the weights or parameters of a neural network in the form of a “state”. The way this state is defined is often referred to as the state representation of a model, and it can become crucial in certain problems. Broadly speaking, there are two widely used state representation methods: stateful and stateless.
In this article, I’ll briefly introduce the concept and demonstrate through examples when to use which method when implementing a solution in PyTorch. Some topics related to this article are:
- Meta-learning
- Hypernetworks
- Ensemble Modelling
What are Stateful and Stateless Models?
Getting straight to the point, in PyTorch, the two state representations are defined as follows:
Stateful Models are models whose state (aka the model’s weights) is defined and stored inside the model class definition. When you create an instance of the model, it initializes the wieghts internally. By default, PyTorch models are stateful. Below is an example of a stateful MLP with one hidden layer:
import torch.nn as nn
class StatefulModel(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(StatefulModel, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# Initialize model
model = StatefulModel(10, 50, 1)
# Initialize a random input
x = torch.rand(1, 10)
# Make prediction
pred = model(x)
In the example, the layers are defined as the module attributes of the class, specifically of type nn.Linear
. The definition of each layer contains the set of operations needed for forward propagation and also maintains the state for that layer. As a results, the entire model posseses its complete state which can be access via model.state_dict()
. In Pytorch, model’s weights are stored as an Ordered Dictionary
.
Stateless Models, on the other hand, do not store weights internally. Instead, the weights are passed as input. The example below shows a stateless definition of a model in PyTorch:
import torch.nn as nn
import torch.nn.functional as F
class StatelessModel(nn.Module):
def forward(self, x, weights):
x = F.relu(F.linear(x, weights['fc1.weight'], weights['fc1.bias']))
x = F.linear(x, weights['fc2.weight'], weights['fc2.bias'])
return x
# Initialize model
model = StatelessModel()
# Initialize weights
weights = {
'fc1.weight': torch.randn(50, 10),
'fc1.bias': torch.randn(50),
'fc2.weight': torch.randn(1, 50),
'fc2.bias': torch.randn(1)
}
# Initialize random input
x = torch.rand(1, 10)
# Make prediction
pred = model(x, weights)
As you can see, in the stateless version, there are no module initializations when creating an instance of the class. Instead, the class only defines the forward function that instructs how an input should be mapped to an output through a set of ternsor operations.
Ambiguity
In the context of Recurrent Neural Networks (RNNs), the terms “stateful” and “stateless” can take on a different meaning, and refer to the way the hidden state of the RNN is initialized . In a stateful RNN, the hidden state is carried over from the previous step. In contrast, a stateless RNN (randomly) re-initializes the hidden state at each step.
When to Use Which?
TLDR
“Stateless” models can be used for any type of problem, but they require manual management of the state and often involve more implementation effort. “Stateful” models are easier to implement, but can make life of a researcher or programmer harder in certain use cases.
Some libraries like Flax (JAX+Flax) assume a stateless model definition by default, while in PyTorch standard model definitions are stateful. In general, if your problem has the “standard” training loop as below:
REPEAT:
- get data batch
- forward an input to the model
- compute loss and backpropagate
- update weights
then you’re probably better of choosing the stateful option. In most instances, stateful versions are more convenient because everything is handled as an all-in-one package.
Examples
Despite their user-friendliness, stateful models can be problematic in certain situations. I’ve listed some of these situations, with an illustration for each, below:
When weights of the model are generated by another model, or when the weights are perturbed or changed by an external modifier.
A popular example is the hypernetwork. Hypernetworks generate the weights of another network, usually referred to as the “main model”. The main model cannot maintain its state, instead, it receives it as an input provided by the output of the hypernetwork.
When the graph of the optimization process has to be maintained.
Usually, in the inner loop of meta-optimization methods, we need to keep the computation graph throughout the entire optimization process. To acheive this, the state of the model cannot be stored internally and it needs to be kept externally for reference.
Parallel inference in ensemble models.
Althout its possible to perform inference in ensemble models with stateful version, having multiple copies of the same model for parallel inference can become costly. For example if we need to make prediction with hundreds of models, each with small purterbation applied to a reference model, it can be achieve in more a effcient way with stateless versions.
Stateless Models in PyTorch
Before PyTorch 2.0, converting a stateful model to its stateless equivalent required rewriting the forward function using functionals. An alternative was to use packages such as funchtorch
and higher
that facilitated this conversion through monkeypatching. Luckily, since version 2.0, there is not need for external modules. A stateful model can be called in a stateless manner using torch.nn.utils.stateless.functional_call
as demonstrated in the following example.
from torch.nn.utils.statless import functionall_call
# Using the StatefulModel definition in the first code example
# Initialize model
model = StatefulModel(10, 50, 1)
# Initialize random input
x = torch.rand(1, 10)
# Get the state dict of the model
state_dict = model.state_dict()
# Stateless call
functional_call(model, state_dict, x)
# Make prediction
pred = model(x)
In the example, we don’t convert the model directly. Instead, we utilize the stateful version and call it in a stateless manner. This minor change in the function call can enable a cleaner implementation of methods, which previously relied on third-party libraries in a more intricate way.