import copy
import matplotlib.pyplot as plt
import torch
from torch import nn, optim
from torchvision import datasets, transforms
import torch.nn.functional as F
This will not be a CNN, we only need some very basic ANN.
def target_reshape(target, n_cls=None):
if not n_cls:
n_cls = len(target.long().unique())
left = torch.LongTensor(range(len(target))).unsqueeze(0)
right = target.long().unsqueeze(0)
idx = torch.cat([left, right]).t()
new_target = torch.zeros(len(target), n_cls)
new_target[idx[:, 0], idx[:, 1]] = 1.
return new_target
#target_reshape(torch.LongTensor([0, 1, 2, 1, 0]))#torch.ones(5).long())
#torch.LongTensor([0, 1, 2, 1, 0]).unique()
# not a CNN, no need to normalize
mnist_train = datasets.MNIST('data', train=True, download=True)
mnist_test = datasets.MNIST('data', train=True, download=True)
mnist_train.data.shape, mnist_train.targets.shape, mnist_test.data.shape, mnist_test.targets.shape
X_train, X_test = torch.flatten(mnist_train.data, 1), torch.flatten(mnist_test.data, 1)
X_train, X_test = X_train.float(), X_test.float()
X_train.dtype, X_train.shape, X_test.dtype, X_test.shape
y_train, y_test = mnist_train.targets.float(), mnist_test.targets.float()
y_train, y_test = target_reshape(y_train), target_reshape(y_test)
y_train.dtype, y_train.shape, y_test.dtype, y_test.shape
# torch, seriously, why this does not exists? it is super boilerplate
class DSFlat(torch.utils.data.Dataset):
def __init__(self, X, y):
self.X = X
self.y = y
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
return self.X[idx], self.y[idx]
kwargs = dict(batch_size=1000, shuffle=True)
loader_train = torch.utils.data.DataLoader(DSFlat(X_train, y_train), **kwargs)
loader_test = torch.utils.data.DataLoader(DSFlat(X_test, y_test), ** kwargs)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
#self.fc1 = nn.Linear(28, 280)
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 64)
self.fc4 = nn.Linear(64, 32)
self.fc5 = nn.Linear(32, 10)
def forward(self, x):
x = self.fc1(x)
x = torch.tanh(x)
x = self.fc2(x)
x = torch.tanh(x)
x = self.fc3(x)
x = torch.tanh(x)
x = self.fc4(x)
x = torch.tanh(x)
x = self.fc5(x)
x = torch.tanh(x)
return x
main_model = Net()
def weight_init(layer):
if isinstance(layer, nn.Linear):
nn.init.normal_(layer.weight, mean=0., std=1./layer.weight.shape[0])
nn.init.normal_(layer.bias, mean=0., std=1./layer.bias.shape[0])
main_model.apply(weight_init)
with torch.no_grad():
print(list(main_model.parameters())[0][:, 0])
def flat_grads(model):
all_grads = []
for p in model.parameters():
if p.grad is not None:
all_grads.append(torch.flatten(p.grad.detach().clone()))
return all_grads
#torch.cat(flat_grads(model))
#loss_fun = nn.MSELoss()
#grad_acc = []
def train(model, loader_train, optimizer, epoch, grad_acc=None, verbose=False):
model.train()
for batch_idx, (data, target) in enumerate(loader_train):
#data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
#print(output.shape, target.shape)
loss = F.mse_loss(output, target)
#loss = loss_fun(output, target)
loss.backward()
optimizer.step()
if grad_acc is not None:
grad_acc.append(torch.cat(flat_grads(model)))
if verbose and batch_idx % 20 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(loader_train.dataset),
100. * batch_idx / len(loader_train), loss.item()))
def test(model, loader_test):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in loader_test:
output = model(data)
test_loss += F.mse_loss(output, target, reduction='sum').item()
y_hat = output.argmax(dim=1)
y = target.argmax(dim=1)
correct += y_hat.eq(y).sum().item()
test_loss /= len(loader_test.dataset)
test_acc = correct / len(loader_test.dataset)
print(f'Test set average loss: {test_loss}, ACC: {test_acc}')
model = copy.deepcopy(main_model)
optimizer = optim.SGD(model.parameters(), lr=0.1)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
epochs = 10
grad_acc = []
for epoch in range(1, epochs + 1):
train(model, loader_train, optimizer, epoch, grad_acc=grad_acc)
n_grads = len(torch.cat(grad_acc))
print(f'EPOCH {epoch}, GRADS {n_grads}')
test(model, loader_test)
scheduler.step(epoch)
#all_grads = torch.cat(grad_acc)
fig, ax = plt.subplots(figsize=(16, 9))
#ax.plot(all_grads[:-1], all_grads[1:])
#for grads in all_grads.split(10**3):
for grads in grad_acc:
ax.plot(grads[:-1], grads[1:], '.')