package utils
import (
"database/sql"
)
type SQLResult struct {
Count int
Columns []string
Rows []interface{}
}
func UnmarshalSQLRows(rows *sql.Rows) (*SQLResult, error) {
defer func() {
_ = rows.Close()
}()
columnTypes, err := rows.ColumnTypes()
if err != nil {
return nil, err
}
result := &SQLResult{}
result.Columns, err = rows.Columns()
if err != nil {
return nil, err
}
count := len(columnTypes)
for rows.Next() {
result.Count++
scanArgs := make([]interface{}, count)
for i, v := range columnTypes {
switch v.DatabaseTypeName() {
case "VARCHAR", "TEXT", "UUID", "TIMESTAMP":
scanArgs[i] = new(sql.NullString)
case "BOOL":
scanArgs[i] = new(sql.NullBool)
case "INT4":
scanArgs[i] = new(sql.NullInt64)
default:
scanArgs[i] = new(sql.NullString)
}
}
err := rows.Scan(scanArgs...)
if err != nil {
return result, err
}
masterData := make(map[string]interface{})
for i, v := range columnTypes {
if z, ok := (scanArgs[i]).(*sql.NullBool); ok {
masterData[v.Name()] = z.Bool
continue
}
if z, ok := (scanArgs[i]).(*sql.NullString); ok {
masterData[v.Name()] = z.String
continue
}
if z, ok := (scanArgs[i]).(*sql.NullInt64); ok {
masterData[v.Name()] = z.Int64
continue
}
if z, ok := (scanArgs[i]).(*sql.NullFloat64); ok {
masterData[v.Name()] = z.Float64
continue
}
if z, ok := (scanArgs[i]).(*sql.NullInt32); ok {
masterData[v.Name()] = z.Int32
continue
}
masterData[v.Name()] = scanArgs[i]
}
result.Rows = append(result.Rows, masterData)
}
return result, nil
}