经验首页 前端设计 程序设计 Java相关 移动开发 数据库/运维 软件/图像 大数据/云计算 其他经验
当前位置:技术经验 » 大数据/云/AI » 人工智能基础 » 查看文章
聊聊GLM-4-9B开源模型的微调loss计算
来源:cnblogs  作者:又见阿郎  时间:2024/6/12 16:15:42  对本文有异议

概述

Github官方地址:GLM-4

网上已经有很多关于微调的文章,介绍各种方式下的使用,这里不会赘述。我个人比较关心的是微调时的loss计算逻辑,这点在很多的文章都不会有相关的描述,因为大多数人都是关心如何使用之类的应用层,而不是其具体的底层逻辑,当然咱也说不清太底层的计算。

可了解其它loss计算的文章:
再聊多轮对话微调训练格式与长序列训练
聊聊ChatGLM2与ChatGLM3微调多轮对话的设计逻辑及源码分析
聊聊大模型多轮对话的训练及优化

微调

微调格式:

  1. [
  2. {
  3. "messages": [
  4. {
  5. "role": "system",
  6. "content": "<system prompt text>",
  7. "tools": [
  8. {
  9. "name": "<tool name>",
  10. "args": {
  11. "<arg name>": "<arg value>"
  12. }
  13. }
  14. ]
  15. },
  16. {
  17. "role": "user",
  18. "content": "<user prompt text>"
  19. },
  20. {
  21. "role": "assistant",
  22. "content": "<assistant response text>"
  23. },
  24. {
  25. "role": "user",
  26. "content": "<user prompt text>"
  27. },
  28. {
  29. "role": "assistant",
  30. "content": "<assistant response text>"
  31. },
  32. {
  33. "role": "observation",
  34. "content": "<observation prompt text>"
  35. },
  36. {
  37. "role": "assistant",
  38. "content": "<assistant response observation>"
  39. },
  40. {
  41. "role": "user",
  42. "content": "<user prompt text>"
  43. },
  44. {
  45. "role": "assistant",
  46. "content": "<assistant response text>"
  47. }
  48. ]
  49. }
  50. ]

微调源码地址:finetune.py
Loss计算代码:

  1. def process_batch(
  2. batch: Mapping[str, Sequence],
  3. tokenizer: PreTrainedTokenizer,
  4. max_input_length: int,
  5. max_output_length: int,
  6. ) -> dict[str, list]:
  7. batched_conv = batch['messages']
  8. batched_input_ids = []
  9. batched_labels = []
  10. # batched_conv 是一个数组
  11. # conv 是数组内的单个 message
  12. for conv in batched_conv:
  13. input_ids = [151331, 151333]
  14. loss_masks = [False, False]
  15. # conv 是数组内的单个 message
  16. # message 是 单个role json对象
  17. for message in conv:
  18. message = process_message(message)
  19. # 设置 mask 掩码,只有system,user,observation不参与mask计算,其余的角色参与计算
  20. loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
  21. # 获取 input 文本的数字表示(ids)
  22. new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
  23. # 计算整句的 mask
  24. new_loss_masks = [loss_mask_val] * len(new_input_ids)
  25. # 拼接message中的每段json
  26. input_ids += new_input_ids
  27. # 拼接message中每段json对应的mask
  28. loss_masks += new_loss_masks
  29. # 追加结尾的 token id
  30. input_ids.append(tokenizer.eos_token_id)
  31. loss_masks = [False, *loss_masks]
  32. labels = []
  33. for input_id, mask in zip(input_ids, loss_masks):
  34. if mask:
  35. # 添加到label,计算loss
  36. labels.append(input_id)
  37. else:
  38. # -100 不处理,即ignore_index
  39. labels.append(-100)
  40. max_length = max_input_length + max_output_length + 1
  41. # 截断
  42. batched_input_ids.append(input_ids[:max_length])
  43. batched_labels.append(labels[:max_length])
  44. return {'input_ids': batched_input_ids, 'labels': batched_labels}

注释在代码中已经写明。process_batch方法用于将输入转换为ids,并计算mask(用于Loss计算)。而该方法的调用是在数据集的遍历处理中,即如下所示:

  1. tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
  2. data_manager = DataManager(data_dir, ft_config.data_config)
  3. # 数据集拆分遍历
  4. train_dataset = data_manager.get_dataset(
  5. Split.TRAIN,
  6. functools.partial(
  7. process_batch,
  8. tokenizer=tokenizer,
  9. max_input_length=ft_config.max_input_length,
  10. max_output_length=ft_config.max_output_length,
  11. ),
  12. batched=True,
  13. )
  14. print('train_dataset:', train_dataset)

Loss计算如下图所示:

总结

相比较于之前的ChatGLM版本,GLM4开源版本的多轮对话loss计算更恰当且效率也会更高;在其它的开源模型/微调框架中早已支持该种loss计算,如InternLM、XTuner、Firefly等。对于loss格式的类别,可参考XTuner的官方文档说明:dataset_format.md

原文链接:https://mp.weixin.qq.com/s/0mLCQfpaZr7eEonG4a4Etg

更多大模型相关的文章,请上个人公众号查阅:
image

原文链接:https://www.cnblogs.com/zhiyong-ITNote/p/18243420

 友情链接:直通硅谷  点职佳  北美留学生论坛

本站QQ群:前端 618073944 | Java 606181507 | Python 626812652 | C/C++ 612253063 | 微信 634508462 | 苹果 692586424 | C#/.net 182808419 | PHP 305140648 | 运维 608723728

W3xue 的所有内容仅供测试,对任何法律问题及风险不承担任何责任。通过使用本站内容随之而来的风险与本站无关。
关于我们  |  意见建议  |  捐助我们  |  报错有奖  |  广告合作、友情链接(目前9元/月)请联系QQ:27243702 沸活量
皖ICP备17017327号-2 皖公网安备34020702000426号