Training
Baselines and Strategies Code Examples
"LWF" Example
1
model = SimpleMLP(hidden_size=args.hs)
2
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
3
criterion = torch.nn.CrossEntropyLoss()
4
โ€‹
5
# check if selected GPU is available or use CPU
6
assert args.cuda == -1 or args.cuda >= 0, "cuda must be -1 or >= 0."
7
if args.cuda >= 0:
8
assert torch.cuda.device_count() > args.cuda,
9
f"{args.cuda + 1} GPU needed. Found {torch.cuda.device_count()}."
10
device = 'cpu' if args.cuda == -1 else f'cuda:{args.cuda}'
11
print(f'Using device: {device}')
12
โ€‹
13
# create split scenario
14
scenario = SplitMNIST(n_experiences=5, return_task_id=False)
15
โ€‹
16
interactive_logger = InteractiveLogger()
17
eval_plugin = EvaluationPlugin(
18
accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
19
loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
20
ExperienceForgetting(),
21
loggers=[interactive_logger])
22
โ€‹
23
# create strategy
24
assert len(args.lwf_alpha) == 1 or len(args.lwf_alpha) == 5,
25
'Alpha must be a non-empty list.'
26
lwf_alpha = args.lwf_alpha[0] if len(args.lwf_alpha) == 1 else args.lwf_alpha
27
โ€‹
28
strategy = LwF(model, optimizer, criterion, alpha=lwf_alpha,
29
temperature=args.softmax_temperature,
30
train_epochs=args.epochs, device=device,
31
train_mb_size=args.minibatch_size, evaluator=eval_plugin)
32
โ€‹
33
# train on the selected scenario with the chosen strategy
34
print('Starting experiment...')
35
results = []
36
for train_batch_info in scenario.train_stream:
37
print("Start training on experience ", train_batch_info.current_experience)
38
โ€‹
39
strategy.train(train_batch_info, num_workers=4)
40
print("End training on experience ", train_batch_info.current_experience)
41
print('Computing accuracy on the test set')
42
results.append(strategy.eval(scenario.test_stream[:]))
Copied!

๐Ÿค Run it on Google Colab

You can run this chapter and play with it on Google Colaboratory:
Notebook currently unavailable.
Last modified 8mo ago
Export as PDF
Copy link