In this tutorial we build an interactive deep learning app with Streamlit and PyTorch to apply style transfer. This tutorial should demonstrate how easy interactive web applications can be build with Streamlit. Streamlit lets you create beautiful apps for your machine learning or deep learning projects with simple Python scripts. See official Streamlit website for more info.
You can find the code on GitHub: https://github.com/python-engineer/pytorch-examples.
The style transfer code is based on this fast neural style code from the official PyTorch examples repo: Fast Neural Style.
Installation
It is recommended to use a virtual environment before installing the dependencies
pip install streamlit
pip install torch torchvision
Usage
Download the pretrained models
python download_saved_models.py
After downloading, move the saved_models folder into the neural_style folder. Then run
streamlit run main.py
The PyTorch functions
To utilize caching we split the original stylize()
function into two different functions for model loading and for applying the style transfer:
# style.py
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@st.cache
def load_model(model_path):
print('load model')
with torch.no_grad():
style_model = TransformerNet()
state_dict = torch.load(model_path)
# remove saved deprecated running_* keys in InstanceNorm from the checkpoint
for k in list(state_dict.keys()):
if re.search(r'in\d+\.running_(mean|var)$', k):
del state_dict[k]
style_model.load_state_dict(state_dict)
style_model.to(device)
style_model.eval()
return style_model
@st.cache
def stylize(style_model, content_image, output_image):
content_image = utils.load_image(content_image)
content_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.mul(255))
])
content_image = content_transform(content_image)
content_image = content_image.unsqueeze(0).to(device)
with torch.no_grad():
output = style_model(content_image).cpu()
utils.save_image(output_image, output[0])
The Streamlit App
Implementing the web app is straightforward and can be achieved in only 30 lines:
import streamlit as st
from PIL import Image
import style
st.title('PyTorch Style Transfer')
img = st.sidebar.selectbox(
'Select Image',
('amber.jpg', 'cat.png')
)
style_name = st.sidebar.selectbox(
'Select Style',
('candy', 'mosaic', 'rain_princess', 'udnie')
)
model= "saved_models/" + style_name + ".pth"
input_image = "images/content-images/" + img
output_image = "images/output-images/" + style_name + "-" + img
st.write('### Source image:')
image = Image.open(input_image)
st.image(image, width=400) # image: numpy array
clicked = st.button('Stylize')
if clicked:
model = style.load_model(model)
style.stylize(model, input_image, output_image)
st.write('### Output image:')
image = Image.open(output_image)
st.image(image, width=400)