Python Engineer

Free Python and Machine Learning Tutorials

Become A Patron and get exclusive content! Get access to ML From Scratch notebooks, join a private Slack channel, get priority response, and more! I really appreciate the support!

Build A PyTorch Style Transfer Web App With Streamlit

22 Sep 2020

In this tutorial we build an interactive deep learning app with Streamlit and PyTorch to apply style transfer. This tutorial should demonstrate how easy interactive web applications can be build with Streamlit. Streamlit lets you create beautiful apps for your machine learning or deep learning projects with simple Python scripts. See official Streamlit website for more info.

You can find the code on GitHub:

The style transfer code is based on this fast neural style code from the official PyTorch examples repo: Fast Neural Style.


It is recommended to use a virtual environment before installing the dependencies

pip install streamlit pip install torch torchvision


Download the pretrained models


After downloading, move the saved_models folder into the neural_style folder. Then run

streamlit run

The PyTorch functions

To utilize caching we split the original stylize() function into two different functions for model loading and for applying the style transfer:

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @st.cache def load_model(model_path): print('load model') with torch.no_grad(): style_model = TransformerNet() state_dict = torch.load(model_path) # remove saved deprecated running_* keys in InstanceNorm from the checkpoint for k in list(state_dict.keys()): if'in\d+\.running_(mean|var)$', k): del state_dict[k] style_model.load_state_dict(state_dict) style_model.eval() return style_model @st.cache def stylize(style_model, content_image, output_image): content_image = utils.load_image(content_image) content_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) content_image = content_transform(content_image) content_image = content_image.unsqueeze(0).to(device) with torch.no_grad(): output = style_model(content_image).cpu() utils.save_image(output_image, output[0])

The Streamlit App

Implementing the web app is straightforward and can be achieved in only 30 lines:

import streamlit as st from PIL import Image import style st.title('PyTorch Style Transfer') img = st.sidebar.selectbox( 'Select Image', ('amber.jpg', 'cat.png') ) style_name = st.sidebar.selectbox( 'Select Style', ('candy', 'mosaic', 'rain_princess', 'udnie') ) model= "saved_models/" + style_name + ".pth" input_image = "images/content-images/" + img output_image = "images/output-images/" + style_name + "-" + img st.write('### Source image:') image = st.image(image, width=400) # image: numpy array clicked = st.button('Stylize') if clicked: model = style.load_model(model) style.stylize(model, input_image, output_image) st.write('### Output image:') image = st.image(image, width=400)