Go中反射的利用,通用sql封装,字段拼接对应

在后端接口开发中,往往需要针对某一张表写相对应的增删改查的sql方法,比如我们查询某张表的数据

func (r *DomainDao) GetHostsByModel(model *config.HostsModel, startNum int, pagesize int) ([]*config.HostsModel, error) {

        db := r.Db.GetDB()
        if db == nil {
                return nil, errors.New("db is nil")
        }
        result := make([]*config.HostsModel, 0)
        args := make([]interface{}, 0)
        strSql := "SELECT * FROM table WHERE 1=1 "
        if model != nil {

                if model.Gid != 0 {
                        strSql += " AND g
                        args = append(args, model.Gid)
                }
                if len(strings.Trim(model.CdnDomain, " ")) != 0 {
                        strSql += " AND cdn_domain=? "
                        args = append(args, model.CdnDomain)
                }
                if len(strings.Trim(model.CdnName, " ")) != 0 {
                        strSql += " AND cdn_name=? "
                        args = append(args, model.CdnName)
                }
                if model.CdnType != 0 {
                        strSql += " AND cdn_type=? "
                        args = append(args, model.CdnType)
                }
                if model.Master != 0 {
                        strSql += " AND master=? "
                        args = append(args, model.Master)
                }
                if model.Mode != 0 {
                        strSql += " AND mode=? "
                        args = append(args, model.Mode)
                }
                if model.AuthOutTime != 0 {
                        strSql += " AND auth_out_time=? "
                        args = append(args, model.AuthOutTime)
                }
                if model.Enable != 0 {
                        strSql += " AND enable=? "
                        args = append(args, model.Enable)
                }
                if len(strings.Trim(model.AuthKey, " ")) != 0 {
                        strSql += " AND auth_key=? "
                        args = append(args, model.AuthKey)
                }
                if len(strings.Trim(model.Reserve, " ")) != 0 {
                        strSql += " AND reserve=? "
                        args = append(args, model.Reserve)
                }
        }
        if startNum >= 0 && pagesize > 0 {
                strSql += " LIMIT ?,?  "
                args = append(args, startNum)
                args = append(args, pagesize)
        }

        var err error
        if len(args) > 0 {
                _, err = db.ExecSelect(strSql, &result, args...)
        } else {
                _, err = db.ExecSelect(strSql, &result)
        }
        return result, err
}

  出于封装的目的,我们常常将查询条件动态封装拼接以满足我们不同的查询业务,但每张表的字段都是不同的,所有每张表都像这样写就显得很冗余低效了,所以需要对其再封装

通过反射进行封装:

//GetSelectTableSql 获取通用查询sql
/**
 * @name: GetSelectTableSql
 * @Descripttion: 获取通用查询sql
 * @param       {model} 表结构对应条件结构体
 * @param       {tableName} 表名称
 * @param       {startNum} 开始下标位置
 * @param       {pagesize} 条数
 * @return      {string} 拼装成的sql {[]interface{}} 动态参数
 */
func (db *MySQL) GetSelectTableSql(model interface{}, tableName string, startNum int, pagesize int) (string, []interface{}) {
        strSql := " SELECT * FROM " + tableName + " WHERE 1=1 "
        args := make([]interface{}, 0)
        if model != nil {
                refValue := reflect.ValueOf(model)
                refType := reflect.TypeOf(model)
                fieldCount := refValue.NumField()

                for i := 0; i < fieldCount; i++ {
                        fieldType := refType.Field(i)
                        fieldValue := refValue.Field(i)
                        gormTag := fieldType.Tag.Get("gorm")
                        //对应的列名称(tag中定义)
                        gormTag = strings.Replace(gormTag, "column:", "", -1)
                        gormTags := strings.Split(gormTag, ";")
                        if len(gormTags) > 0 {
                                column := gormTags[0]
                                isadd := false
                                switch fieldType.Type.String() {
                                case "string":
                                        if fieldValue.Len() > 0 {
                                                isadd = true
                                        }
                                case "int", "int8", "int16", "int32", "int64":
                                        if fieldValue.Int() != 0 {
                                                isadd = true
                                        }
                                case "time.Time":
                                        valTime := fieldValue.Interface().(time.Time)
                                        if !tool.CheckIsDefaultTime(valTime) {
                                                isadd = true
                                        } else {
                                                if column == "delete_at" {
                                                        strSql += " AND delete_at IS NULL "
                                                }
                                        }
                                }
                                if isadd {
                                        strSql += " AND " + column + " = ? "
                                        args = append(args, fieldValue.Interface())
                                }
                        }
                }
        }
        if startNum >= 0 && pagesize > 0 {
                strSql += " LIMIT ?,?  "
                args = append(args, startNum)
                args = append(args, pagesize)
        }
        return strSql, args
}

//GetSelectTableCountSql 获取通用查询的数量sql
/**
 * @name: GetSelectTableCountSql
 * @Descripttion: 获取通用查询的数量sql
 * @param       {model} 表结构对应条件结构体
 * @param       {tableName} 表名称
 * @return      {string} 拼装成的sql {[]interface{}} 动态参数
 */
func (db *MySQL) GetSelectTableCountSql(model interface{}, tableName string) (string, []interface{}) {
        strSql := " SELECT COUNT(1) FROM " + tableName + " WHERE 1=1 "
        args := make([]interface{}, 0)
        if model != nil {
                refValue := reflect.ValueOf(model)
                refType := reflect.TypeOf(model)
                fieldCount := refValue.NumField()

                for i := 0; i < fieldCount; i++ {
                        fieldType := refType.Field(i)
                        fieldValue := refValue.Field(i)
                        gormTag := fieldType.Tag.Get("gorm")
                        //对应的列名称(tag中定义)
                        gormTag = strings.Replace(gormTag, "column:", "", -1)
                        gormTags := strings.Split(gormTag, ";")
                        if len(gormTags) > 0 {
                                column := gormTags[0]
                                isadd := false
                                switch fieldType.Type.String() {
                                case "string":
                                        if fieldValue.Len() > 0 {
                                                isadd = true
                                        }
                                case "int", "int8", "int16", "int32", "int64":
                                        if fieldValue.Int() != 0 {
                                                isadd = true
                                        }
                                case "time.Time":
                                        valTime := fieldValue.Interface().(time.Time)
                                        if !tool.CheckIsDefaultTime(valTime) {
                                                isadd = true
                                        } else {
                                                if column == "delete_at" {
                                                        strSql += " AND delete_at IS NULL "
                                                }
                                        }
                                }
                                if isadd {
                                        strSql += " AND " + column + " = ? "
                                        args = append(args, fieldValue.Interface())
                                }
                        }
                }
        }
        return strSql, args
}

//InsertTable 通用新增表
/**
 * @name: InsertTable
 * @Descripttion: 通用新增表
 * @param       {model} 表结构对应结构体数据
 * @param       {tableName} 表名称
 * @return      {int64} 新增gid {error} 错误
 */
func (db *MySQL) InsertTable(model interface{}, tableName string) (int64, error) {
        if model == nil {
                return -1, errors.New("model is nil")
        }
        strSql := "insert " + tableName
        args := make([]interface{}, 0)
        strSql += " ("
        refValue := reflect.ValueOf(model)
        refType := reflect.TypeOf(model)
        fieldCount := refValue.NumField()
        for i := 0; i < fieldCount; i++ {
                fieldType := refType.Field(i)
                fieldValue := refValue.Field(i)
                gormTag := fieldType.Tag.Get("gorm")
                //对应的列名称(tag中定义)
                gormTag = strings.Replace(gormTag, "column:", "", -1)
                gormTags := strings.Split(gormTag, ";")
                if len(gormTags) > 0 {
                        column := gormTags[0]
                        isadd := false
                        switch fieldType.Type.String() {
                        case "string":
                                if fieldValue.Len() > 0 {
                                        isadd = true
                                }
                        case "int", "int8", "int16", "int32", "int64":
                                if fieldValue.Int() != 0 {
                                        isadd = true
                                }
                        case "time.Time":
                                valTime := fieldValue.Interface().(time.Time)
                                if !tool.CheckIsDefaultTime(valTime) {
                                        isadd = true
                                }
                        }
                        if isadd {
                                strSql += column + ","
                                args = append(args, fieldValue.Interface())
                        }
                }
        }

        if len(args) < 1 {
                return -1, errors.New("args is nil")
        }
        insertKeyStr := strSql[0:len(strSql)-1] + ") "
        insertValueStr := " values ("
        for i := 0; i < len(args); i++ {
                insertValueStr += "?"
                if i != len(args)-1 {
                        insertValueStr += ","
                }
        }
        insertValueStr += ")"
        insertSql := insertKeyStr + insertValueStr
        result, err := db.Exec(insertSql, args...)
        return result, err
}

//UpdateTableByColumn 通用修改
/**
 * @name: UpdateTableByColumn
 * @Descripttion: 通用修改通过表结构某一字段
 * @param       {model} 表结构对应结构体数据
 * @param       {tableName} 表名称
 * @param       {mapcolumn} 根据表字段修改 默认通过gid
 * @return      {int64} 新增gid {error} 错误
 */
func (db *MySQL) UpdateTableByColumn(model interface{}, tableName string, mapcolumn map[string]interface{}) (int64, error) {
        if model == nil {
                return -1, errors.New("model is nil")
        }
        strSql := "update " + tableName + " SET "

        args := make([]interface{}, 0)

        refValue := reflect.ValueOf(model)
        refType := reflect.TypeOf(model)
        fieldCount := refValue.NumField()

        var gid int64

        for i := 0; i < fieldCount; i++ {
                fieldType := refType.Field(i)
                fieldValue := refValue.Field(i)
                gormTag := fieldType.Tag.Get("gorm")
                //对应的列名称(tag中定义)
                gormTag = strings.Replace(gormTag, "column:", "", -1)
                gormTags := strings.Split(gormTag, ";")
                if len(gormTags) > 0 {
                        column := gormTags[0]
                        if column == "gid" {
                                gid = fieldValue.Interface().(int64)
                        }
                        isadd := false
                        switch fieldType.Type.String() {
                        case "string":
                                if fieldValue.Len() > 0 {
                                        isadd = true
                                }
                        case "int", "int8", "int16", "int32", "int64":
                                if fieldValue.Int() != 0 {
                                        isadd = true
                                }
                        case "time.Time":
                                valTime := fieldValue.Interface().(time.Time)
                                if !tool.CheckIsDefaultTime(valTime) {
                                        isadd = true
                                } else {
                                        if column == "delete_at" {
                                                strSql += "delete_at=NULL,"
                                        }
                                }
                        }
                        if isadd {
                                strSql += column + "=?,"
                                args = append(args, fieldValue.Interface())
                        }
                }
        }

        if len(args) < 1 {
                return -1, errors.New("args is nil")
        }

        //默认通过gid修改
        if mapcolumn == nil {
                if gid == 0 {
                        return -1, errors.New("update where is nil")
                } else {
                        mapcolumn = make(map[string]interface{})
                        mapcolumn["gid"] = gid
                }
        }

        updateStr := strSql[0:len(strSql)-1] + " where 1=1"

        for k, v := range mapcolumn {
                updateStr += " AND " + k + "=? "
                args = append(args, v)
        }

        result, err := db.Exec(updateStr, args...)
        return result, err
}

  

调用:

//新增
func (r *RatetemplateDao) InsertRatetemplate(model *config.RatetemplateModel) (int64, error) {
        db := r.Db.GetDB()
        if db == nil {
                return -1, errors.New(dao.DbErrMsg)
        }
        if model == nil {
                return -1, errors.New(daoErrMsg)
        }

        var intoModel interface{}
        if model == nil {
                intoModel = nil
        } else {
                intoModel = *model
        }
        result, err := db.InsertTable(intoModel, ratetemplateTableName)
        return result, err
}

//修改
func (r *RatetemplateDao) UpdateRatetemplateById(model *config.RatetemplateModel) (int64, error) {
        db := r.Db.GetDB()
        if db == nil {
                return -1, errors.New(dao.DbErrMsg)
        }
        if model == nil {
                return -1, errors.New(daoErrMsg)
        }
        var intoModel interface{}
        if model == nil {
                intoModel = nil
        } else {
                intoModel = *model
        }
        result, err := db.UpdateTableByColumn(intoModel, ratetemplateTableName, nil)
        return result, err
}

//查询
func (r *RatetemplateDao) GetRatetemplateByModel(model *config.RatetemplateModel, startNum int, pagesize int) ([]*config.RatetemplateModel, error) {

        db := r.Db.GetDB()
        if db == nil {
                return nil, errors.New(dao.DbErrMsg)
        }
        result := make([]*config.RatetemplateModel, 0)
        var intoModel interface{}
        if model == nil {
                intoModel = nil
        } else {
                intoModel = *model
        }

        strSql, args := db.GetSelectTableSql(intoModel, ratetemplateTableName, startNum, pagesize)
        var err error
        if len(args) > 0 {
                _, err = db.ExecSelect(strSql, &result, args...)
        } else {
                _, err = db.ExecSelect(strSql, &result)
        }
        return result, err
}

  传入的参数对象结构体需要在tag里面定义相应的解析值column:

type RatetemplateModel struct {
        Gid          int64     `sql:"Gid" gorm:"column:gid;primary_key;auto_increment;comment:'唯一标识';type:bigint(20)" json:"gid"`
        CdnGid       int64     `sql:"CdnGid" gorm:"column:cdn_gid;not null;comment:'拉流域名gid';type:bigint(20)" json:"cdn_gid"`
        AppName      string    `sql:"AppName" gorm:"column:app_name;not null;comment:'业务线名称(live)';type:varchar(32)" json:"app_name"`
}

  这样我们就不需要再写每张表的常规增删改查的sql语句了,而且当有大量的单一业务时,可以写一个代码生成工具根据数据库来生成这些代码

ps:反射的时候相应的传入对象结构体不能是指针类型的,如外层业务传入的是指针类型,需要转换为值类型,关键代码:
intoModel = *model,以上内容只是方便记录理解反射逻辑,某些具体sql执行方法未贴出