import re import os import marshal import sys MIN_FLOAT = -3.14e100 PROB_START_P = "prob_start.p" PROB_TRANS_P = "prob_trans.p" PROB_EMIT_P = "prob_emit.p" PrevStatus = { 'B':('E','S'), 'M':('M','B'), 'S':('S','E'), 'E':('B','M') } def load_model(): _curpath = os.path.normpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) start_p = {} abs_path = os.path.join(_curpath, PROB_START_P) with open(abs_path, 'rb') as f: start_p = marshal.load(f) trans_p = {} abs_path = os.path.join(_curpath, PROB_TRANS_P) with open(abs_path, 'rb') as f: trans_p = marshal.load(f) emit_p = {} abs_path = os.path.join(_curpath, PROB_EMIT_P) with open(abs_path, 'rb') as f: emit_p = marshal.load(f) return start_p, trans_p, emit_p if sys.platform.startswith("java"): start_P, trans_P, emit_P = load_model() else: from . import prob_start,prob_trans,prob_emit start_P, trans_P, emit_P = prob_start.P, prob_trans.P, prob_emit.P def viterbi(obs, states, start_p, trans_p, emit_p): V = [{}] #tabular path = {} for y in states: #init V[0][y] = start_p[y] + emit_p[y].get(obs[0],MIN_FLOAT) path[y] = [y] for t in range(1,len(obs)): V.append({}) newpath = {} for y in states: em_p = emit_p[y].get(obs[t],MIN_FLOAT) (prob,state ) = max([(V[t-1][y0] + trans_p[y0].get(y, MIN_FLOAT) + em_p, y0) for y0 in PrevStatus[y]]) V[t][y] = prob newpath[y] = path[state] + [y] path = newpath (prob, state) = max([(V[len(obs)-1][y], y) for y in ('E','S')]) return (prob, path[state]) def __cut(sentence): global emit_P prob, pos_list = viterbi(sentence, ('B','M','E','S'), start_P, trans_P, emit_P) begin, next = 0,0 #print pos_list, sentence for i,char in enumerate(sentence): pos = pos_list[i] if pos == 'B': begin = i elif pos == 'E': yield sentence[begin:i+1] next = i+1 elif pos == 'S': yield char next = i+1 if next < len(sentence): yield sentence[next:] def cut(sentence): if not isinstance(sentence, str): try: sentence = sentence.decode('utf-8') except UnicodeDecodeError: sentence = sentence.decode('gbk', 'ignore') re_han, re_skip = re.compile("([\u4E00-\u9FA5]+)"), re.compile("(\d+\.\d+|[a-zA-Z0-9]+)") blocks = re_han.split(sentence) for blk in blocks: if re_han.match(blk): for word in __cut(blk): yield word else: tmp = re_skip.split(blk) for x in tmp: if x: yield x