Natural Language Processing - 2025s2
  • Home
  • Final Project
  • Text Classification
    • 00 - Review - Regular expressions
    • 01 - Practice - Sentiment Analysis with ANEW
    • 02 - Theory - Math of Logistic Regression
    • 02a - Theory - Supplementary Material
    • 03 - Case Study - Classification on IMDB
    • 04 - Practice - Cross-dataset Classification
    • 05 - Practice - Detecting Fake News
  • Language Models
    • 00 - Theory - Language Models
    • 00a - Solution for exercises in 00
    • 01 - Case Study - Language Models
    • 02 - Theory - From Sklearn to Pytorch
    • 03 - Theory - MLP, Residuals, Normalization
    • 05 - Practice - Tokenizers, Classification and Visualization
    • 06 - Theory - Self-Attention and Self-Supervised Training
    • 07 - Case Study - Pre-trained BERT
    • 08 - Practice - Fine-tuning BERT
  • Search
  • Previous
  • Next
  • Multi-layer perceptrons
    • A simple dataset: a rotation + translation
    • A more complicated dataset: linear by parts
    • Multi Layer Perceptron (MLP) models
    • Why MLP? A small example
    • The Vanishing Gradient problem
    • Residual blocks
    • Normalization
    • Conclusion
  • Practice

Multi-layer perceptrons¶

So far, we have been using Logistic Regression for all our classification needs. Logistic regression is very similar to linear regression, except for that $\sigma(z)$ in the end - it is essentially a linear projection and a choice between the "positive" and the "negative" sides of the projection surface.

Also, we have seen that we can choose to project our data $X$ into an intermediate representation $z$ so that $z$ has more than one dimension. We can use that for multi-class classification.

Now, we are going to view the effects of mapping the intermediate projection $z$ to another intermediate projection (let's call it $z_2$). As we will see, increasing the dimensionality of each representation $z_i$ and the number of intermediate projections has the effect of creating intermediate regions in which we can apply linear transformations separately. If you want a theoretical reference for such, refer to: Thiago Serra, Christian Tjandraatmadja, Srikumar Ramalingam Proceedings of the 35th International Conference on Machine Learning, PMLR 80:4558-4566, 2018..

In the examples shown here, we will start with the use of Linear Regression to show some limitations of this approach. These animated examples work like this:

  1. We define a random dataset $X$
  2. We define a target $y$ by applying some function over $X$
  3. We initialize a prediction model with an identity function (that is, our weights are such that )
  4. We train the prediction model to predict $\hat{y} = f(X)$ using gradient descent, and store $\hat{y}_t$ for each iteration $t$
  5. We make an animation of all $\hat{y}_t$ so we can see what happens to our predictions.
In [ ]:
Copied!
import torch
import torch.nn as nn
from tqdm import tqdm
import pandas as pd
torch.manual_seed(20)
import torch import torch.nn as nn from tqdm import tqdm import pandas as pd torch.manual_seed(20)
In [2]:
Copied!
def train_model(model, X, y, lr=0.01, epochs=100):
    
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # Training loop
    outputs = [X.numpy()]
    for epoch in tqdm(range(epochs)):
        # Forward pass
        predictions = model(X)
        outputs.append(predictions.detach().numpy())
        loss = criterion(predictions, y)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return model, outputs

def animate_training(outputs, y, frame_duration=5, title='Training Animation'):
    import pandas as pd
    import plotly.express as px

    # Convert outputs to a pandas DataFrame for easier plotting
    frames = []
    for i, output in enumerate(outputs):
        df = pd.DataFrame(output, columns=['Feature 1', 'Feature 2'])
        df['Frame'] = i  # Add a frame identifier
        frames.append(df)

    total_frames = len(frames)
    n = total_frames // 100
    frames = frames[::n]  # Get every nth frame
    
    # Concatenate all frames into a single DataFrame
    animated_df = pd.concat(frames, ignore_index=True)


    fig = px.scatter(

        animated_df,
        width=600,
        height=600,
        x='Feature 1',
        y='Feature 2',
        animation_frame='Frame',
        title=title,
        labels={'Feature 1': 'Feature 1', 'Feature 2': 'Feature 2'}
    )

    # Add a scatterplot of y
    scatter_y = pd.DataFrame(y.numpy(), columns=['Feature 1', 'Feature 2'])
    scatter_y['Frame'] = -1  # Use -1 to indicate the original data
    for _, row in scatter_y.iterrows():
        fig.add_trace(px.scatter(
            pd.DataFrame([row]),
            x='Feature 1',
            y='Feature 2',
            color='Frame',

        ).data[0])

    # Adjust animation speed
    fig.layout.updatemenus[0].buttons[0].args[1]["frame"]["duration"] = frame_duration  # Set duration in milliseconds
    fig.update_layout(coloraxis_showscale=False)
    return fig
def train_model(model, X, y, lr=0.01, epochs=100): criterion = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=lr) # Training loop outputs = [X.numpy()] for epoch in tqdm(range(epochs)): # Forward pass predictions = model(X) outputs.append(predictions.detach().numpy()) loss = criterion(predictions, y) # Backward pass and optimization optimizer.zero_grad() loss.backward() optimizer.step() return model, outputs def animate_training(outputs, y, frame_duration=5, title='Training Animation'): import pandas as pd import plotly.express as px # Convert outputs to a pandas DataFrame for easier plotting frames = [] for i, output in enumerate(outputs): df = pd.DataFrame(output, columns=['Feature 1', 'Feature 2']) df['Frame'] = i # Add a frame identifier frames.append(df) total_frames = len(frames) n = total_frames // 100 frames = frames[::n] # Get every nth frame # Concatenate all frames into a single DataFrame animated_df = pd.concat(frames, ignore_index=True) fig = px.scatter( animated_df, width=600, height=600, x='Feature 1', y='Feature 2', animation_frame='Frame', title=title, labels={'Feature 1': 'Feature 1', 'Feature 2': 'Feature 2'} ) # Add a scatterplot of y scatter_y = pd.DataFrame(y.numpy(), columns=['Feature 1', 'Feature 2']) scatter_y['Frame'] = -1 # Use -1 to indicate the original data for _, row in scatter_y.iterrows(): fig.add_trace(px.scatter( pd.DataFrame([row]), x='Feature 1', y='Feature 2', color='Frame', ).data[0]) # Adjust animation speed fig.layout.updatemenus[0].buttons[0].args[1]["frame"]["duration"] = frame_duration # Set duration in milliseconds fig.update_layout(coloraxis_showscale=False) return fig

A simple dataset: a rotation + translation¶

A linear regression is capable to find the correct rotation and translation of a dataset. This is because rotations and translations can be immediately expressed by the linear prediction equation $y = xw^t + b$ - in this case, the weight matrix $w$ can be constructed from a rotation, and $b$ corresponnds to the translation:

In [3]:
Copied!
# Create a mock dataset
X = torch.randn(100, 2) * 5  # 100 samples, 2 feature
theta = torch.tensor(30.0 * torch.pi / 180.0)  # Convert degrees to radians
rotation_matrix = torch.tensor([
    [torch.cos(theta), -torch.sin(theta)],
    [torch.sin(theta), torch.cos(theta)]
])
y = X @ rotation_matrix + torch.tensor([5,5]) + 0.01 * torch.randn(100, 2)  # y = 3x + 2 + noise
# Create a mock dataset X = torch.randn(100, 2) * 5 # 100 samples, 2 feature theta = torch.tensor(30.0 * torch.pi / 180.0) # Convert degrees to radians rotation_matrix = torch.tensor([ [torch.cos(theta), -torch.sin(theta)], [torch.sin(theta), torch.cos(theta)] ]) y = X @ rotation_matrix + torch.tensor([5,5]) + 0.01 * torch.randn(100, 2) # y = 3x + 2 + noise

¶

In [4]:
Copied!
# Initialize the model, loss function, and optimizer
input_size = 2
output_size = 2
linear_model = nn.Linear(
    in_features=2,
    out_features=2,
)
linear_model.weight.data = torch.eye(2)  # Initializing with identity
linear_model.bias.data = torch.zeros(2)  # Initializing bias to zero

model, outputs = train_model(
    model=linear_model,
    X=X,
    y=y,
    lr=0.01,
    epochs=1000
)

fig = animate_training(outputs, y, title='Linear data, linear model')
fig.show() 
# Initialize the model, loss function, and optimizer input_size = 2 output_size = 2 linear_model = nn.Linear( in_features=2, out_features=2, ) linear_model.weight.data = torch.eye(2) # Initializing with identity linear_model.bias.data = torch.zeros(2) # Initializing bias to zero model, outputs = train_model( model=linear_model, X=X, y=y, lr=0.01, epochs=1000 ) fig = animate_training(outputs, y, title='Linear data, linear model') fig.show()
100%|██████████| 1000/1000 [00:03<00:00, 257.16it/s]

A more complicated dataset: linear by parts¶

Now, let's get a more complicated dataset. Now, $X$ (our input) will have three different clusters, and we will apply a different linear transform in each cluster.

In [5]:
Copied!
# Create a mock dataset
X1 = torch.randn(100, 2) + torch.tensor([3,-3])   # 100 samples, 2 feature
theta = torch.tensor(30.0 * torch.pi / 180.0)  # Convert degrees to radians
y1 = 3*X1 + 0.01 * torch.randn(100, 2)  # y = 3x + 2 + noise

# Create a mock dataset
X2 = torch.randn(100, 2)  # 100 samples, 2 feature
theta = torch.tensor(150.0 * torch.pi / 180.0)  # Convert degrees to radians
rotation_matrix2 = torch.tensor([
    [torch.cos(theta), -torch.sin(theta)],
    [torch.sin(theta), torch.cos(theta)]
])
y2 = (X2 @ rotation_matrix2) + torch.tensor([-3,3]) + 0.01 * torch.randn(100, 2)  # y = 3x + 2 + noise

# Create a mock dataset
X3 = torch.randn(100, 2) + torch.tensor([3,3]) # 100 samples, 2 feature
theta = torch.tensor(150.0 * torch.pi / 180.0)  # Convert degrees to radians
y3 = -5*X3 - 0.01 * torch.randn(100, 2)  # y = 3x + 2 + noise


X = torch.cat((X1, X2, X3), dim=0)
y = torch.cat((y1, y2, y3), dim=0)
# Create a mock dataset X1 = torch.randn(100, 2) + torch.tensor([3,-3]) # 100 samples, 2 feature theta = torch.tensor(30.0 * torch.pi / 180.0) # Convert degrees to radians y1 = 3*X1 + 0.01 * torch.randn(100, 2) # y = 3x + 2 + noise # Create a mock dataset X2 = torch.randn(100, 2) # 100 samples, 2 feature theta = torch.tensor(150.0 * torch.pi / 180.0) # Convert degrees to radians rotation_matrix2 = torch.tensor([ [torch.cos(theta), -torch.sin(theta)], [torch.sin(theta), torch.cos(theta)] ]) y2 = (X2 @ rotation_matrix2) + torch.tensor([-3,3]) + 0.01 * torch.randn(100, 2) # y = 3x + 2 + noise # Create a mock dataset X3 = torch.randn(100, 2) + torch.tensor([3,3]) # 100 samples, 2 feature theta = torch.tensor(150.0 * torch.pi / 180.0) # Convert degrees to radians y3 = -5*X3 - 0.01 * torch.randn(100, 2) # y = 3x + 2 + noise X = torch.cat((X1, X2, X3), dim=0) y = torch.cat((y1, y2, y3), dim=0)

When we try to approximate this using a linear layer, we obviously can't. This is due to our data being more complicated than the model - or, in other words, our model is not expressive enough to model our data. In the animation, we clearly see that the linear layer can only apply the same transform to all points in the input vector space, hence they all "bend" in the same way.

Our model is unable, for example, to model the different cluster variances generated by the different multiplications applied when we generated each part of $y$.

In [6]:
Copied!
# Initialize the model, loss function, and optimizer
input_size = 2
output_size = 2
linear_model = nn.Linear(
    in_features=2,
    out_features=2,
)
linear_model.weight.data = torch.eye(2)  # Initializing with identity
linear_model.bias.data = torch.zeros(2)  # Initializing bias to zero

model, outputs = train_model(
    model=linear_model,
    X=X,
    y=y,
    lr=0.01,
    epochs=500
)

fig = animate_training(outputs, y, title='Linear by parts data, linear model')
fig.show()
# Initialize the model, loss function, and optimizer input_size = 2 output_size = 2 linear_model = nn.Linear( in_features=2, out_features=2, ) linear_model.weight.data = torch.eye(2) # Initializing with identity linear_model.bias.data = torch.zeros(2) # Initializing bias to zero model, outputs = train_model( model=linear_model, X=X, y=y, lr=0.01, epochs=500 ) fig = animate_training(outputs, y, title='Linear by parts data, linear model') fig.show()
100%|██████████| 500/500 [00:00<00:00, 2659.77it/s]

Exercise

Why did we get perfectly fitting points in the first example, but we were unable to get perfectly fitting points in the second example? In the first example, data was generated using the exact same model as the one we used for prediction. This means that the prediction was able to find the exact model used for data generation.

However, in the second example, data was generated using a model with more degrees of freedom (or: a more expressive model) than the one we used for prediction. As a consequence, it finds an approximation that minimizes MSE, but cannot find the exact match. This is the same behavior as trying to fit a linear (y=ax+b) model to a polynomial or an exponential curve.

Multi Layer Perceptron (MLP) models¶

A possible upgrade to the linear model is the MLP model. The MLP model is:

$$ \hat{y} = f(xw_1^T+b_1)w_2^t+b_2, $$ where $f$ is the Rectifying Linear Unit (ReLU) function given by $f(z)=0, z<0, f(z)=z, z>0$.

We can interpret this equation as two layers of linear projections, separated by a non-linear operation, that is:

$$ \hat{y} = (xw_1^T+b_1) \circ f(x) \circ (xw_2^t+b_2), $$ where $\circ$ denotes a function aggregation.

We can draw it like this:

flowchart LR; subgraph MLP; direction LR; L1[Linear] --> ReLU --> L2[Linear] end; Input --> MLP --> Output

Why MLP? A small example¶

Let's suppose we have two 1-dimensional inputs: $x_1$ and $-x_2$, where $x_1$ and $x_2$ are real and positive. Our network has 1-d outputs as well.

For simplicity, let's assume $b_1=b_2=0$

$w_1$ will be equal to $\begin{bmatrix} 1 \\ -1 \end{bmatrix}$. Thus, $xw_1^T=\begin{bmatrix} x_1 & -x_2 \\ -x_1 & x_2 \end{bmatrix}$.

Now, after applying $f(.)$ to $xw_1^T$, we have:

$$ f(xw_1^T)=\begin{bmatrix} x_1 & 0 \\ 0 & x_2 \end{bmatrix} $$

This is important: each one of our rows are independent of each input!

Now, note that $w_2$ must be 2x1 (let's say it is equal to $[c, d]$), and:

$$ y = \begin{bmatrix} x_1 & 0 \\ 0 & x_2 \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix} = \begin{bmatrix} c x_1 \\ d x_2 \end{bmatrix} $$

Now, importantly: in our model, the first input received a scale of $c$, while the second input received a scale of $d$.

Thus, the model operates in two layers. In the first layer, it divides the inputs into groups; in the second layer, it applies a different linear transform for each group.

Exercise

Ok, now do it yourself.

assume:

$$ x = \begin{bmatrix} 1 & 1 \\ 2 & 2 \\ -1 & -1 \\ -2 & -2 \end{bmatrix} $$

Assume your first weight matrix is:

$$ w_1 = \begin{bmatrix} 1 & 1 \\ -1 & -1 \end{bmatrix} $$

First, calculate $z_1 = x w_1^T$.

Answer here

z_1 = \begin{bmatrix} 2 & -2 \\ 4 & -4 \\ -2 & 2 \\ -4 & 4 \end{bmatrix}

Then, calculate $y_1 = \text{ReLU}(z_1)$

Answer here z_1 = \begin{bmatrix} 2 & 0 \\ 4 & 0 \\ 0 & 2 \\ 0 & 4 \end{bmatrix}

Now, assume that:

$$ w_2 = \begin{bmatrix} 1 & -1 \end{bmatrix} $$

Calculate $z_2 = y_1 w_2^T$

Answer here z_2 = \begin{bmatrix} 2 \\ 4 \\ 2 \\ 4 \end{bmatrix}

What happened to the first and second data point? What happened to the third and fourth data points? Can we do that using a simple linear transformation?

Answer here We applied an identity transform to data points 1 and 2. We applied a multiplication by -1 to data points 3 and 4. This is impossible with a simple linear transformation

Henceforth, what is the role of using a non-linearity to isolate layers in our network?

Answer here Non-linearities allow applying dividing our vector space in regions, and then applying a different linear transformation in each region!

The Vanishing Gradient problem¶

ReLU is a powerful non-linearity, but is has an inherent problem.

Exercise

Assume:

$$ x = \begin{bmatrix}1 & 1\end{bmatrix}, $$

$b=-5$

and

$$ w = \begin{bmatrix}1 & 1\end{bmatrix} $$

In this case, what is the value of $z = xw^T+b$?

Answer here z=0

What is the value of $\frac{dz}{dx_1}$ and $\frac{dz}{dx_2}$?

Answer here In these conditions: \frac{dz}{dx_1}=\frac{dz}{dx_2}=0

Why is this harmful to gradient descent?

Answer here When we apply the chain rule, the derivative of the error with respectr to z is multiplied by the derivative of z with respect to x. However, this second factor is zero, hence the gradient of the error with respect to x (and with respect to the parameters w and b) is also zero.

Visualizing the vanishing gradient¶

The values for the weight and bias matrices must be trained using gradient descent. However, $f(.)$ has zero gradient for all negative inputs. For this reason, it is common to see output values organizing in straight lines - located exactly where the point of inflection is.

In [7]:
Copied!
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.nonlinearity = nn.ReLU()

        # Initialize weights to identity and biases to zero
        nn.init.eye_(self.fc1.weight)
        nn.init.zeros_(self.fc1.bias)
        nn.init.eye_(self.fc2.weight)
        nn.init.zeros_(self.fc2.bias)

    def forward(self, x):
        x = self.nonlinearity(self.fc1(x))
        x = self.fc2(x)
        return x

# Create an instance of the MLP
hidden_size = 2  # Example hidden layer size
mlp_model = MLP(input_size, hidden_size, output_size)
model, outputs = train_model(
    model=mlp_model,
    X=X,
    y=y,
    lr=0.01,
    epochs=5000
)

fig = animate_training(outputs, y, frame_duration=1, title='Linear by parts data, MLP model')
fig.show()
class MLP(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(MLP, self).__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.fc2 = nn.Linear(hidden_size, output_size) self.nonlinearity = nn.ReLU() # Initialize weights to identity and biases to zero nn.init.eye_(self.fc1.weight) nn.init.zeros_(self.fc1.bias) nn.init.eye_(self.fc2.weight) nn.init.zeros_(self.fc2.bias) def forward(self, x): x = self.nonlinearity(self.fc1(x)) x = self.fc2(x) return x # Create an instance of the MLP hidden_size = 2 # Example hidden layer size mlp_model = MLP(input_size, hidden_size, output_size) model, outputs = train_model( model=mlp_model, X=X, y=y, lr=0.01, epochs=5000 ) fig = animate_training(outputs, y, frame_duration=1, title='Linear by parts data, MLP model') fig.show()
100%|██████████| 5000/5000 [00:02<00:00, 1798.35it/s]

Residual blocks¶

The problem of zero-gradient has been tackled by many approaches. One of the most successfull was to create an alternate route for gradients to propagate. This route is called "residual", and involves adding the input to the output of the network, that is:

$$ \hat{y} = x + (f(xw_1^T+b_2)w_2^T+b_2), $$

or, using function aggregation:

$$ \hat{y} = x + ((xw_1^T+b_1) \circ f(x) \circ (xw_2^T+b_2)), $$ where $\circ$ denotes a function aggregation.

flowchart LR; Input subgraph MLP; direction LR; L1[Linear] --> ReLU --> L2[Linear] end; Input --> MLP MLP --> Add --> Output Input -.->|"Residual connection"| Add
In [8]:
Copied!
class ResidualBlock(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.fc1 = nn.Linear(hidden_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.nonlinearity = nn.ReLU()

        # Initialize weights to identity and biases to zero
        nn.init.eye_(self.fc1.weight)
        nn.init.zeros_(self.fc1.bias)
        nn.init.eye_(self.fc2.weight)
        nn.init.zeros_(self.fc2.bias)

    def forward(self, x):
        residual = x
        x = self.nonlinearity(self.fc1(x))
        x = self.fc2(x)
        x += residual  # Add the residual connection
        return x
    
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.r1 = ResidualBlock(hidden_size)
        self.r2 = ResidualBlock(hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)

        # Initialize weights to identity and biases to zero
        nn.init.eye_(self.fc1.weight)
        nn.init.zeros_(self.fc1.bias)
        nn.init.eye_(self.fc2.weight)
        nn.init.zeros_(self.fc2.bias)

    def forward(self, x):
        x = self.fc1(x)
        x = self.r1(x)
        x = self.r2(x)
        y = self.fc2(x)
        return y

# Create an instance of the MLP
hidden_size = 2  # Example hidden layer size
mlp_model = MLP(input_size, hidden_size, output_size)
model, outputs = train_model(
    model=mlp_model,
    X=X,
    y=y,
    lr=0.01,
    epochs=5000
)

fig = animate_training(outputs, y, frame_duration=1, title='Linear by parts data, MLP model with residual propagation')
fig.show()
class ResidualBlock(nn.Module): def __init__(self, hidden_size): super().__init__() self.fc1 = nn.Linear(hidden_size, hidden_size) self.fc2 = nn.Linear(hidden_size, hidden_size) self.nonlinearity = nn.ReLU() # Initialize weights to identity and biases to zero nn.init.eye_(self.fc1.weight) nn.init.zeros_(self.fc1.bias) nn.init.eye_(self.fc2.weight) nn.init.zeros_(self.fc2.bias) def forward(self, x): residual = x x = self.nonlinearity(self.fc1(x)) x = self.fc2(x) x += residual # Add the residual connection return x class MLP(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(MLP, self).__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.r1 = ResidualBlock(hidden_size) self.r2 = ResidualBlock(hidden_size) self.fc2 = nn.Linear(hidden_size, output_size) # Initialize weights to identity and biases to zero nn.init.eye_(self.fc1.weight) nn.init.zeros_(self.fc1.bias) nn.init.eye_(self.fc2.weight) nn.init.zeros_(self.fc2.bias) def forward(self, x): x = self.fc1(x) x = self.r1(x) x = self.r2(x) y = self.fc2(x) return y # Create an instance of the MLP hidden_size = 2 # Example hidden layer size mlp_model = MLP(input_size, hidden_size, output_size) model, outputs = train_model( model=mlp_model, X=X, y=y, lr=0.01, epochs=5000 ) fig = animate_training(outputs, y, frame_duration=1, title='Linear by parts data, MLP model with residual propagation') fig.show()
100%|██████████| 5000/5000 [00:05<00:00, 931.68it/s] 

Normalization¶

The non-linearities allow applying different transforms to each region of the input space. The residual connections avoids the vanishing gradient problem. Now, we add an extra layer of stability by normalizing data in each layer. Normalization helps maintaining all representations within reasonable values, which helps numerical stability and has ultimately been linked to faster convergence in neural networks.

Normalization works as an extra block, which is usually inserted after adding the residual connection:

flowchart LR; Input subgraph MLP; direction LR; L1[Linear] --> ReLU --> L2[Linear] end; Input --> MLP MLP --> Add --> Normalize --> Output Input -.->|"Residual connection"| Add

Using normalization after each layer, we observe a faster convergence.

In [9]:
Copied!
class ResidualBlock(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.fc1 = nn.Linear(hidden_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.nonlinearity = nn.ReLU()

        # Initialize weights to identity and biases to zero
        nn.init.eye_(self.fc1.weight)
        nn.init.zeros_(self.fc1.bias)
        nn.init.eye_(self.fc2.weight)
        nn.init.zeros_(self.fc2.bias)

    def forward(self, x):
        residual = x
        x = self.nonlinearity(self.fc1(x))
        x = self.fc2(x)
        x += residual  # Add the residual connection
        return x
    
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.bn1 = nn.BatchNorm1d(hidden_size)
        self.bn2 = nn.BatchNorm1d(hidden_size)
        self.bn3 = nn.BatchNorm1d(hidden_size)
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.r1 = ResidualBlock(hidden_size)
        self.r2 = ResidualBlock(hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)

        # Initialize weights to identity and biases to zero
        nn.init.eye_(self.fc1.weight)
        nn.init.zeros_(self.fc1.bias)
        nn.init.eye_(self.fc2.weight)
        nn.init.zeros_(self.fc2.bias)

    def forward(self, x):
        x = self.fc1(x)
        #x = self.bn1(x)
        x = self.r1(x)
        x = self.bn2(x)
        x = self.r2(x)
        x = self.bn3(x)
        y = self.fc2(x)
        return y

# Create an instance of the MLP
hidden_size = 2  # Example hidden layer size
mlp_model = MLP(input_size, hidden_size, output_size)
model, outputs = train_model(
    model=mlp_model,
    X=X,
    y=y,
    lr=0.01,
    epochs=5000
)

fig = animate_training(outputs, y, frame_duration=1, title='Linear by parts data, MLP model with residual propagation\nand batch normalization')
fig.show()
class ResidualBlock(nn.Module): def __init__(self, hidden_size): super().__init__() self.fc1 = nn.Linear(hidden_size, hidden_size) self.fc2 = nn.Linear(hidden_size, hidden_size) self.nonlinearity = nn.ReLU() # Initialize weights to identity and biases to zero nn.init.eye_(self.fc1.weight) nn.init.zeros_(self.fc1.bias) nn.init.eye_(self.fc2.weight) nn.init.zeros_(self.fc2.bias) def forward(self, x): residual = x x = self.nonlinearity(self.fc1(x)) x = self.fc2(x) x += residual # Add the residual connection return x class MLP(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(MLP, self).__init__() self.bn1 = nn.BatchNorm1d(hidden_size) self.bn2 = nn.BatchNorm1d(hidden_size) self.bn3 = nn.BatchNorm1d(hidden_size) self.fc1 = nn.Linear(input_size, hidden_size) self.r1 = ResidualBlock(hidden_size) self.r2 = ResidualBlock(hidden_size) self.fc2 = nn.Linear(hidden_size, output_size) # Initialize weights to identity and biases to zero nn.init.eye_(self.fc1.weight) nn.init.zeros_(self.fc1.bias) nn.init.eye_(self.fc2.weight) nn.init.zeros_(self.fc2.bias) def forward(self, x): x = self.fc1(x) #x = self.bn1(x) x = self.r1(x) x = self.bn2(x) x = self.r2(x) x = self.bn3(x) y = self.fc2(x) return y # Create an instance of the MLP hidden_size = 2 # Example hidden layer size mlp_model = MLP(input_size, hidden_size, output_size) model, outputs = train_model( model=mlp_model, X=X, y=y, lr=0.01, epochs=5000 ) fig = animate_training(outputs, y, frame_duration=1, title='Linear by parts data, MLP model with residual propagation\nand batch normalization') fig.show()
100%|██████████| 5000/5000 [00:09<00:00, 529.21it/s]

Conclusion¶

Althoug our reference states that there is a theoretical upper bound for the number of regions created by subsequent ReLU-separated projections, it is still challenging to find what are the optimal regions and corresponding projections for a particular dataset.

Our toolset for such is:

  1. We can use simple linear regressions or logistic regressions to find a baseline for our system.
  2. We can use the MLP topology to create potential regions in our dataset. More layers, and more neurons per layer, increase the expressivity of the network, that is, the number of linear regions it can model.
  3. Adding a residual connection helps propagating gradients to the earlier layers of the MLP, which favors using the whole potential of the network.
  4. Normalization layers help leading to a more numerically stable fit.

Practice¶

Make a neural network that maps $X$ to $y$ using the data below. Plot an animation of the convergence process, like the ones shown above. Try the different model variations - what happens in each case?

In [10]:
Copied!
x1 = torch.linspace(-1, 1, 500)
x2 = torch.linspace(1, -1, 500)
X = torch.stack([x1, x2], dim=1)
X += 0.1 * torch.randn(500, 2)  # Adding noise to X

y = 2*X + X**2 +0.5

import matplotlib.pyplot as plt

plt.figure(figsize=(8, 6))
plt.scatter(X[:, 0].numpy(), X[:, 1].numpy(), label='X', alpha=0.5)
plt.scatter(y[:, 0].numpy(), y[:, 1].numpy(), label='y', alpha=0.5)
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.title('Scatterplot of X and y')
plt.legend()
plt.show()
x1 = torch.linspace(-1, 1, 500) x2 = torch.linspace(1, -1, 500) X = torch.stack([x1, x2], dim=1) X += 0.1 * torch.randn(500, 2) # Adding noise to X y = 2*X + X**2 +0.5 import matplotlib.pyplot as plt plt.figure(figsize=(8, 6)) plt.scatter(X[:, 0].numpy(), X[:, 1].numpy(), label='X', alpha=0.5) plt.scatter(y[:, 0].numpy(), y[:, 1].numpy(), label='y', alpha=0.5) plt.xlabel('Feature 1') plt.ylabel('Feature 2') plt.title('Scatterplot of X and y') plt.legend() plt.show()
No description has been provided for this image

Documentation built with MkDocs.

Search

From here you can search these documents. Enter your search terms below.

Keyboard Shortcuts

Keys Action
? Open this help
n Next page
p Previous page
s Search