from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, RandomCrop
from torch.utils.data import DataLoader
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class SimpleMLP(nn.Module):
def __init__(self, num_classes=10, input_size=28*28):
super(SimpleMLP, self).__init__()
self.features = nn.Sequential(
nn.Linear(input_size, 512),
self.classifier = nn.Linear(512, num_classes)
self._input_size = input_size
x = x.view(x.size(0), self._input_size)
model = SimpleMLP(num_classes=10)
rng_permute = np.random.RandomState(0)
train_transform = transforms.Compose([
RandomCrop(28, padding=4),
transforms.Normalize((0.1307,), (0.3081,))
test_transform = transforms.Compose([
transforms.Normalize((0.1307,), (0.3081,))
# permutation transformation
class PixelsPermutation(object):
def __init__(self, index_permutation):
self.permutation = index_permutation
return x.view(-1)[self.permutation].view(1, 28, 28)
return torch.from_numpy(rng_permute.permutation(784)).type(torch.int64)
# for every incremental step
# choose a random permutation of the pixels in the image
idx_permute = get_permutation()
current_perm = PixelsPermutation(idx_permute)
permutations.append(idx_permute)
# add the permutation to the default dataset transformation
train_transform_list = train_transform.transforms.copy()
train_transform_list.append(current_perm)
new_train_transform = transforms.Compose(train_transform_list)
test_transform_list = test_transform.transforms.copy()
test_transform_list.append(current_perm)
new_test_transform = transforms.Compose(test_transform_list)
# get the datasets with the constructed transformation
permuted_train = MNIST(root='./data/mnist',
download=True, transform=new_train_transform)
permuted_test = MNIST(root='./data/mnist',
download=True, transform=new_test_transform)
list_train_dataset.append(permuted_train)
list_test_dataset.append(permuted_test)
optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = CrossEntropyLoss()
for task_id, train_dataset in enumerate(list_train_dataset):
train_data_loader = DataLoader(
train_dataset, num_workers=4, batch_size=32)
for iteration, (train_mb_x, train_mb_y) in enumerate(train_data_loader):
train_mb_x = train_mb_x.to(device)
train_mb_y = train_mb_y.to(device)
logits = model(train_mb_x)
loss = criterion(logits, train_mb_y)
for task_id, test_dataset in enumerate(list_test_dataset):
test_data_loader = DataLoader(
test_dataset, num_workers=4, batch_size=32)
for iteration, (test_mb_x, test_mb_y) in enumerate(test_data_loader):
# Move mini-batch data to device
test_mb_x = test_mb_x.to(device)
test_mb_y = test_mb_y.to(device)
test_logits = model(test_mb_x)
test_loss = criterion(test_logits, test_mb_y)
correct += test_mb_y.eq(test_logits.argmax(dim=1)).sum().item()
acc_results.append(correct / len(test_dataset))