Back to course overview

Logistic Regression - PyTorch Beginner 08

Learn all the basics you need to get started with this deep learning framework! In this part we implement a logistic regression algorithm and apply all the concepts that we have learned so far:

All code from this course can be found on GitHub.

Logistic Regression in PyTorch

import torch import torch.nn as nn import numpy as np from sklearn import datasets from sklearn.preprocessing import StandardScaler from sklearn.model_selection import train_test_split # 0) Prepare data bc = datasets.load_breast_cancer() X, y =, n_samples, n_features = X.shape X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1234) # scale sc = StandardScaler() X_train = sc.fit_transform(X_train) X_test = sc.transform(X_test) X_train = torch.from_numpy(X_train.astype(np.float32)) X_test = torch.from_numpy(X_test.astype(np.float32)) y_train = torch.from_numpy(y_train.astype(np.float32)) y_test = torch.from_numpy(y_test.astype(np.float32)) y_train = y_train.view(y_train.shape[0], 1) y_test = y_test.view(y_test.shape[0], 1) # 1) Model # Linear model f = wx + b , sigmoid at the end class Model(nn.Module): def __init__(self, n_input_features): super(Model, self).__init__() self.linear = nn.Linear(n_input_features, 1) def forward(self, x): y_pred = torch.sigmoid(self.linear(x)) return y_pred model = Model(n_features) # 2) Loss and optimizer num_epochs = 100 learning_rate = 0.01 criterion = nn.BCELoss() optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) # 3) Training loop for epoch in range(num_epochs): # Forward pass and loss y_pred = model(X_train) loss = criterion(y_pred, y_train) # Backward pass and update loss.backward() optimizer.step() # zero grad before new step optimizer.zero_grad() if (epoch+1) % 10 == 0: print(f'epoch: {epoch+1}, loss = {loss.item():.4f}') with torch.no_grad(): y_predicted = model(X_test) y_predicted_cls = y_predicted.round() acc = y_predicted_cls.eq(y_test).sum() / float(y_test.shape[0]) print(f'accuracy: {acc.item():.4f}')

FREE VS Code / PyCharm Extensions I Use

🪁 Code faster with Kite, AI-powered autocomplete: Link *

✅ Write cleaner code with Sourcery, instant refactoring suggestions: Link *

* These are affiliate links. By clicking on it you will not have any additional costs, instead you will support me and my project. Thank you! 🙏

Check out my Courses