Skip to content

Linear Regression - PyTorch Beginner 07

In this part we implement a logistic regression algorithm and apply all the concepts that we have learned so far.


Learn all the basics you need to get started with this deep learning framework! In this part we implement a logistic regression algorithm and apply all the concepts that we have learned so far:

  • Training Pipeline in PyTorch
  • Model Design
  • Loss and Optimizer
  • Automatic Training steps with forward pass, backward pass, and weight updates

All code from this course can be found on GitHub.

Linear Regression in PyTorch

import torch
import torch.nn as nn
import numpy as np
from sklearn import datasets
import matplotlib.pyplot as plt

# 0) Prepare data
X_numpy, y_numpy = datasets.make_regression(n_samples=100, n_features=1, noise=20, random_state=4)

# cast to float Tensor
X = torch.from_numpy(X_numpy.astype(np.float32))
y = torch.from_numpy(y_numpy.astype(np.float32))
y = y.view(y.shape[0], 1)

n_samples, n_features = X.shape

# 1) Model
# Linear model f = wx + b
input_size = n_features
output_size = 1
model = nn.Linear(input_size, output_size)

# 2) Loss and optimizer
learning_rate = 0.01

criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)  

# 3) Training loop
num_epochs = 100
for epoch in range(num_epochs):
    # Forward pass and loss
    y_predicted = model(X)
    loss = criterion(y_predicted, y)

    # Backward pass and update
    loss.backward()
    optimizer.step()

    # zero grad before new step
    optimizer.zero_grad()

    if (epoch+1) % 10 == 0:
        print(f'epoch: {epoch+1}, loss = {loss.item():.4f}')

# Plot
predicted = model(X).detach().numpy()

plt.plot(X_numpy, y_numpy, 'ro')
plt.plot(X_numpy, predicted, 'b')
plt.show()

FREE VS Code / PyCharm Extensions I Use

鉁 Write cleaner code with Sourcery, instant refactoring suggestions: Link*


PySaaS: The Pure Python SaaS Starter Kit

馃殌 Build a software business faster with pure Python: Link*

* These are affiliate link. By clicking on it you will not have any additional costs. Instead, you will support my project. Thank you! 馃檹