Spaces:
Runtime error
Runtime error
| import torch | |
| from collections import defaultdict | |
| from .metrics import * | |
| class EvaluatorBase(): | |
| ''' To use this class, you should inherit it and implement the `eval` method. ''' | |
| def __init__(self): | |
| self.accumulator = defaultdict(list) | |
| def eval(self, **kwargs): | |
| ''' Evaluate the metrics on the data. ''' | |
| raise NotImplementedError | |
| def get_results(self, chosen_metric=None): | |
| ''' Get the current mean results. ''' | |
| # Only chosen metrics will be compacted and returned. | |
| compacted = self._compact_accumulator(chosen_metric) | |
| ret = {} | |
| for k, v in compacted.items(): | |
| ret[k] = v.mean(dim=0).item() | |
| return ret | |
| def _compact_accumulator(self, chosen_metric=None): | |
| ''' Compact the accumulator list and return the compacted results. ''' | |
| ret = {} | |
| for k, v in self.accumulator.items(): | |
| # Only chosen metrics will be compacted. | |
| if chosen_metric is None or k in chosen_metric: | |
| ret[k] = torch.cat(v, dim=0) | |
| self.accumulator[k] = [ret[k]] | |
| return ret | |