from avalanche.benchmarks.classic import SplitMNIST
from avalanche.evaluation.metrics import forgetting_metrics, accuracy_metrics,\
loss_metrics, timing_metrics, cpu_usage_metrics, StreamConfusionMatrix,\
disk_usage_metrics, gpu_usage_metrics
from avalanche.models import SimpleMLP
from avalanche.logging import InteractiveLogger, TextLogger, TensorboardLogger
from avalanche.training.plugins import EvaluationPlugin
from avalanche.training import Naive
from torch.optim import SGD
from torch.nn import CrossEntropyLoss
benchmark = SplitMNIST(n_experiences=5)
model = SimpleMLP(num_classes=scenario.n_classes)
# DEFINE THE EVALUATION PLUGIN and LOGGERS
# The evaluation plugin manages the metrics computation.
# It takes as argument a list of metrics, collectes their results and returns
# them to the strategy it is attached to.
tb_logger = TensorboardLogger()
text_logger = TextLogger(open('log.txt', 'a'))
interactive_logger = InteractiveLogger()
eval_plugin = EvaluationPlugin(
accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
timing_metrics(epoch=True),
cpu_usage_metrics(experience=True),
forgetting_metrics(experience=True, stream=True),
StreamConfusionMatrix(num_classes=scenario.n_classes, save_image=False),
disk_usage_metrics(minibatch=True, epoch=True, experience=True, stream=True),
loggers=[interactive_logger, text_logger, tb_logger],
# CREATE THE STRATEGY INSTANCE (NAIVE)
model, SGD(model.parameters(), lr=0.001, momentum=0.9),
CrossEntropyLoss(), train_mb_size=500, train_epochs=1, eval_mb_size=100,
print('Starting experiment...')
for experience in benchmark.train_stream:
print("Start of experience: ", experience.current_experience)
print("Current Classes: ", experience.classes_in_this_experience)
# train returns a dictionary which contains all the metric values
res = cl_strategy.train(experience, num_workers=4)
print('Training completed')
print('Computing accuracy on the whole test set')
# eval also returns a dictionary which contains all the metric values
results.append(cl_strategy.eval(benchmark.test_stream, num_workers=4))