Python Machine learning
pytorch
Jan 6, 2019     7 minutes read

1. What is pytorch and why would you use it?

2. “Hello world” example

inspired by this article

Let’s define some data:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


theta = 2
x = np.random.rand(10) * 10
y = x ** theta
data = pd.DataFrame(dict(x=x, y=y))
print(data)
##           x          y
## 0  8.520089  72.591910
## 1  8.610977  74.148918
## 2  4.796133  23.002888
## 3  9.971878  99.438360
## 4  7.678634  58.961420
## 5  7.072061  50.014045
## 6  5.995141  35.941718
## 7  4.744001  22.505545
## 8  4.623789  21.379420
## 9  3.899390  15.205244
plt.figure(1)
plt.scatter(x, y)
plt.show()

theta is a parameter we’ll be estimating. We will see how close to 2 our optimization algorithm gets us.

Estimating the value of theta:

import torch
from torch.autograd import Variable


def rmse(y, y_hat):
    """Compute root mean squared error"""
    return torch.sqrt(torch.mean((y - y_hat).pow(2).sum()))


def forward(x, e):
    """Forward pass for our fuction"""
    return x.pow(e.repeat(x.size(0)))

# initial settings
learning_rate = 0.00005

x = Variable(torch.FloatTensor(data['x']), requires_grad=False)
y = Variable(torch.FloatTensor(data['y']), requires_grad=False)

# should be a random value, but don't know how to set seed in pytorch
theta_hat = Variable(torch.FloatTensor([1]), requires_grad=True)

loss_history = []
theta_history = []

for i in range(0, 600):
    y_hat = forward(x, theta_hat)
    loss = rmse(y, y_hat)
    loss_history.append(loss.data.item())
    loss.backward()

    theta_hat.data -= learning_rate * theta_hat.grad.data
    theta_hat.grad.data.zero_()
    theta_history.append(theta_hat.data.item())
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])
## tensor([0.])

And the estimation is…

print(theta_hat.data)
## tensor([2.0092])

Not bad! Let’s see how the process of learning looked:

plt.figure(2)
plt.scatter(x=range(len(loss_history)), y=loss_history)
plt.title("Loss history")
plt.show()

plt.figure(3)
plt.scatter(x=range(len(theta_history)), y=theta_history)
plt.title("Theta history")
plt.show()

Pytorch reached the correct estimation in about 200 iterations.