Source code for lib.sedna.core.lifelong_learning.lifelong_learning

# Copyright 2021 The KubeEdge Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile

from sedna.backend import set_backend
from sedna.core.base import JobBase
from sedna.common.file_ops import FileOps
from sedna.common.constant import K8sResourceKind, K8sResourceKindStatus
from sedna.common.constant import KBResourceConstant
from sedna.common.config import Context
from sedna.common.class_factory import ClassType, ClassFactory
from sedna.algorithms.multi_task_learning import MulTaskLearning
from sedna.service.client import KBClient


[docs]class LifelongLearning(JobBase): """ Lifelong Learning (LL) is an advanced machine learning (ML) paradigm that learns continuously, accumulates the knowledge learned in the past, and uses/adapts it to help future learning and problem solving. Sedna provide the related interfaces for application development. Parameters ---------- estimator : Instance An instance with the high-level API that greatly simplifies machine learning programming. Estimators encapsulate training, evaluation, prediction, and exporting for your model. task_definition : Dict Divide multiple tasks based on data, see `task_jobs.task_definition` for more detail. task_relationship_discovery : Dict Discover relationships between all tasks, see `task_jobs.task_relationship_discovery` for more detail. task_mining : Dict Mining tasks of inference sample, see `task_jobs.task_mining` for more detail. task_remodeling : Dict Remodeling tasks based on their relationships, see `task_jobs.task_remodeling` for more detail. inference_integrate : Dict Integrate the inference results of all related tasks, see `task_jobs.inference_integrate` for more detail. unseen_task_detect: Dict unseen task detect algorithms with parameters which has registered to ClassFactory, see `sedna.algorithms.unseen_task_detect` for more detail Examples -------- >>> estimator = XGBClassifier(objective="binary:logistic") >>> task_definition = { "method": "TaskDefinitionByDataAttr", "param": {"attribute": ["season", "city"]} } >>> task_relationship_discovery = { "method": "DefaultTaskRelationDiscover", "param": {} } >>> task_mining = { "method": "TaskMiningByDataAttr", "param": {"attribute": ["season", "city"]} } >>> task_remodeling = None >>> inference_integrate = { "method": "DefaultInferenceIntegrate", "param": {} } >>> unseen_task_detect = { "method": "TaskAttrFilter", "param": {} } >>> ll_jobs = LifelongLearning( estimator=estimator, task_definition=task_definition, task_relationship_discovery=task_relationship_discovery, task_mining=task_mining, task_remodeling=task_remodeling, inference_integrate=inference_integrate, unseen_task_detect=unseen_task_detect ) """ def __init__(self, estimator, task_definition=None, task_relationship_discovery=None, task_mining=None, task_remodeling=None, inference_integrate=None, unseen_task_detect=None): if not task_definition: task_definition = { "method": "TaskDefinitionByDataAttr" } if not unseen_task_detect: unseen_task_detect = { "method": "TaskAttrFilter" } e = MulTaskLearning( estimator=estimator, task_definition=task_definition, task_relationship_discovery=task_relationship_discovery, task_mining=task_mining, task_remodeling=task_remodeling, inference_integrate=inference_integrate) self.unseen_task_detect = unseen_task_detect.get("method", "TaskAttrFilter") self.unseen_task_detect_param = e._parse_param( unseen_task_detect.get("param", {}) ) config = dict( ll_kb_server=Context.get_parameters("KB_SERVER"), output_url=Context.get_parameters("OUTPUT_URL", "/tmp") ) task_index = FileOps.join_path(config['output_url'], KBResourceConstant.KB_INDEX_NAME.value) config['task_index'] = task_index super(LifelongLearning, self).__init__( estimator=e, config=config ) self.job_kind = K8sResourceKind.LIFELONG_JOB.value self.kb_server = KBClient(kbserver=self.config.ll_kb_server)
[docs] def train(self, train_data, valid_data=None, post_process=None, action="initial", **kwargs): """ fit for update the knowledge based on training data. Parameters ---------- train_data : BaseDataSource Train data, see `sedna.datasources.BaseDataSource` for more detail. valid_data : BaseDataSource Valid data, BaseDataSource or None. post_process : function function or a registered method, callback after `estimator` train. action : str `update` or `initial` the knowledge base kwargs : Dict parameters for `estimator` training, Like: `early_stopping_rounds` in Xgboost.XGBClassifier Returns ------- train_history : object """ callback_func = None if post_process is not None: callback_func = ClassFactory.get_cls( ClassType.CALLBACK, post_process) res, task_index_url = self.estimator.train( train_data=train_data, valid_data=valid_data, **kwargs ) # todo: Distinguishing incremental update and fully overwrite if isinstance(task_index_url, str) and FileOps.exists(task_index_url): task_index = FileOps.load(task_index_url) else: task_index = task_index_url extractor = task_index['extractor'] task_groups = task_index['task_groups'] model_upload_key = {} for task in task_groups: model_file = task.model.model save_model = FileOps.join_path( self.config.output_url, os.path.basename(model_file) ) if model_file not in model_upload_key: model_upload_key[model_file] = FileOps.upload(model_file, save_model) model_file = model_upload_key[model_file] try: model = self.kb_server.upload_file(save_model) except Exception as err: self.log.error( f"Upload task model of {model_file} fail: {err}" ) model = set_backend( estimator=self.estimator.estimator.base_model ) model.load(model_file) task.model.model = model for _task in task.tasks: sample_dir = FileOps.join_path( self.config.output_url, f"{_task.samples.data_type}_{_task.entry}.sample") task.samples.save(sample_dir) try: sample_dir = self.kb_server.upload_file(sample_dir) except Exception as err: self.log.error( f"Upload task samples of {_task.entry} fail: {err}") _task.samples.data_url = sample_dir save_extractor = FileOps.join_path( self.config.output_url, KBResourceConstant.TASK_EXTRACTOR_NAME.value ) extractor = FileOps.dump(extractor, save_extractor) try: extractor = self.kb_server.upload_file(extractor) except Exception as err: self.log.error(f"Upload task extractor fail: {err}") task_info = { "task_groups": task_groups, "extractor": extractor } fd, name = tempfile.mkstemp() FileOps.dump(task_info, name) index_file = self.kb_server.update_db(name) if not index_file: self.log.error(f"KB update Fail !") index_file = name FileOps.upload(index_file, self.config.task_index) task_info_res = self.estimator.model_info( self.config.task_index, relpath=self.config.data_path_prefix) self.report_task_info( None, K8sResourceKindStatus.COMPLETED.value, task_info_res) self.log.info(f"Lifelong learning Train task Finished, " f"KB idnex save in {self.config.task_index}") return callback_func(self.estimator, res) if callback_func else res
[docs] def update(self, train_data, valid_data=None, post_process=None, **kwargs): return self.train( train_data=train_data, valid_data=valid_data, post_process=post_process, action="update", **kwargs
)
[docs] def evaluate(self, data, post_process=None, **kwargs): """ evaluated the performance of each task from training, filter tasks based on the defined rules. Parameters ---------- data : BaseDataSource valid data, see `sedna.datasources.BaseDataSource` for more detail. kwargs: Dict parameters for `estimator` evaluate, Like: `ntree_limit` in Xgboost.XGBClassifier """ callback_func = None if callable(post_process): callback_func = post_process elif post_process is not None: callback_func = ClassFactory.get_cls( ClassType.CALLBACK, post_process) task_index_url = self.get_parameters( "MODEL_URLS", self.config.task_index) index_url = self.estimator.estimator.task_index_url self.log.info( f"Download kb index from {task_index_url} to {index_url}") FileOps.download(task_index_url, index_url) res, tasks_detail = self.estimator.evaluate(data=data, **kwargs) drop_tasks = [] model_filter_operator = self.get_parameters("operator", ">") model_threshold = float(self.get_parameters('model_threshold', 0.1)) operator_map = { ">": lambda x, y: x > y, "<": lambda x, y: x < y, "=": lambda x, y: x == y, ">=": lambda x, y: x >= y, "<=": lambda x, y: x <= y, } if model_filter_operator not in operator_map: self.log.warn( f"operator {model_filter_operator} use to " f"compare is not allow, set to <" ) model_filter_operator = "<" operator_func = operator_map[model_filter_operator] for detail in tasks_detail: scores = detail.scores entry = detail.entry self.log.info(f"{entry} scores: {scores}") if any(map(lambda x: operator_func(float(x), model_threshold), scores.values())): self.log.warn( f"{entry} will not be deploy because all " f"scores {model_filter_operator} {model_threshold}") drop_tasks.append(entry) continue drop_task = ",".join(drop_tasks) index_file = self.kb_server.update_task_status(drop_task, new_status=0) if not index_file: self.log.error(f"KB update Fail !") index_file = str(index_url) self.log.info( f"upload kb index from {index_file} to {self.config.task_index}") FileOps.upload(index_file, self.config.task_index) task_info_res = self.estimator.model_info( self.config.task_index, result=res, relpath=self.config.data_path_prefix) self.report_task_info( None, K8sResourceKindStatus.COMPLETED.value, task_info_res, kind="eval") return callback_func(res) if callback_func else res
[docs] def inference(self, data=None, post_process=None, **kwargs): """ predict the result for input data based on training knowledge. Parameters ---------- data : BaseDataSource inference sample, see `sedna.datasources.BaseDataSource` for more detail. post_process: function function or a registered method, effected after `estimator` prediction, like: label transform. kwargs: Dict parameters for `estimator` predict, Like: `ntree_limit` in Xgboost.XGBClassifier Returns ------- result : array_like results array, contain all inference results in each sample. is_unseen_task : bool `true` means detect an unseen task, `false` means not tasks : List tasks assigned to each sample. """ task_index_url = self.get_parameters( "MODEL_URLS", self.config.task_index) index_url = self.estimator.estimator.task_index_url FileOps.download(task_index_url, index_url) res, tasks = self.estimator.predict( data=data, post_process=post_process, **kwargs ) is_unseen_task = False if self.unseen_task_detect: try: if callable(self.unseen_task_detect): unseen_task_detect_algorithm = self.unseen_task_detect() else: unseen_task_detect_algorithm = ClassFactory.get_cls( ClassType.UTD, self.unseen_task_detect )() except ValueError as err: self.log.error("Lifelong learning " "Inference [UTD] : {}".format(err)) else: is_unseen_task = unseen_task_detect_algorithm( tasks=tasks, result=res, **self.unseen_task_detect_param ) return res, is_unseen_task, tasks