mabwiser icon indicating copy to clipboard operation
mabwiser copied to clipboard

[WIP] Minimal changes necessary for Spock integration

Open ncilfone opened this issue 4 years ago • 1 comments

Encompasses the minimal amount of changes needed to integrate Spock.

Unit tests adjusted for all changes. 100% pass.

The only fundamental change to the API is that TreeBandit parameters are specified (as of sklearn>=1.0.2) instead of being an empty Dict. This is due to the nature of Spock type checking where the underlying type of a dictionary value cannot be typing.Any as isinstance raises an exception on these types (as it is not really a type per say). Thus the configuration for TreeBandit is now:

class _DTCCriterion(Enum):
    squared_error = "squared_error"
    friedman_mse = "friedman_mse"
    absolute_error = "absolute_error"
    poisson = "poisson"


class _DTCSplitter(Enum):
    best = "best"
    random = "random"


@spock
class TreeBandit:
    """TreeBandit Neighborhood Policy.
    This policy fits a decision tree for each arm using context history.
    It uses the leaves of these trees to partition the context space into regions
    and keeps a list of rewards for each leaf.
    To predict, it receives a context vector and goes to the corresponding
    leaf at each arm's tree and applies the given context-free MAB learning policy
    to predict expectations and choose an arm.
    The TreeBandit neighborhood policy is compatible with the following
    context-free learning policies only: EpsilonGreedy, ThompsonSampling and UCB1.
    The TreeBandit neighborhood policy is a modified version of
    the TreeHeuristic algorithm presented in:
    Adam N. Elmachtoub, Ryan McNellis, Sechan Oh, Marek Petrik
    A Practical Method for Solving Contextual Bandit Problems Using Decision Trees, UAI 2017
    Attributes
    ----------
    tree_parameters: Dict, **kwarg
        Parameters of the decision tree.
        The keys must match the parameters of sklearn.tree.DecisionTreeClassifier.
        When a parameter is not given, the default parameters from
        sklearn.tree.DecisionTreeClassifier will be chosen.
        Default value is an empty dictionary.
    Example
    -------
        >>> from mabwiser.mab import MAB, LearningPolicy, NeighborhoodPolicy
        >>> list_of_arms = ['Arm1', 'Arm2']
        >>> decisions = ['Arm1', 'Arm1', 'Arm2', 'Arm1']
        >>> rewards = [20, 17, 25, 9]
        >>> contexts = [[0, 1, 2, 3], [1, 2, 3, 0], [2, 3, 1, 0], [3, 2, 1, 0]]
        >>> mab = MAB(list_of_arms, LearningPolicy.EpsilonGreedy(epsilon=0), NeighborhoodPolicy.TreeBandit())
        >>> mab.fit(decisions, rewards, contexts)
        >>> mab.predict([[3, 2, 0, 1]])
        'Arm2'
    """

    criterion: Optional[_DTCCriterion] = _DTCCriterion.squared_error
    splitter: Optional[_DTCSplitter] = _DTCSplitter.best
    max_depth: Optional[int] = None
    min_samples_split: int = 2
    min_samples_leaf: int = 1
    min_weight_fraction_leaf: float = 0.0
    max_features: Optional[int] = None
    random_state: Optional[int] = None
    max_leaf_nodes: Optional[int] = None
    min_impurity_decrease: float = 0.0
    ccp_alpha: float = 0.0

    def __post_hook__(self):
        try:
            if self.ccp_alpha is not None:
                ge(self.ccp_alpha, bound=0.0)
        except Exception as e:
            raise ValueError(
                f"`{self.__class__.__name__}` could not be instantiated -- spock message: {e}"
            )

Notable Fixes

  • Fixed failing examples (Random, CustomizedMAB, ParallelMAB)
  • Fixed incorrect definitions of TreeBandit parameters -- implementation uses DecisionTreeRegressor while configuration was using DecisionTreeClassifier which have different parameters
  • Fixed all NoReturn types which are incorrect type hints

ncilfone avatar Mar 24 '22 16:03 ncilfone

@bkleyn @skadio

This is the minimal set of changes needed for Spock integration without heavily refactoring the backend and/or front-end API

ncilfone avatar Mar 24 '22 16:03 ncilfone

Closing this one due to over-time (many thanks to @ncilfone!)

skadio avatar Sep 21 '22 03:09 skadio