Simple MNIST weights

In [1]:
import copy
import matplotlib.pyplot as plt
import torch
from torch import nn, optim
from torchvision import datasets, transforms
import torch.nn.functional as F

This will not be a CNN, we only need some very basic ANN.

In [2]:
def target_reshape(target, n_cls=None):
    if not n_cls:
        n_cls = len(target.long().unique())
    left = torch.LongTensor(range(len(target))).unsqueeze(0)
    right = target.long().unsqueeze(0)
    idx = torch.cat([left, right]).t()
    new_target = torch.zeros(len(target), n_cls)
    new_target[idx[:, 0], idx[:, 1]] = 1.
    return new_target

#target_reshape(torch.LongTensor([0, 1, 2, 1, 0]))#torch.ones(5).long())
#torch.LongTensor([0, 1, 2, 1, 0]).unique()
In [3]:
# not a CNN, no need to normalize
mnist_train = datasets.MNIST('data', train=True, download=True)
mnist_test = datasets.MNIST('data', train=True, download=True)
mnist_train.data.shape, mnist_train.targets.shape, mnist_test.data.shape, mnist_test.targets.shape
Out[3]:
(torch.Size([60000, 28, 28]),
 torch.Size([60000]),
 torch.Size([60000, 28, 28]),
 torch.Size([60000]))
In [4]:
X_train, X_test = torch.flatten(mnist_train.data, 1), torch.flatten(mnist_test.data, 1)
X_train, X_test = X_train.float(), X_test.float()
X_train.dtype, X_train.shape, X_test.dtype, X_test.shape
Out[4]:
(torch.float32,
 torch.Size([60000, 784]),
 torch.float32,
 torch.Size([60000, 784]))
In [5]:
y_train, y_test = mnist_train.targets.float(), mnist_test.targets.float()
y_train, y_test = target_reshape(y_train), target_reshape(y_test)
y_train.dtype, y_train.shape, y_test.dtype, y_test.shape
Out[5]:
(torch.float32,
 torch.Size([60000, 10]),
 torch.float32,
 torch.Size([60000, 10]))
In [6]:
# torch, seriously, why this does not exists?  it is super boilerplate
class DSFlat(torch.utils.data.Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


kwargs = dict(batch_size=1000, shuffle=True)
loader_train = torch.utils.data.DataLoader(DSFlat(X_train, y_train), **kwargs)
loader_test = torch.utils.data.DataLoader(DSFlat(X_test, y_test), ** kwargs)
In [7]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        #self.fc1 = nn.Linear(28, 280)
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 64)
        self.fc4 = nn.Linear(64, 32)
        self.fc5 = nn.Linear(32, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = torch.tanh(x)
        x = self.fc2(x)
        x = torch.tanh(x)
        x = self.fc3(x)
        x = torch.tanh(x)
        x = self.fc4(x)
        x = torch.tanh(x)
        x = self.fc5(x)
        x = torch.tanh(x)
        return x


main_model = Net()
In [8]:
def weight_init(layer):
    if isinstance(layer, nn.Linear):
        nn.init.normal_(layer.weight, mean=0., std=1./layer.weight.shape[0])
        nn.init.normal_(layer.bias, mean=0., std=1./layer.bias.shape[0])


main_model.apply(weight_init)
with torch.no_grad():
    print(list(main_model.parameters())[0][:, 0])
tensor([ 8.9700e-03, -7.7628e-03,  6.1073e-03,  4.3121e-03,  5.1590e-03,
         4.9937e-03,  1.9669e-06,  5.8203e-04, -3.5670e-03,  1.3484e-02,
        -2.3590e-03, -5.5416e-03, -4.2532e-03,  1.9974e-02, -7.2682e-03,
        -1.2502e-02,  3.8435e-03,  6.0157e-03,  6.2353e-03,  2.3243e-02,
        -8.9876e-03,  1.7035e-02,  5.1363e-03,  2.2865e-03,  1.3453e-02,
        -2.2533e-03, -4.6464e-03,  4.4941e-03, -1.6753e-03,  1.9458e-03,
        -1.1414e-02,  4.8388e-03,  1.5474e-03,  9.1407e-03, -6.8918e-04,
        -9.4471e-03, -8.1504e-03, -2.6788e-04, -7.2986e-03, -5.3564e-03,
         1.3822e-02, -5.1286e-03,  3.7515e-03, -5.0374e-03, -3.3580e-03,
         2.3971e-04,  1.3948e-02, -8.0490e-03,  7.7263e-03,  6.0016e-03,
         1.9093e-03, -1.9067e-03, -2.8667e-03, -7.6309e-04,  1.6329e-03,
        -8.1186e-03,  1.6430e-02, -9.0990e-03, -1.6359e-02,  6.1584e-03,
        -4.7498e-03, -6.6008e-03,  3.9043e-03,  5.5709e-05,  4.8035e-03,
         3.8247e-03, -2.9911e-03, -1.0898e-02, -1.0261e-02,  1.8101e-03,
        -2.1351e-05,  6.3482e-03,  1.4705e-02,  3.5996e-04,  1.1084e-03,
         2.4933e-03, -8.4690e-03,  2.4193e-02, -5.8261e-03, -1.0607e-02,
        -1.6720e-02,  8.9839e-03,  6.9497e-03, -9.7834e-03,  2.4808e-03,
         1.4056e-02, -3.0682e-03, -1.7787e-02,  2.7480e-03, -1.0192e-02,
         4.7187e-03,  5.5004e-03, -7.7283e-03,  6.0476e-03, -1.3350e-02,
         3.0279e-03,  1.2910e-03, -4.2525e-03,  7.2957e-03, -1.3627e-02,
         9.5277e-03,  7.1334e-03,  1.3018e-03, -6.1815e-03,  3.4147e-03,
         6.2413e-03, -4.3534e-03,  1.5975e-03,  5.5802e-03,  7.6046e-03,
        -3.2308e-03, -5.4937e-03, -2.4917e-03, -3.9858e-03, -2.7521e-03,
         9.1837e-03, -1.2787e-02,  2.9422e-03,  4.0538e-03, -5.6661e-03,
         3.7011e-04, -3.1495e-03,  1.0714e-02, -1.3368e-03, -4.7869e-03,
        -2.5558e-03, -4.9807e-03, -9.5094e-03], requires_grad=True)
In [9]:
def flat_grads(model):
    all_grads = []
    for p in model.parameters():
        if p.grad is not None:
            all_grads.append(torch.flatten(p.grad.detach().clone()))
    return all_grads


#torch.cat(flat_grads(model))
In [10]:
#loss_fun = nn.MSELoss()
#grad_acc = []
def train(model, loader_train, optimizer, epoch, grad_acc=None, verbose=False):
    model.train()
    for batch_idx, (data, target) in enumerate(loader_train):
        #data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        #print(output.shape, target.shape)
        loss = F.mse_loss(output, target)
        #loss = loss_fun(output, target)
        loss.backward()
        optimizer.step()
        if grad_acc is not None:
            grad_acc.append(torch.cat(flat_grads(model)))
        if verbose and batch_idx % 20 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(loader_train.dataset),
                100. * batch_idx / len(loader_train), loss.item()))
In [11]:
def test(model, loader_test):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in loader_test:
            output = model(data)
            test_loss += F.mse_loss(output, target, reduction='sum').item()
            y_hat = output.argmax(dim=1)
            y = target.argmax(dim=1)
            correct += y_hat.eq(y).sum().item()

    test_loss /= len(loader_test.dataset)
    test_acc = correct / len(loader_test.dataset)
    print(f'Test set average loss: {test_loss}, ACC: {test_acc}')
In [12]:
model = copy.deepcopy(main_model)
optimizer = optim.SGD(model.parameters(), lr=0.1)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
epochs = 10
grad_acc = []
for epoch in range(1, epochs + 1):
    train(model, loader_train, optimizer, epoch, grad_acc=grad_acc)
    n_grads = len(torch.cat(grad_acc))
    print(f'EPOCH {epoch}, GRADS {n_grads}')
    test(model, loader_test)
    scheduler.step(epoch)
EPOCH 1, GRADS 6918360
Test set average loss: 0.9040954406738281, ACC: 0.1421
EPOCH 2, GRADS 13836720
Test set average loss: 0.8889137339274088, ACC: 0.25683333333333336
EPOCH 3, GRADS 20755080
Test set average loss: 0.8793820475260417, ACC: 0.43123333333333336
EPOCH 4, GRADS 27673440
Test set average loss: 0.8655454620361328, ACC: 0.5471166666666667
EPOCH 5, GRADS 34591800
Test set average loss: 0.8429606353759765, ACC: 0.5513833333333333
EPOCH 6, GRADS 41510160
Test set average loss: 0.8072774586995443, ACC: 0.5327
EPOCH 7, GRADS 48428520
Test set average loss: 0.7607387135823568, ACC: 0.5307333333333333
EPOCH 8, GRADS 55346880
Test set average loss: 0.7139944183349609, ACC: 0.5623833333333333
EPOCH 9, GRADS 62265240
Test set average loss: 0.6711846537272136, ACC: 0.61525
EPOCH 10, GRADS 69183600
Test set average loss: 0.633376557413737, ACC: 0.6466833333333334
In [13]:
#all_grads = torch.cat(grad_acc)
fig, ax = plt.subplots(figsize=(16, 9))
#ax.plot(all_grads[:-1], all_grads[1:])
#for grads in all_grads.split(10**3):
for grads in grad_acc:
    ax.plot(grads[:-1], grads[1:], '.')