Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import random | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader, Dataset | |
| from transformers import BertTokenizer, BertModel | |
| from sklearn.model_selection import train_test_split | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import matplotlib.font_manager as fm | |
| import requests | |
| # ========================================== | |
| # 1. 基础环境配置 | |
| # ========================================== | |
| def set_seed(seed=42): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| set_seed(42) | |
| device = torch.device("cpu") | |
| # --- 字体下载 --- | |
| font_filename = 'SimHei.ttf' | |
| if not os.path.exists(font_filename): | |
| print("正在下载中文字体...") | |
| try: | |
| url = "https://github.com/StellarCN/scp_zh/raw/master/fonts/SimHei.ttf" | |
| r = requests.get(url) | |
| with open(font_filename, "wb") as f: | |
| f.write(r.content) | |
| except Exception as e: | |
| print(f"字体下载警告: {e}") | |
| try: | |
| fm.fontManager.addfont(font_filename) | |
| plt.rcParams['font.sans-serif'] = ['SimHei'] | |
| plt.rcParams['axes.unicode_minus'] = False | |
| except: | |
| pass | |
| # ========================================== | |
| # 2. 核心知识库与数据 | |
| # ========================================== | |
| MODEL_NAME = 'hfl/chinese-roberta-wwm-ext' | |
| LABEL_MAP = { | |
| 0: "抢劫罪", 1: "盗窃罪", 2: "诈骗罪", 3: "抢夺罪", 4: "侵占罪", | |
| 5: "职务侵占罪", 6: "挪用资金罪", 7: "敲诈勒索罪", | |
| 8: "故意毁坏财物罪", 9: "其他/无罪" | |
| } | |
| LEGAL_KNOWLEDGE_BASE = { | |
| "抢劫罪": {"article": "《刑法》第263条:以暴力、胁迫或者其他方法抢劫公私财物的。", "range": "处三年以上十年以下有期徒刑;\n情节严重的,处十年以上有期徒刑、无期徒刑或者死刑,并处罚金或者没收财产。"}, | |
| "盗窃罪": {"article": "《刑法》第264条:盗窃公私财物,数额较大的,或者多次盗窃、入户盗窃、携带凶器盗窃、扒窃的。", "range": "处三年以下有期徒刑、拘役或者管制,并处或者单处罚金;\n数额巨大或者有其他严重情节的,处三年以上十年以下有期徒刑,并处罚金;\n数额特别巨大或者有其他特别严重情节的,处十年以上有期徒刑或者无期徒刑,并处罚金或者没收财产。"}, | |
| "诈骗罪": {"article": "《刑法》第266条:诈骗公私财物。", "range": "数额较大的,处三年以下有期徒刑、拘役或者管制,并处或者单处罚金;\n数额巨大或者有其他严重情节的,处三年以上十年以下有期徒刑,并处罚金;\n数额特别巨大或者有其他特别严重情节的,处十年以上有期徒刑或者无期徒刑,并处罚金或者没收财产。"}, | |
| "抢夺罪": {"article": "《刑法》第267条:抢夺公私财物,数额较大的,或者多次抢夺的。", "range": "处三年以下有期徒刑、拘役或者管制,并处或者单处罚金;\n数额巨大或者有其他严重情节的,处三年以上十年以下有期徒刑,并处罚金;\n数额特别巨大或者有其他特别严重情节的,处十年以上尤其图形或者无期徒刑,并处罚金或者没收财产。"}, | |
| "侵占罪": {"article": "《刑法》第270条:将代为保管的他人财物非法占为己有,数额较大,拒不退还的。将他人遗忘物或者埋藏物非法占为己有,数额较大,拒不交出的。", "range": "处二年以下有期徒刑、拘役或者罚金;\n数额巨大或者有其他严重情节的,的处二年以上五年以下有期徒刑,并处罚金。"}, | |
| "职务侵占罪": {"article": "《刑法》第271条:公司、企业或者其他单位的工作人员利用职务上的便利,将本单位财物非法占为己有。", "range": "数额较大的,处三年以下有期徒刑或者拘役,并处罚金;\n数额巨大的,处三年以上十年以下有期徒刑,并处罚金;\n数额特别巨大的,处十年以上有期徒刑或者无期徒刑,并处罚金。三年以下有期徒刑或者拘役。"}, | |
| "挪用资金罪": {"article": "《刑法》第272条:公司、企业或者其他单位的工作人员利用职务上的便利,挪用本单位资金归个人使用或者借贷给他人。", "range": "数额较大、超过三个月未还的,或者虽未超过三个月,但数额较大、进行营利活动的,或者进行非法活动的,处三年以下有期徒刑或者拘役;\n挪用本单位资金数额巨大的,处三年以上七年以下有期徒刑;\n数额特别巨大的,处七年以上有期徒刑。"}, | |
| "敲诈勒索罪": {"article": "《刑法》第274条:敲诈勒索公私财物,数额较大或者多次敲诈勒索的。", "range": "数额较大或者多次敲诈勒索的,处三年以下有期徒刑、拘役或者管制,并处或者单处罚金;\n数额巨大或者有其他严重情节的,处三年以上十年以下有期徒刑,并处罚金;\n数额特别巨大或者有其他特别严重情节的,处十年以上有期徒刑,并处罚金。"}, | |
| "故意毁坏财物罪": {"article": "《刑法》第275条:故意毁坏公私财物,数额较大或者有其他严重情节的。", "range": "处三年以下有期徒刑、拘役或者罚金;\n数额巨大或者有其他特别严重情节的,处三年以上七年以下有期徒刑。"}, | |
| "其他/无罪": {"article": "无", "range": "无。"} | |
| } | |
| def generate_data(): | |
| data = [] | |
| # 0. 抢劫罪 | |
| data.extend([ | |
| ("以暴力、胁迫或者其他方法抢劫公私财物。", 0),("蒙面持刀进入金店,威胁营业员交出柜台里的所有黄金首饰。", 0), | |
| ("持刀拦截路人逼迫交出钱包", 0), ("暴力殴打店员抢走收银机现金", 0), ("入室盗窃被发现后拿出匕首威胁主人", 0), | |
| ("把人绑在椅子上搜身拿走金项链", 0), ("出租车司机拿刀逼乘客转账", 0), ("在小巷子里用棍棒把人打晕抢走手机", 0), | |
| ("冒充警察拦车,强行铐住司机拿走财物", 0), ("用电击棍威胁受害人交出银行卡", 0), ("抢完东西为了毁灭罪证,窝藏赃物,抗拒抓捕,而拿刀把人捅伤", 0), | |
| ("下药把网友迷晕,拿走随身财物", 0), ("拿着仿真枪冲进银行喊抢劫", 0), ("勒住被害人脖子致其昏迷,取走财物", 0), | |
| ("一群人围住学生扇耳光,强行抢走生活费", 0), ("入户抢劫,虽然没伤人但是拿刀威胁了", 0), ("骑摩托车抢包,因为拉扯导致受害人摔成骨折", 0), | |
| ]) | |
| # 1. 盗窃罪 | |
| data.extend([ | |
| ("盗窃公私财物或者扒窃。", 1),("在医院排队挂号时,用镊子夹走前方病人的皮夹子,内有现金2000元。", 1), | |
| ("趁人熟睡从枕头下拿走手机", 1), ("公交车上用镊子夹走乘客钱包", 1), ("深夜撬开商店卷帘门搬走烟酒", 1), | |
| ("看见路边电动车没拔钥匙直接骑走", 1), ("在网吧趁人上厕所拿走桌上的手机", 1), ("潜入邻居家偷走放在抽屉里的现金", 1), | |
| ("去商场试衣服,把新衣服穿里面偷走", 1), ("偷接路边的电缆线去卖废品", 1), ("在医院挂号处排队时偷别人包里的钱", 1), | |
| ("趁快递员送货,偷走三轮车上的包裹", 1), ("利用黑客技术偷偷转走别人的支付宝余额", 1), ("趁同事不注意,偷配钥匙打开保险柜拿钱", 1), | |
| ("在火车站候车室拎走睡着旅客的行李箱", 1), ("进入工地盗窃钢材", 1), ("偷偷拔走别人菜地里的珍贵药材", 1), | |
| ]) | |
| # 2. 诈骗罪 | |
| data.extend([ | |
| ("诈骗公私财物。", 2),("通过社交软件假装成白富美,以父亲生病为由骗取多名受害人钱财共计10万元。", 2), | |
| ("冒充公检法打电话说涉嫌洗钱要求转账", 2), ("虚构中奖信息骗取手续费", 2), ("网恋杀猪盘,诱导受害人投资虚假平台", 2), | |
| ("假装是大老板,骗取货款后失联", 2), ("碰瓷,故意撞车骗取赔偿金", 2), ("用假的金元宝冒充文物低价卖给路人", 2), | |
| ("冒充孙子出车祸,骗老奶奶汇款救命钱", 2), ("建立虚假购物网站,收款后不发货", 2), ("谎称有关系能低价买房,骗取中介费", 2), | |
| ("伪造房产证把租来的房子卖给别人", 2), ("假冒名医推销成分不明的保健品", 2), ("在网上发布虚假兼职刷单广告骗取本金", 2), | |
| ("冒充公司领导QQ要求财务转账", 2), ("用假币购买小额商品骗取真币找零", 2), ("谎称能帮人办理社保,收取费用后消失", 2), | |
| ]) | |
| # 3. 抢夺罪 | |
| data.extend([ | |
| ("驾驶助力车,在巷子口猛地拽走一名独行妇女挂在脖子上的金项链后逃离。", 3),("抢夺公私财物。", 3), | |
| ("骑摩托车经过路人时飞车夺走名牌包", 3), ("趁人低头看手机,一把夺走手机就跑", 3), ("在金店假装试戴金项链,戴上后转身就跑", 3), | |
| ("在银行门口趁人数钱,一把抓走现金逃跑", 3), ("趁人打电话不注意,抢了手机就跑", 3), ("买手机时拿着样机冲出店门", 3), | |
| ("在公交站台抢夺候车人的项链", 3), ("假装问路,趁机抢走对方手里的相机", 3), ("趁老人反应慢,夺走手里的买菜钱", 3), | |
| ("坐在副驾驶,伸手抢夺路人的手提袋", 3), ("趁店员转身拿货,抓起柜台上的金戒指就跑", 3), ("两人配合,一人引开注意,一人夺走财物", 3), | |
| ("在ATM机旁,趁人吐钞时抢走现金", 3), ("抢夺他人手中的欠条并撕毁", 3), ("飞车抢夺耳环,未造成人员伤亡", 3), | |
| ]) | |
| # 4. 侵占罪 | |
| data.extend([ | |
| ("捡到他人掉落的IPHONE手机,失主拨打时拒不承认,并将手机变卖获利。", 4),("将代为保管的他人财物非法占为己有,拒不退还。", 4), | |
| ("将他人的遗忘物或者埋藏物非法占为己有,拒不交出。", 4),("借朋友的高档相机玩了之后拒不归还", 4), | |
| ("捡到路人掉的装有5万元的钱包,拒不归还", 4), ("帮室友代管行李箱,结果把东西占为己有", 4), ("在出租屋床底发现金条,据为己有拒不交出", 4), | |
| ("4S店员工把客户遗忘在车里的名表拿走拒不承认", 4), ("替公司代收货款后谎称丢了,实际私吞", 4), ("捡到别人忘在ATM机的卡,取走钱款", 4), | |
| ("受朋友委托保管名画,后来谎称画被偷了", 4), ("租来的汽车到期后拒不归还,并把车卖了", 4), ("发现自家地里埋藏的古董,私自藏匿拒不上交", 4), | |
| ("拾得他人手机,刷机后自己使用", 4), ("帮邻居照看宠物狗,结果把狗卖了据为己有", 4), ("误收了别人多转的钱,对方索要时拒不退还", 4), | |
| ("保管公司电脑,离职时带走拒不归还", 4), ("老刘向老张借来一个名贵手表,后不想还给老张而占为己有", 4), | |
| ]) | |
| # 5. 职务侵占罪 | |
| data.extend([ | |
| ("身为仓储主管的王某,利用职务之便,多次私自将仓库内的空调零部件运出销售。", 5), | |
| ("公司、企业、单位工作人员,利用职务上的便利,将本单位的财物非法占为己有。", 5), | |
| ("私企会计做假账将公司钱转入自己卡", 5), ("国有银行行长虚报装修开支侵吞公款", 5), ("快递员将客户包裹占为己有", 5), | |
| ("公务员利用职务便利骗取国家补贴", 5), ("超市收银员利用漏洞私吞营业款", 5), ("国企高管伙同他人私分国有资产", 5), | |
| ("私企销售经理私吞客户货款", 5), ("村支书虚报冒领扶贫款归个人所有", 5), ("民办学校财务把学费转到自己账户炒股亏光", 5), | |
| ("公司仓库管理员把库存铜线偷出去卖", 5), ("公立学校校长截留教育经费", 5), ("物业管理员截留业主的物业费", 5), | |
| ("税务人员截留税款不入账", 5), ("利用职务之便将公司废旧设备变卖私吞", 5), ("虚报出差费用骗取公司/单位报销", 5), | |
| ]) | |
| # 6. 挪用资金罪 | |
| data.extend([ | |
| ("某公司会计挪用公款50万元用于个人炒股,并在半年后仍未归还。", 6), | |
| ("公司、企业、单位工作人员,利用职务上的便利,挪用本单位资金归个人使用,超过三个月未还。", 6), | |
| ("公司、企业、单位工作人员,利用职务上的便利,挪用本单位资金借贷给他人,超过三个月未还。", 6), | |
| ("公司、企业、单位工作人员,利用职务上的便利,挪用本单位资金进行非法活动。", 6), | |
| ("私企出纳私自把公司钱拿去炒股,打算赚了还", 6), ("村支书挪用村集体公款进行营利活动", 6), ("民营公司经理挪用货款借给朋友周转", 6), | |
| ("国企会计挪用公款归个人使用", 6), ("挪用公司资金进行赌博,但后来还上了", 6), ("公务员挪用单位资金给亲戚做生意", 6), | |
| ("私企会计把公司钱转出去理财,盈利归自己,本金还公司", 6), ("挪用公款进行赌博", 6), ("擅自将公司资金借给其他单位使用", 6), | |
| ("公立学校会计挪用学费炒股", 6), ("挪用子公司资金填补自己公司的亏空", 6), ("挪用社保基金归个人使用", 6), | |
| ("私自把客户备用金拿去买彩票", 6), ("国企出纳擅自把公款借给私企老板", 6), ("利用职务挪用单位资金归个人使用超过三个月", 6), | |
| ]) | |
| # 7. 敲诈勒索罪 | |
| data.extend([ | |
| ("掌握了受害人的隐私照片,威胁受害人支付5万元封口费,否则发给其家人。", 7),("敲诈勒索公私财物。", 7), | |
| ("威胁发裸照勒索前女友钱财", 7), ("如果不给保护费就天天来砸店", 7), ("抓住官员把柄,写信勒索巨额封口费", 7), | |
| ("PS艳照寄给受害人勒索钱财", 7), ("黑社会威胁不给钱就打断腿", 7), ("扣押欠债人子女,勒索赎金", 7), | |
| ("利用网贷裸条威胁女学生还高额利息", 7), ("自称记者,以曝光负面新闻勒索企业", 7), ("威胁房东不给钱就举报违建", 7), | |
| ("在饭菜里放虫子,威胁餐厅不赔钱就曝光", 7), ("捡到车牌,留下电话勒索车主赎金", 7), ("通过木马病毒锁死公司电脑,勒索比特币", 7), | |
| ("威胁不给钱就去家里闹事", 7), ("冒充黑社会打电话勒索", 7), ("知情人以揭发犯罪事实为由勒索钱财", 7), | |
| ]) | |
| # 8. 故意毁坏财物罪 | |
| data.extend([ | |
| ("因与邻居发生争执,被告人深夜用砖头将邻居刚买的汽车前挡风玻璃全部砸碎。", 8),("故意毁坏公私财物。", 8), | |
| ("因琐事泄愤,半夜划花对方的汽车", 8), ("吵架后砸碎邻居家窗户玻璃", 8), ("把共享单车扔进河里发泄", 8), | |
| ("因为嫉妒,剪烂室友的名牌包", 8), ("向别人的汽车排气管灌水泥", 8), ("放火烧毁他人堆放在田里的秸秆", 8), | |
| ("推倒别人砌好的围墙", 8), ("恶意砸坏街道上的摄像头", 8), ("把前男友的电脑摔个粉碎", 8), | |
| ("投毒毒死邻居家值钱的斗牛犬", 8), ("把别人地里的西瓜苗全部拔掉", 8), ("向别人的鱼塘里投毒", 8), | |
| ("砸坏公司的打卡机", 8), ("踹坏小区的门禁系统", 8), ("割破别人的汽车轮胎", 8), | |
| ]) | |
| # 9. 其他/无罪 | |
| data.extend([ | |
| ("双方因为停车位发生口角,后经派出所调解达成谅解,未造成人员受伤。", 9), | |
| ("今天天气不错,去楼下买了瓶水", 9), ("在家看电视,吃了个苹果", 9), ("路上捡到一个钱包,原地等待失主并归还了", 9), | |
| ("和朋友去公园放风筝", 9), ("买东西结账后离开", 9), ("去超市买鸡蛋排队付款", 9), ("在图书馆安静地看书", 9), | |
| ("扶老奶奶过马路", 9), ("将捡到的钱交给警察", 9), ("张三在餐厅吃饭拒不结账", 9), | |
| ("借了朋友五千块钱,因为失业了暂时还不上", 9), ("房客拖欠房租不给,房东很生气", 9), ("因为货物质量问题,拒绝支付尾款", 9), | |
| ("两人因为排队插队问题吵了起来", 9), ("两人因琐事发生口角,继而互殴,致一方轻伤", 9), ("开车不小心撞到了路边的行人,导致其骨折", 9), | |
| ("小王非法将小李拘禁在房间,声称不还钱就别想走", 9), ("在家里喝醉了酒,自己摔倒把头磕破了", 9), ("无证驾驶摩托车在路上狂飙", 9), | |
| ("吸食毒品被警察抓获", 9), ("在网络上散布谣言,造成恶劣影响", 9), ("非法携带管制刀具进入地铁", 9),("你好", 9), ("今晚吃什么呢", 9), | |
| ]) | |
| # 10. 难案 | |
| data.extend([ | |
| ("在小巷偷钱包被发现后,为了挣脱抓捕,挥拳将失主打成轻微伤并逃跑。", 0), | |
| ("尾随独行女青年,趁其不备夺走金项链,因用力过猛导致受害人颈部皮开肉绽。", 3), | |
| ("潜入他人房内行窃,被主人发现后,为了窝藏赃物当场使用折叠刀威胁对方别乱动。", 0), | |
| ("下药把相亲对象迷晕,然后带走其随身佩戴的钻戒和手机。", 0), | |
| ("用力拉扯被害人手中的皮包,因为用力过猛导致被害人手臂红肿。", 3), | |
| ("趁人不备夺走手机,为了逃跑推了一下被害人,未造成伤害。", 3), | |
| ("去朋友家做客,趁朋友去厨房做饭,将桌上充电的iPad塞进自己包里带走。", 1), | |
| ("在图书馆看书,趁邻座同学去接水的几分钟,拿走了对方放在桌上的名牌耳机。", 1), | |
| ("商场保安在值班巡逻时,看到柜台里一个戒指没锁,于是溜进去拿走藏入宿舍。", 1), | |
| ("快递公司分拣员,在传送带旁将一个破损包装里的黄金转运珠揣入兜里。", 5), | |
| ("公司保洁员趁下班无人,偷偷拿走办公室桌上的笔记本电脑。", 1), | |
| ("快递员送货途中,把车厢没锁好的包裹拿走藏起来。", 5), | |
| ("在手机店假装买手机,拿到真机试机时,趁店员转身拿配件,用假手机模型替换了真机。", 1), | |
| ("在餐厅吃饭后,趁服务员不注意,没有结账就偷偷从后门溜走。", 9), | |
| ("谎称可以帮忙办理入学,收取家长5万元赞助费,随后将钱挥霍并更换手机号失联。", 2), | |
| ("由于公司经营不善欠下巨额货款无法偿还,老板关门躲避债务。", 9), | |
| ("骑摩托车飞速夺取路人的挎包,导致路人被带倒在地拖行数米受重伤。", 0), | |
| ("趁店员不注意,抓起柜台上的金条就跑,店员在后面追,由于跑得快没被抓到。", 3), | |
| ("张三将电脑放在出租房内委托室友李四看管,李四私自将电脑卖掉并称电脑丢了,拒不赔钱。", 4), | |
| ("在出租车后座捡到前一名乘客掉落的手机,司机发现后将其关机并据为己有,拒不承认。", 4), | |
| ("在公园长椅上发现一个被他人遗忘的公文包,由于四下无人直接拎走。", 1), | |
| ("公司出纳利用管理现金的便利,私自支取5万元用于给自己买车,没打算还。", 5), | |
| ("公司经理利用职务便利,借用公司20万周转,打算三个月后卖了股票再还回来。", 6), | |
| ("私企主管将客户支付给公司的5000元定金直接花掉,月底报账时谎称客户没给。", 5), | |
| ("物业管理员张某挪用业主缴纳的维修基金去赌博,输光后无法归还。", 6), | |
| ("普通职员向同事借了3万块钱,约定一年后还,结果一年后由于没钱一直拖着。", 9), | |
| ("掌握了老板偷税漏税的证据,要求老板给10万元,否则就去税务局举报。", 7), | |
| ("买到过期面包后,向超市索赔1000元,并威胁不赔钱就发到短视频平台曝光。", 9), | |
| ("冒充黑社会老大打电话给店主,说有人出钱要你一条腿,如果不打钱过来就放火烧店。", 7), | |
| ("PS了几张某官员的合成裸照寄给对方,要求对方转账50万封口费。", 7), | |
| ("为了卖废铁,将正在路口使用的交通红绿灯拆走砸碎。", 9), | |
| ("潜入果园,为了报复果园老板,将即将成熟的几百棵苹果树全部砍断。", 8), | |
| ("半夜潜入仇人家里,没偷东西,而是把对方的高档家具全部泼了油漆。", 8), | |
| ("邻居家的狗经常半夜乱叫,一气之下买了毒香肠把狗毒死了。", 8), ("张三在马路上捡到一个钱包,原地等了两个小时还给了失主。", 9), | |
| ("路怒症发作,在马路上强行别车导致对方车辆失控撞上护栏,但无人受伤。", 9),("借了朋友5万块钱做生意,结果亏本了还不上,朋友报警。", 9), | |
| ("甲在候车室以需要紧急联络为名,向赵某借得高档手机,边打电话边向候车室外移动,出门后拔腿就跑,已经有所警觉的赵某猛追未果。", 9), | |
| ]) | |
| augmented_data = data * 4 | |
| return augmented_data | |
| class CrimeDataset(Dataset): | |
| def __init__(self, data, tokenizer, max_len): | |
| self.data = data | |
| self.tokenizer = tokenizer | |
| self.max_len = max_len | |
| def __len__(self): return len(self.data) | |
| def __getitem__(self, index): | |
| text, label = self.data[index] | |
| encoding = self.tokenizer.encode_plus( | |
| text, add_special_tokens=True, max_length=self.max_len, | |
| padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt' | |
| ) | |
| return { | |
| 'input_ids': encoding['input_ids'].flatten(), | |
| 'attention_mask': encoding['attention_mask'].flatten(), | |
| 'labels': torch.tensor(label, dtype=torch.long) | |
| } | |
| class CrimeClassifier(nn.Module): | |
| def __init__(self, n_classes): | |
| super(CrimeClassifier, self).__init__() | |
| self.bert = BertModel.from_pretrained(MODEL_NAME) | |
| self.drop = nn.Dropout(p=0.3) | |
| self.fc = nn.Linear(self.bert.config.hidden_size, n_classes) | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| pooled_output = outputs.pooler_output | |
| x = self.drop(pooled_output) | |
| x = self.fc(x) | |
| return x | |
| # ========================================== | |
| # 3. 训练流程 | |
| # ========================================== | |
| def train_model(): | |
| print("正在加载 BERT 模型和分词器...") | |
| tokenizer = BertTokenizer.from_pretrained(MODEL_NAME) | |
| model = CrimeClassifier(len(LABEL_MAP)).to(device) | |
| raw_data = generate_data() | |
| random.shuffle(raw_data) | |
| dataset = CrimeDataset(raw_data, tokenizer, max_len=80) | |
| loader = DataLoader(dataset, batch_size=32, shuffle=True) | |
| optimizer = optim.AdamW(model.parameters(), lr=5e-5) | |
| criterion = nn.CrossEntropyLoss() | |
| EPOCHS = 3 | |
| print(f"开始训练 (共 {EPOCHS} 轮)...") | |
| model.train() | |
| loss_history = [] | |
| for epoch in range(EPOCHS): | |
| total_loss = 0 | |
| for i, batch in enumerate(loader): | |
| input_ids = batch['input_ids'].to(device) | |
| attention_mask = batch['attention_mask'].to(device) | |
| labels = batch['labels'].to(device) | |
| optimizer.zero_grad() | |
| outputs = model(input_ids, attention_mask) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| if i % 5 == 0: | |
| print(f"Epoch {epoch+1}/{EPOCHS} | Step {i} | Loss: {loss.item():.4f}") | |
| avg_loss = total_loss / len(loader) | |
| loss_history.append(avg_loss) | |
| print(f"Epoch {epoch+1} 完成,平均 Loss: {avg_loss:.4f}") | |
| # 绘图 | |
| plt.figure(figsize=(8, 4)) | |
| plt.plot(loss_history, marker='o') | |
| plt.title("Training Loss Trend") | |
| plt.xlabel("Epoch") | |
| plt.ylabel("Loss") | |
| plt.grid(True) | |
| plt.savefig("loss.png") | |
| plt.close() | |
| return model, tokenizer | |
| model, tokenizer = train_model() | |
| # ========================================== | |
| # 4. Gradio 界面 | |
| # ========================================== | |
| def ai_judge_interface(text): | |
| if not text.strip(): return "请输入案情", "", "" | |
| model.eval() | |
| with torch.no_grad(): | |
| encoded = tokenizer.encode_plus( | |
| text, max_length=80, padding='max_length', truncation=True, return_tensors='pt' | |
| ) | |
| input_ids = encoded['input_ids'].to(device) | |
| mask = encoded['attention_mask'].to(device) | |
| logits = model(input_ids, mask) | |
| probs = F.softmax(logits, dim=1) | |
| # --- 关键修改:获取 Top 2 --- | |
| top_probs, top_indices = torch.topk(probs, k=2, dim=1) | |
| # 第一名 | |
| first_idx = top_indices[0][0].item() | |
| first_conf = top_probs[0][0].item() | |
| first_label = LABEL_MAP[first_idx] | |
| # 第二名 | |
| second_idx = top_indices[0][1].item() | |
| second_conf = top_probs[0][1].item() | |
| second_label = LABEL_MAP[second_idx] | |
| # 结果格式化:显示前两名 | |
| verdict = f" 首选研判:【{first_label}】\n 置信度:{first_conf:.2%}\n\n" | |
| verdict += f" 备选可能性:{second_label}\n 置信度:{second_conf:.2%}" | |
| # 获取法条(仅针对首选研判) | |
| info = LEGAL_KNOWLEDGE_BASE.get(first_label, {}) | |
| law_content = f"{info.get('article', '暂无')} " | |
| # 给出量刑建议 | |
| sentencing = info.get('range', '建议进一步咨询律师') | |
| return verdict, law_content, sentencing | |
| # 界面构建 | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 刑事案件罪名辅助研判系统") | |
| gr.Markdown("深度学习课程期末设计") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_text = gr.Textbox( | |
| label="案情描述输入", | |
| lines=5, | |
| placeholder="例如:趁人不备,一把夺走金项链就跑..." | |
| ) | |
| run_btn = gr.Button("开始研判", variant="primary") | |
| gr.Markdown("### 快速测试用例") | |
| gr.Examples( | |
| examples=[ | |
| ["趁人不备,一把夺走金项链就跑。"], | |
| ["拿着刀逼迫路人交出钱包,不然就捅人。"], | |
| ["我是公司会计,虚报账目把50万公款转到自己卡里买房。"], | |
| ["借了朋友相机不还,说丢了,其实卖了。"], | |
| ["看邻居不顺眼,把他车胎扎破了。"], | |
| ["两人发生口角,互殴导致轻伤。"] | |
| ], | |
| inputs=input_text | |
| ) | |
| with gr.Column(scale=1): | |
| out_verdict = gr.Textbox(label="1. 罪名研判", lines=4) | |
| out_law = gr.Textbox(label="2. 法条依据", lines=2) | |
| out_sent = gr.Textbox(label="3. 法定刑参考", lines=2) | |
| with gr.Accordion("查看训练 Loss 曲线", open=False): | |
| gr.Image("loss.png", label="Training Loss") | |
| run_btn.click( | |
| ai_judge_interface, | |
| inputs=input_text, | |
| outputs=[out_verdict, out_law, out_sent] | |
| ) | |
| demo.launch() |