Source code for skmine.callbacks

"""
Callback API for scikit-mine
"""
import re
from inspect import signature, getsource


def has_self_assigment(f):
    """
    Parameters
    ----------
    f: callable

    check if no assignment is made on the ``self`` keyword
    """
    try:
        p = r"self(\.\w+)?\s*=.*"
        return any(re.finditer(p, getsource(f)))
    except TypeError:
        return False


def _get_params(fn):
    assert callable(fn)
    try:
        sig = signature(fn)
        params = list(sig.parameters)
    except ValueError:
        params = list()
    return params


def post(self, func_name, callback):
    """decorator, call callback on returned values, after main function"""
    func = getattr(self, func_name)
    assert callable(func)
    callback_params = _get_params(callback)

    def _(*args, **kwargs):
        res = func(*args, **kwargs)
        if "self" in callback_params:
            _w = (res,) if len(callback_params) == 2 else res
            args_ = (self,) + (_w if len(callback_params) > 1 else tuple())
        else:
            args_ = res
        try:
            callback(*args_)
        except TypeError:
            callback(args_)  # list.extend
        return res

    return _


[docs]class CallBacks(dict): """ A collection of callbacks Works by defining functions to be called after the execution of the function they target Parameters ---------- key-value pairs Keys must be string, values must be callables Examples -------- >>> class A(): ... def f(self): ... return 10 ... >>> from skmine.callbacks import CallBacks >>> stack = list() >>> callbacks = CallBacks(f=stack.append) >>> a = A() >>> callbacks(a) >>> a.f() 10 >>> stack [10] """ def __init__(self, **kwargs): dict.__init__(self, **kwargs) self._check() def _check(self): # TODO : inspect source code from callbacks and check no assigment on inner self for v in self.values(): if not callable(v): raise TypeError(f"values must be callables, found {type(v)}") if has_self_assigment(v): # TODO : only allow lambdas or builtins ? raise ValueError("callbacks should not modify `self` attributes") def _frozen(self, *args, **kwargs): raise NotImplementedError(f"{type(self)} is immutable") __setitem__ = _frozen update = _frozen def __call__(self, miner): # print("dir(miner)", dir(miner)) # print("miner", miner) # print("miner transform", getattr(miner, "transform")) # print("miner update", getattr(miner, "update"),miner.update ) # print("miner set_output",miner.set_output ) miner_methods = [] for f_name in dir(miner): try: if callable(getattr(miner, f_name)): miner_methods.append(f_name) except: print(f"warning : `f_name`='{f_name}' return an error for `callable(getattr(miner, f_name)`") # miner_methods = [ # f_name for f_name in dir(miner) if callable(getattr(miner, f_name)) # ] for f_name in self.keys(): if not f_name in miner_methods: raise ValueError( f"{f_name} found in callbacks while there is not corresponding function" ) for callback_name, callback in self.items(): new_meth = post(miner, callback_name, callback) # TODO : lock assignement to "once for all" # re-executing a cell in a notebook can lead to callbacks being called many times setattr(miner, callback_name, new_meth)
def _print_positive_gain(self, data_size, model_size, *_): diff = (self.model_size_ + self.data_size_) - (data_size + model_size) if diff > 0.01: print( "data size : {:.2f} | model size : {:.2f}".format(data_size, model_size) ) def _print_candidates_size(self, candidates): print("{} new candidates considered".format(len(candidates))) mdl_prints = CallBacks( evaluate=_print_positive_gain, generate_candidates=_print_candidates_size ) mdl_prints.__doc__ = """ Base callback for miners which inherit the :class:`skmine.base.MDLOptimizer` Prints data size and model size when compression has improved, only if ``verbose`` is set to True for the miner to attach. Examples -------- >>> from skmine.callbacks import mdl_prints >>> from skmine.base import MDLOptimizer >>> class MyMDLMiner(MDLOptimizer): ... def __init__(self): ... self.codetable_ = dict() ... def generate_candidates(self): ... return [(2,), (2, 3), (2, 4)] ... def evaluate(self): ... pass >>> miner = MyMDLMiner() >>> mdl_prints(miner) >>> miner.generate_candidates() 3 new candidates considered [(2,), (2, 3), (2, 4)] """