diff --git a/internal/database/connect.go b/internal/database/connect.go
index 4a325a1..88e81ff 100644
--- a/internal/database/connect.go
+++ b/internal/database/connect.go
@@ -38,5 +38,5 @@ func ChangeConnection(c *gin.Context) {
session.Set("current", name)
session.Save()
- c.String(200, templates.ConnectionsList(connections, name)+TableTree(c))
+ c.String(200, templates.ConnectionsList(connections, name)+TableTree(c)+EnumTree(c))
}
diff --git a/internal/database/tree.go b/internal/database/tree.go
index 46bfe78..0cfca94 100644
--- a/internal/database/tree.go
+++ b/internal/database/tree.go
@@ -29,7 +29,7 @@ func TableTree(c *gin.Context) string {
url := connections[current]
- tree, err := generateTree(url)
+ tree, err := generateTableTree(url)
if err != nil {
fmt.Println(err)
return ""
@@ -39,7 +39,7 @@ func TableTree(c *gin.Context) string {
}
// Generate the tree of the database tables
-func generateTree(url string) (map[string][]model.Column, error) {
+func generateTableTree(url string) (map[string][]model.Column, error) {
conn, err := sql.Open("postgres", url)
if err != nil {
return map[string][]model.Column{}, err
@@ -118,17 +118,23 @@ func fillColumns(conn *sql.DB, tree map[string][]model.Column) error {
fkeys = append(fkeys, fkey)
}
- rows, err := conn.Query(fmt.Sprintf("SELECT column_name, is_nullable, data_type, character_maximum_length FROM information_schema.columns WHERE table_name = '%s';", table))
+ rows, err := conn.Query(fmt.Sprintf("SELECT c.column_name, c.is_nullable, c.data_type, c.character_maximum_length, t.typname AS enum_type FROM information_schema.columns c JOIN pg_type t ON c.udt_name = t.typname WHERE c.table_name = '%s';", table))
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
- var column model.Column
- if err := rows.Scan(&column.Name, &column.Nullable, &column.Type, &column.MaxLength); err != nil {
+ var (
+ column model.Column
+ enumType string
+ )
+ if err := rows.Scan(&column.Name, &column.Nullable, &column.Type, &column.MaxLength, &enumType); err != nil {
return err
}
+ if column.Type == "USER-DEFINED" {
+ column.Type = enumType
+ }
if column.Name == pkey {
column.PrimaryKey = true
}
@@ -172,3 +178,69 @@ func getUniqueColumns(conn *sql.DB, table string) ([]string, error) {
return cols, nil
}
+
+// Generate the tree of the database enums and their values
+func EnumTree(c *gin.Context) string {
+ session := sessions.Default(c)
+ connections_bytes, ok := session.Get("connections").([]byte)
+ current, ok := session.Get("current").(string)
+ if !ok {
+ fmt.Println("No connections found")
+ return ""
+ }
+
+ var connections map[string]string
+ if err := json.Unmarshal(connections_bytes, &connections); err != nil {
+ fmt.Println(err)
+ return ""
+ }
+
+ url := connections[current]
+
+ enums, err := genereteEnumTree(url)
+ if err != nil {
+ fmt.Println(err)
+ return ""
+ }
+
+ return templates.EnumTree(enums)
+}
+
+// Generate the tree of the database enums and their values from a
+// provided connection URL.
+func genereteEnumTree(url string) (map[string][]string, error) {
+ conn, err := sql.Open("postgres", url)
+ if err != nil {
+ return nil, err
+ }
+ defer conn.Close()
+
+ enums, err := enumList(conn)
+ if err != nil {
+ return nil, err
+ }
+
+ return enums, nil
+}
+
+// Get a list/map of all the enums in the database.
+// The key is the name of the enum and the value is a slice of the enum values.
+func enumList(conn *sql.DB) (map[string][]string, error) {
+ rows, err := conn.Query("SELECT t.typname AS enum_name, e.enumlabel AS enum_value FROM pg_type t JOIN pg_enum e ON t.oid = e.enumtypid JOIN pg_namespace n ON n.oid = t.typnamespace WHERE t.typcategory = 'E' AND n.nspname NOT IN ('pg_catalog', 'information_schema') ORDER BY t.typname, e.enumsortorder;")
+ if err != nil {
+ return map[string][]string{}, err
+ }
+ defer rows.Close()
+
+ enums := make(map[string][]string)
+ for rows.Next() {
+ var enum, value string
+ if err := rows.Scan(&enum, &value); err != nil {
+ return map[string][]string{}, err
+ }
+
+ enums[enum] = append(enums[enum], value)
+ }
+
+ return enums, nil
+}
diff --git a/internal/http/router.go b/internal/http/router.go
index 35076ca..1b2a767 100644
--- a/internal/http/router.go
+++ b/internal/http/router.go
@@ -70,9 +70,15 @@ func populate(web, api *gin.RouterGroup) {
})
api.POST("/connections/connect", database.ChangeConnection)
- web.GET("/connections/tree", func(c *gin.Context) {
+ web.GET("/connections/tree/table", func(c *gin.Context) {
c.String(200, database.TableTree(c))
})
+ web.GET("/connections/tree/enum", func(c *gin.Context) {
+ c.String(200, database.EnumTree(c))
+ })
+ web.GET("/connections/tree", func(c *gin.Context) {
+ c.String(200, database.TableTree(c)+database.EnumTree(c))
+ })
web.GET("/query/auto", templates.ToggleQueryType)
diff --git a/internal/templates/tree.go b/internal/templates/tree.go
index a71d4b3..41b788b 100644
--- a/internal/templates/tree.go
+++ b/internal/templates/tree.go
@@ -7,10 +7,10 @@ import (
"github.com/Azpect3120/Web-Database-Viewer/internal/model"
)
-// Tree definition
-const TREE_OPEN string = `
`
-const TREE_CLOSE string = `
`
-const TREE_BODY_TEMPLATE string = `%s`
+// Table tree definition
+const TABLE_TREE_OPEN string = ``
+const TABLE_TREE_CLOSE string = `
`
+const TABLE_TREE_BODY_TEMPLATE string = `%s`
// Table definition
const TABLE_TEMPLATE string = `
@@ -26,9 +26,9 @@ const TABLE_TEMPLATE string = `
`
// Fields definition
-const FIELDS_LIST_OPEN string = ``
-const FIELDS_LIST_CLOSE string = `
`
-const FIELD_TEMPLATE string = `
+const TABLE_FIELDS_LIST_OPEN string = ``
+const TABLE_FIELDS_LIST_CLOSE string = `
`
+const TABLE_FIELD_TEMPLATE string = `