protoc-gen-validate

2023-08-10 ⏳1.6分钟(0.6千字)

本文诣在让读者了解protoc-gen-validate的前世今生,以及如何开发一个protoc-gen-xxx的插件

在刚入职我司的时候,做过一个技术项目:为部门开源框架实现一个轻量级的proto校验规则,Pull request在此。当时的方式是直接合入了开源框架中,现在回想起来还有些许遗憾:

正好借此机会温故知新,回忆下当时的心路历程,补全遗憾。

背景

当时了解过一些开源插件,比如:

syntax = "proto3";

package examplepb;

import "validate/validate.proto";

message Person {
  uint64 id = 1 [(validate.rules).uint64.gt = 999];
}
syntax = "proto3";

package examplepb;

import "github.com/gogo/protobuf/gogoproto/gogo.proto";

message Person {
  int64 uid = 2 [(gogoproto.moretags) = "validate:\"gt=0,required\""];
}

他们都有一些相同的劣势:

基于此我们期望自己实现一个protoc-genvalidate插件,目标为:

过程

前置知识

我们需要先了解proto插件是如何运作的。

当你执行proptc --go_out=. test.proto时,它会将proto文件的结构发给protoc-gen-go可执行文件。本质上其实是会根据--xxx_out去寻找protoc-gen-xxx。同时我们可以配合使用google.golang.org/protobuf/compiler/protogengoogle自带的proto解析包去获取我们整个proto的结构树。

基于此,我们可以先生成一个简单的demo:protoc-gen-validate插件打印整个proto结构树。

插件代码如下:

package main

import (
        "fmt"
        "google.golang.org/protobuf/compiler/protogen"
)

func main() {
        protogen.Options{}.Run(func(plugin *protogen.Plugin) error {
                fmt.Println(plugin)
                return nil
        })
}

生成可执行文件:

go install .

生成测试proto文件:

syntax = "proto3";

package main;

service Test{
    rpc Foo(FooReq) returns (FooResp);
}

message FooReq {
    // @gt:0
    int64 uid = 1;
}

message FooResp {
}

执行命令protoc test.proto --validate_out=.

  proto protoc test.proto --validate_out=.
--validate_out: protoc-gen-validate: Plugin output is unparseable: &{file_to_generate:\"test.proto\"  proto_file:{name:\"test.proto\"  package:\"main\"  message_type:{name:\"FooReq\"  field:{name:\"uid\"  number:1  label:LABEL_OPTIONAL
ype:TYPE_INT64  json_name:\"uid\"}}  message_type:{name:\"FooResp\"}  service:{name:\"Test\"  method:{name:\"Foo\"  input_type:\".main.FooReq\"  output_type:\".main.FooResp\"}}  options:{go_package:\"./\"}  source_code_info:{location:{sp
n:0  span:0  span:16  span:1}  location:{path:12  span:0  span:0  span:18}  location:{path:2  span:2  span:0  span:13}  location:{path:8  span:4  span:0  span:25}  location:{path:8  path:11  span:4  span:0  span:25}  location:{path:6  pa
h:0  span:6  span:0  span:8  span:1}  location:{path:6  path:0  path:1  span:6  span:8  span:12}  location:{path:6  path:0  path:2  path:0  span:7  span:4  span:38}  location:{path:6  path:0  path:2  path:0  path:1  span:7  span:8  span:
1}  location:{path:6  path:0  path:2  path:0  path:2  span:7  span:12  span:18}  location:{path:6  path:0  path:2  path:0  path:3  span:7  span:29  span:36}  location:{path:4  path:0  span:10  span:0  span:13  span:1}  location:{path:4
ath:0  path:1  span:10  span:8  span:14}  location:{path:4  path:0  path:2  path:0  span:12  span:4  span:18  leading_comments:\" @gt:0\\n\"}  location:{path:4  path:0  path:2  path:0  path:5  span:12  span:4  span:9}  location:{path:4
ath:0  path:2  path:0  path:1  span:12  span:10  span:13}  location:{path:4  path:0  path:2  path:0  path:3  span:12  span:16  span:17}  location:{path:4  path:1  span:15  span:0  span:16  span:1}  location:{path:4  path:1  path:1  span:
5  span:8  span:15}}  syntax:\"proto3\"}  compiler_version:{major:3  minor:17  patch:3  suffix:\"\"} [0x1400016e900] map[test.proto:0x1400016e900] 0 0x1400000cf48 map[] map[main.FooReq:0x140000010e0 main.FooResp:0x14000001200] false 0  [
 {<nil> <nil>} <nil>}\n

我们初步demo成功。并可以发现注释在File.Services.Comments.Leading字段上。

实现细节

大致里程碑分为下述几点:

初步拟定需要的功能:

  1. 数字类型
  1. 字符串
  1. 数组
  1. 支持嵌套/循环的规则

确认协议:

使用 @+key+:+value的方式规避正常注射

生成方法: 通过协议我们可以获取到规则的数组kv形式,再通过text/template包生成模版内容,最后输出到文件中,拿eq举例:

package rule

const eqTpl = `
        if {{ .Key }} != {{ .Value }} {
                return {{ .Field.Parent.GoIdent.GoName }}ValidationError {
                        field:  "{{ .Field.GoName }}",
                        reason: "value must equal {{ escape .Value }}",
                }
        }
`

更多的内容可以直接查看project

此步完成后,我们已经可以为每个message生成专有的validate方法:

func (m *NumericsReq) validate() error {
  if m == nil {
          return nil
  }  
  if m.GetA() != 1.23 {
          return NumericsReqValidationError{
                  field:  "A",
                  reason: "value must equal 1.23",
          }
  }  
  if m.GetB() >= 20 {
          return NumericsReqValidationError{
                  field:  "B",
                  reason: "value must less than 20",
          }
  }  
  if m.GetB() <= 10 {
          return NumericsReqValidationError{
                  field:  "B",
                  reason: "value must greater than 10",
          }
  }  
  if m.GetC() > 20 {
          return NumericsReqValidationError{
                  field:  "C",
                  reason: "value must less than or equal to 20",
          }
  }  
  if m.GetC() < 10 {
          return NumericsReqValidationError{
                  field:  "C",
                  reason: "value must greater than or equal to 10",
          }
  }  
  var NumericsReq_D_In = map[uint32]struct{}{  
          1: {},  
          2: {},  
          3: {},
  }  
  if _, ok := NumericsReq_D_In[m.GetD()]; !ok {
          return NumericsReqValidationError{
                  field:  "D",
                  reason: "value must be in list [1,2,3]",
          }
  }  
  var NumericsReq_E_NotIn = map[float32]struct{}{  
          1: {},  
          2: {},  
          3: {},
  }  
  if _, ok := NumericsReq_E_NotIn[m.GetE()]; ok {
          return NumericsReqValidationError{
                  field:  "E",
                  reason: "value must be not in list [1,2,3]",
          }
  }  
  if m.GetF() <= 1 || m.GetF() >= 5 {
          return NumericsReqValidationError{
                  field:  "F",
                  reason: "value must in range (1,5)",
          }
  }  
  if m.GetG() < 1 || m.GetG() > 5 {
          return NumericsReqValidationError{
                  field:  "G",
                  reason: "value must in range [1,5]",
          }
  }  
  return nil
}

我们项目使用的是sniper框架,使用protoc-gen-twirp生成路由代码,我们增加函数:

func (t *twirp) addValidate(method *protogen.Method, service *protogen.Service) {
	if t.ValidateEnable {
		t.P(`  if  validerr := reqContent.validate(); validerr != nil {`)
		t.P(`    s.writeError(ctx, resp, twirp.InvalidArgumentError("argument", validerr.Error()))`)
		t.P(`    return`)
		t.P(`  }`)
	}
}

为每个req增加validate校验。基于此我们完成了整个protoc-gen-validate的功能实现

完善测试用例以及使用文档

补充了一波测试用例以及使用文档.让项目看上去明显更沉稳了哈哈。

结果

protoc-gen-validate也顺利推到了github上,欢迎小伙伴们使用以及交流,一起共勉。