diff --git a/internal/database/tree.go b/internal/database/tree.go index d0de7ae..46bfe78 100644 --- a/internal/database/tree.go +++ b/internal/database/tree.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" + "github.com/Azpect3120/Web-Database-Viewer/internal/model" "github.com/Azpect3120/Web-Database-Viewer/internal/templates" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" @@ -38,20 +39,20 @@ func TableTree(c *gin.Context) string { } // Generate the tree of the database tables -func generateTree(url string) (map[string][]string, error) { +func generateTree(url string) (map[string][]model.Column, error) { conn, err := sql.Open("postgres", url) if err != nil { - return map[string][]string{}, err + return map[string][]model.Column{}, err } defer conn.Close() tree, err := tableList(conn) if err != nil { - return map[string][]string{}, err + return map[string][]model.Column{}, err } if err := fillColumns(conn, tree); err != nil { - return map[string][]string{}, err + return map[string][]model.Column{}, err } return tree, nil @@ -59,20 +60,20 @@ func generateTree(url string) (map[string][]string, error) { // Return a map with the keys being the table names and the values // being blank which can be later used to store the columns. -func tableList(conn *sql.DB) (map[string][]string, error) { +func tableList(conn *sql.DB) (map[string][]model.Column, error) { rows, err := conn.Query("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_type = 'BASE TABLE';") if err != nil { - return map[string][]string{}, err + return map[string][]model.Column{}, err } defer rows.Close() - tree := make(map[string][]string) + tree := make(map[string][]model.Column) for rows.Next() { var table string if err := rows.Scan(&table); err != nil { - return map[string][]string{}, err + return map[string][]model.Column{}, err } - tree[table] = []string{} + tree[table] = []model.Column{} } return tree, nil @@ -84,22 +85,90 @@ func tableList(conn *sql.DB) (map[string][]string, error) { // For now, the only data stored is the // column name, but in the future this could be expanded to store // datatype, constraints, primary keys, relationship, etc. -func fillColumns(conn *sql.DB, tree map[string][]string) error { +func fillColumns(conn *sql.DB, tree map[string][]model.Column) error { + var pkey string + var fkeys []model.ForeignKey for table := range tree { - rows, err := conn.Query(fmt.Sprintf("SELECT column_name FROM information_schema.columns WHERE table_name = '%s';", table)) + unique, err := getUniqueColumns(conn, table) + if err != nil { + return err + } + + pk, err := conn.Query(fmt.Sprintf("SELECT kcu.column_name FROM information_schema.table_constraints tc JOIN information_schema.key_column_usage kcu ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema WHERE tc.constraint_type = 'PRIMARY KEY' AND tc.table_name = '%s';", table)) + if err != nil { + return err + } + defer pk.Close() + for pk.Next() { + if err := pk.Scan(&pkey); err != nil { + return err + } + } + + fk, err := conn.Query(fmt.Sprintf("SELECT tc.table_schema, tc.table_name, kcu.column_name, ccu.table_schema AS foreign_table_schema, ccu.table_name AS foreign_table_name, ccu.column_name AS foreign_column_name FROM information_schema.table_constraints AS tc JOIN information_schema.key_column_usage AS kcu ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema JOIN information_schema.constraint_column_usage AS ccu ON ccu.constraint_name = tc.constraint_name AND ccu.table_schema = tc.table_schema WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name = '%s';", table)) + if err != nil { + return err + } + defer fk.Close() + for fk.Next() { + var fkey model.ForeignKey + if err := fk.Scan(new(interface{}), new(interface{}), &fkey.Column, new(interface{}), &fkey.ForeignTable, &fkey.ForeignColumn); err != nil { + return err + } + fkeys = append(fkeys, fkey) + } + + rows, err := conn.Query(fmt.Sprintf("SELECT column_name, is_nullable, data_type, character_maximum_length FROM information_schema.columns WHERE table_name = '%s';", table)) if err != nil { return err } defer rows.Close() for rows.Next() { - var column string - if err := rows.Scan(&column); err != nil { + var column model.Column + if err := rows.Scan(&column.Name, &column.Nullable, &column.Type, &column.MaxLength); err != nil { return err } + if column.Name == pkey { + column.PrimaryKey = true + } + for _, fkey := range fkeys { + if column.Name == fkey.Column { + column.ForeignKey = fkey + } else { + column.ForeignKey = model.ForeignKey{} + } + } + + for _, u := range unique { + if column.Name == u { + column.Unique = true + } + } + tree[table] = append(tree[table], column) } } return nil } + +// Returns a list of the unique columns in a table +func getUniqueColumns(conn *sql.DB, table string) ([]string, error) { + var cols []string + rows, err := conn.Query(fmt.Sprintf("SELECT kcu.column_name FROM information_schema.table_constraints AS tc JOIN information_schema.key_column_usage AS kcu ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema WHERE tc.constraint_type = 'UNIQUE' AND kcu.table_name = '%s';", table)) + if err != nil { + return []string{}, err + } + defer rows.Close() + + for rows.Next() { + var col string + if err := rows.Scan(&col); err != nil { + return []string{}, err + } + cols = append(cols, col) + } + + return cols, nil +} diff --git a/internal/model/tree.go b/internal/model/tree.go new file mode 100644 index 0000000..d956a20 --- /dev/null +++ b/internal/model/tree.go @@ -0,0 +1,20 @@ +package model + +import "database/sql" + +// Column data structure +type Column struct { + Name string + Type string + MaxLength sql.NullInt64 + Nullable string + PrimaryKey bool + ForeignKey ForeignKey + Unique bool +} + +type ForeignKey struct { + Column string + ForeignTable string + ForeignColumn string +} diff --git a/internal/templates/tree.go b/internal/templates/tree.go index 188a3d4..a71d4b3 100644 --- a/internal/templates/tree.go +++ b/internal/templates/tree.go @@ -3,6 +3,8 @@ package templates import ( "fmt" "sort" + + "github.com/Azpect3120/Web-Database-Viewer/internal/model" ) // Tree definition @@ -28,13 +30,14 @@ const FIELDS_LIST_OPEN string = `
Connect and query your databases effortlessly.