Skip to content

Build A PyTorch Style Transfer Web App With Streamlit

In this tutorial we build an interactive deep learning app with Streamlit and PyTorch to apply style transfer.


In this tutorial we build an interactive deep learning app with Streamlit and PyTorch to apply style transfer. This tutorial should demonstrate how easy interactive web applications can be build with Streamlit. Streamlit lets you create beautiful apps for your machine learning or deep learning projects with simple Python scripts. See official Streamlit website for more info.

You can find the code on GitHub: https://github.com/patrickloeber/pytorch-examples.

The style transfer code is based on this fast neural style code from the official PyTorch examples repo: Fast Neural Style.

Installation

It is recommended to use a virtual environment before installing the dependencies

pip install streamlit
pip install torch torchvision

Usage

Download the pretrained models

python download_saved_models.py

After downloading, move the saved_models folder into the neural_style folder. Then run

streamlit run main.py

The PyTorch functions

To utilize [caching]((https://docs.streamlit.io/en/latest/caching.html) we split the original stylize() function into two different functions for model loading and for applying the style transfer:

# style.py
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

@st.cache
def load_model(model_path):
    print('load model')
    with torch.no_grad():
        style_model = TransformerNet()
        state_dict = torch.load(model_path)
        # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
        for k in list(state_dict.keys()):
            if re.search(r'in\d+\.running_(mean|var)$', k):
                del state_dict[k]
        style_model.load_state_dict(state_dict)
        style_model.to(device)
        style_model.eval()
        return style_model

@st.cache
def stylize(style_model, content_image, output_image):
    content_image = utils.load_image(content_image)
    content_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device)

    with torch.no_grad():
        output = style_model(content_image).cpu()

    utils.save_image(output_image, output[0])

The Streamlit App

Implementing the web app is straightforward and can be achieved in only 30 lines:

import streamlit as st
from PIL import Image

import style

st.title('PyTorch Style Transfer')

img = st.sidebar.selectbox(
    'Select Image',
    ('amber.jpg', 'cat.png')
)

style_name = st.sidebar.selectbox(
    'Select Style',
    ('candy', 'mosaic', 'rain_princess', 'udnie')
)

model= "saved_models/" + style_name + ".pth"
input_image = "images/content-images/" + img
output_image = "images/output-images/" + style_name + "-" + img

st.write('### Source image:')
image = Image.open(input_image)
st.image(image, width=400) # image: numpy array

clicked = st.button('Stylize')

if clicked:
    model = style.load_model(model)
    style.stylize(model, input_image, output_image)

    st.write('### Output image:')
    image = Image.open(output_image)
    st.image(image, width=400)

FREE VS Code / PyCharm Extensions I Use

鉁 Write cleaner code with Sourcery, instant refactoring suggestions: Link*

* This is an affiliate link. By clicking on it you will not have any additional costs. Instead, you will support my project. Thank you! 馃檹