avalanche icon indicating copy to clipboard operation
avalanche copied to clipboard

Eval stream in JointTraining strategy

Open evertonaleixo opened this issue 3 years ago • 1 comments

I do not know if it is a bug or a project decision.

In JointTraining strategy, it merges the data of all training experiences (for example, 3 experiences). It does this in 'train_dataset_adaptation' hook. However, when it passes to the eval phase, at the end of epoch, it calls '_periodic_eval' method passing the 3 experiences eval data.

One problem with this behavior is that, when we use EarlyStopping plugin, it only considers the metric of the last experience. This does not make much sense to me.

🐜 To Reproduce

from torch.nn import CrossEntropyLoss
from torch.optim import SGD


from avalanche.evaluation.metrics import forgetting_metrics, accuracy_metrics, loss_metrics
from avalanche.evaluation.metrics import timing_metrics, cpu_usage_metrics
from avalanche.evaluation.metrics import confusion_matrix_metrics, disk_usage_metrics

from avalanche.benchmarks.classic import PermutedMNIST
from avalanche.benchmarks.scenarios import NCScenario
from avalanche.benchmarks.generators import benchmark_with_validation_stream

from avalanche.models import SimpleCNN, SimpleMLP
from avalanche.logging import TextLogger, TensorboardLogger, InteractiveLogger

from avalanche.training.plugins import EvaluationPlugin
from avalanche.training.strategies import Naive, JointTraining,BaseStrategy

from avalanche.training.plugins import EarlyStoppingPlugin

device = "cuda"
modelUpper = SimpleMLP()
optimizerUpper = SGD(modelUpper.parameters(), lr=1e-3)

optimizer = SGD(model.parameters(), lr=1e-3)
criterion = CrossEntropyLoss()

permutMnistBenchmark = PermutedMNIST(
  n_experiences=3,
  seed=1234,
)

permutMnistFullBenchmark = benchmark_with_validation_stream(permutMnistBenchmark, validation_size=0.2)

espUpper = EarlyStoppingPlugin(patience=3, val_stream_name='valid')

evalPluginUpper = EvaluationPlugin(
    accuracy_metrics(epoch=True, experience=True, stream=True),
    loss_metrics(epoch=True, experience=True, stream=True),
    timing_metrics(epoch=True, experience=True, stream=True),
    cpu_usage_metrics(experience=True, stream=True),
    confusion_matrix_metrics(num_classes=permutMnistBenchmark.n_classes_per_exp[0], save_image=True, stream=True),
    disk_usage_metrics(experience=True, stream=True),
    benchmark=permutMnistFullBenchmark,
    strict_checks=False,
    loggers=[TextLogger(open('log_upper.txt', 'a')), TensorboardLogger(), InteractiveLogger()]
)

upperBoundStrategy = JointTraining(
    modelUpper,
    optimizerUpper,
    criterion,
    train_mb_size=64,
    device=device,
    eval_mb_size=64,
    train_epochs=200000,
    eval_every=1,
    plugins=[espUpper],
    evaluator=evalPluginUpper
)

upperBoundStrategy.train(permutMnistFullBenchmark.train_stream, [permutMnistFullBenchmark.valid_stream])

🐝 Expected behavior

I think it could merge the eval data of all experiences in the 'train' method of JointTraining strategy. Or create a hook before '_periodic_eval' method of BaseStrategy, and then makes the merge in JointTraining strategy, as done in 'train_dataset_adaptation' method. I think the first one is more simple and easy to hotfix, however, the second could give us more flexibility in future strategies.

evertonaleixo avatar Apr 16 '22 14:04 evertonaleixo

I believe the main problem is that it is difficult to express exactly what kind of metric is needed for the early stopping since it depends on many factors (which stream, which experiences), and this may also change over time (e.g. the current experience) so providing a string is not sufficient.

I strongly prefer not touching the eval stream (o.w. metric names become more ambiguous) and instead give the ability to use metrics computed over the entire stream.

AntonioCarta avatar Apr 19 '22 12:04 AntonioCarta

Closing since we are going in a different direction. Stream-level metrics should be computed by aggregating experience-level metrics, instead of using the (possibly partial) stream used to call eval.

AntonioCarta avatar Jul 26 '23 09:07 AntonioCarta