Spaces:
Sleeping
Sleeping
| import logging | |
| from collections import defaultdict | |
| from transformers import pipeline, AutoTokenizer | |
| logger = logging.getLogger(__name__) | |
| class CweInferHelper(): | |
| TOP_K = 5 | |
| MAX_LENGTH = 1024 | |
| MODEL_CONFIG = defaultdict(lambda model: { | |
| "model_name_or_path": model, | |
| "tokenizer_name": model | |
| }) | |
| MODEL_CONFIG.update({ | |
| "patchouli-cwe-UniXcoder": { | |
| "model_name_or_path": "./backend/model/cwe-cls/patchouli-unixcoder", | |
| "tokenizer_name": "microsoft/unixcoder-base-nine" | |
| } | |
| }) | |
| PREDEF_MODEL = list(MODEL_CONFIG.keys()) | |
| def __init__(self): | |
| self.model = None | |
| self.classifier = None | |
| self.tokenizer = None | |
| def load_model(self, model): | |
| logger.info(f"Loading CWE classify model: {model}") | |
| if model == self.model: | |
| return | |
| self.model = model | |
| model_name_or_path = self.MODEL_CONFIG[model]["model_name_or_path"] | |
| tokenizer_name = self.MODEL_CONFIG[model]["tokenizer_name"] | |
| self.classifier = pipeline("text-classification", model=model_name_or_path, tokenizer=tokenizer_name, device_map="auto") | |
| self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |
| def infer(self, diff_code, patch_message = None): | |
| if self.classifier is None: | |
| raise ValueError("Model is not loaded") | |
| input_text = "" | |
| if patch_message is not None and patch_message != "": | |
| input_text += f"[MESSAGE]\n{patch_message}\n" | |
| input_text += f"[PATCH]\n{diff_code}" | |
| logger.info(f"Classifying CWE for diff code") | |
| input_ids = self.tokenizer(input_text, max_length=CweInferHelper.MAX_LENGTH-10, padding="max_length", truncation=True).input_ids | |
| input_text = self.tokenizer.decode(input_ids) | |
| result = self.classifier(input_text, top_k = self.TOP_K) | |
| result = {item["label"]: item["score"] for item in result} | |
| return result | |
| cwe_infer_helper = CweInferHelper() | |
| if __name__ == "__main__": | |
| code = """diff --git a/net/netfilter/ipvs/ip_vs_ctl.c b/net/netfilter/ipvs/ip_vs_ctl.c | |
| index 6bde12da2fe003..c37ac2d7bec44d 100644 | |
| --- a/net/netfilter/ipvs/ip_vs_ctl.c | |
| +++ b/net/netfilter/ipvs/ip_vs_ctl.c | |
| @@ -2077,6 +2077,10 @@ do_ip_vs_set_ctl(struct sock *sk, int cmd, void __user *user, unsigned int len) | |
| if (!capable(CAP_NET_ADMIN)) | |
| return -EPERM; | |
| + if (cmd < IP_VS_BASE_CTL || cmd > IP_VS_SO_SET_MAX) | |
| + return -EINVAL; | |
| + if (len < 0 || len > MAX_ARG_LEN) | |
| + return -EINVAL; | |
| if (len != set_arglen[SET_CMDID(cmd)]) { | |
| pr_err("set_ctl: len %u != %u\n", | |
| len, set_arglen[SET_CMDID(cmd)]); | |
| @@ -2352,17 +2356,25 @@ do_ip_vs_get_ctl(struct sock *sk, int cmd, void __user *user, int *len) | |
| { | |
| unsigned char arg[128]; | |
| int ret = 0; | |
| + unsigned int copylen; | |
| if (!capable(CAP_NET_ADMIN)) | |
| return -EPERM; | |
| + if (cmd < IP_VS_BASE_CTL || cmd > IP_VS_SO_GET_MAX) | |
| + return -EINVAL; | |
| + | |
| if (*len < get_arglen[GET_CMDID(cmd)]) { | |
| pr_err("get_ctl: len %u < %u\n", | |
| *len, get_arglen[GET_CMDID(cmd)]); | |
| return -EINVAL; | |
| } | |
| - if (copy_from_user(arg, user, get_arglen[GET_CMDID(cmd)]) != 0) | |
| + copylen = get_arglen[GET_CMDID(cmd)]; | |
| + if (copylen > 128) | |
| + return -EINVAL; | |
| + | |
| + if (copy_from_user(arg, user, copylen) != 0) | |
| return -EFAULT; | |
| if (mutex_lock_interruptible(&__ip_vs_mutex)) | |
| """ | |
| cwe_infer_helper.load_model("patchouli") | |
| result = cwe_infer_helper.infer(code) | |
| print(result) | |