Newer
Older
/*
 * Copyright (c) 2020  InterDigital Communications, Inc
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package usersdb
import (
	"database/sql"
	"errors"
	"strings"
	log "github.com/InterDigitalInc/AdvantEDGE/go-packages/meep-logger"
	_ "github.com/lib/pq"
)
// DB Config
const (
	DbHost              = "meep-postgis.default.svc.cluster.local"
	DbPort              = "5432"
	DbUser              = ""
	DbPassword          = ""
	DbDefault           = "postgres"
	DbMaxRetryCount int = 2
)
// DB Table Names
const (
	UsersTable = "users"
Kevin Di Lallo
committed
const (
	ProviderLocal = "local"
)
	RoleUser  = "user"
Kevin Di Lallo
committed
	RoleAdmin = "admin"
Kevin Di Lallo
committed
	Id       string
Kevin Di Lallo
committed
	Provider string
Kevin Di Lallo
committed
	Username string
	Password string
	Role     string
	Sboxname string
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
}
// Connector - Implements a Postgis SQL DB connector
type Connector struct {
	name      string
	dbName    string
	db        *sql.DB
	connected bool
}
// NewConnector - Creates and initializes a Postgis connector
func NewConnector(name, user, pwd, host, port string) (pc *Connector, err error) {
	if name == "" {
		err = errors.New("Missing connector name")
		return nil, err
	}
	// Create new connector
	pc = new(Connector)
	pc.name = name
	// Connect to Postgis DB
	for retry := 0; retry <= DbMaxRetryCount; retry++ {
		pc.db, err = pc.connectDB("", user, pwd, host, port)
		if err == nil {
			break
		}
	}
	if err != nil {
		log.Error("Failed to connect to postgis with err: ", err.Error())
		return nil, err
	}
	defer pc.db.Close()
	// Create DB if it does not exist
	// Use format: '<name>' & replace dashes with underscores
	pc.dbName = strings.ToLower(strings.Replace(name, "-", "_", -1))
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
	// Ignore DB creation error in case it already exists.
	// Failure will occur at DB connection if DB was not successfully created.
	_ = pc.CreateDb(pc.dbName)
	// Close connection to postgis
	pc.db.Close()
	// Connect with DB
	pc.db, err = pc.connectDB(pc.dbName, user, pwd, host, port)
	if err != nil {
		log.Error("Failed to connect to DB with err: ", err.Error())
		return nil, err
	}
	log.Info("Postgis Connector successfully created")
	pc.connected = true
	return pc, nil
}
func (pc *Connector) connectDB(dbName, user, pwd, host, port string) (db *sql.DB, err error) {
	// Set default values if none provided
	if dbName == "" {
		dbName = DbDefault
	}
	if host == "" {
		host = DbHost
	}
	if port == "" {
		port = DbPort
	}
	log.Debug("Connecting to Postgis DB [", dbName, "] at addr [", host, ":", port, "]")
	// Open postgis DB
	connStr := "user=" + user + " password=" + pwd + " dbname=" + dbName + " host=" + host + " port=" + port + " sslmode=disable"
	db, err = sql.Open("postgres", connStr)
	if err != nil {
		log.Warn("Failed to connect to Postgis DB with error: ", err.Error())
		return nil, err
	}
	// Make sure connection is up
	err = db.Ping()
	if err != nil {
		log.Warn("Failed to ping Postgis DB with error: ", err.Error())
		db.Close()
		return nil, err
	}
	log.Info("Connected to Postgis DB [", dbName, "]")
	return db, nil
}
// CreateDb -- Create new DB with provided name
func (pc *Connector) CreateDb(name string) (err error) {
	_, err = pc.db.Exec("CREATE DATABASE " + name)
	if err != nil {
		log.Error(err.Error())
		return err
	}
	log.Info("Created database: " + name)
	return nil
}
func (pc *Connector) CreateTables() (err error) {
	_, err = pc.db.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto")
	if err != nil {
		log.Error(err.Error())
		return err
	}
	// users Table
	_, err = pc.db.Exec(`CREATE TABLE IF NOT EXISTS ` + UsersTable + ` (
		id			SERIAL			PRIMARY KEY,
Kevin Di Lallo
committed
		provider	varchar(20)		NOT NULL DEFAULT '` + ProviderLocal + `',
		username	varchar(36)		NOT NULL,
		password	varchar(100)	NOT NULL,
Kevin Di Lallo
committed
		role		varchar(36)		NOT NULL DEFAULT '` + RoleUser + `',
		sboxname	varchar(11)		NOT NULL DEFAULT ''
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
	)`)
	if err != nil {
		log.Error(err.Error())
		return err
	}
	log.Info("Created table: ", UsersTable)
	return nil
}
// DeleteTables - Delete all tables
func (pc *Connector) DeleteTables() (err error) {
	_ = pc.DeleteTable(UsersTable)
	return nil
}
// DeleteTable - Delete table with provided name
func (pc *Connector) DeleteTable(tableName string) (err error) {
	_, err = pc.db.Exec("DROP TABLE IF EXISTS " + tableName)
	if err != nil {
		log.Error(err.Error())
		return err
	}
	log.Info("Deleted table: " + tableName)
	return nil
}
// CreateUser - Create new user
Kevin Di Lallo
committed
func (pc *Connector) CreateUser(provider string, username string, password string, role string, sboxname string) (err error) {
Kevin Di Lallo
committed
	if username == "" {
		return errors.New("Missing username")
	}
	if role == "" {
		role = RoleUser
	} else {
		err = isValidRole(role)
		if err != nil {
			return err
		}
	}
	// Create entry
Kevin Di Lallo
committed
	query := `INSERT INTO ` + UsersTable + ` (provider, username, password, role, sboxname)
		VALUES ($1, $2, crypt('` + password + `', gen_salt('bf')), $3, $4)`
	_, err = pc.db.Exec(query, provider, username, role, sboxname)
	if err != nil {
		log.Error(err.Error())
		return err
	}
	return nil
}
// UpdateUser - Update existing user
Kevin Di Lallo
committed
func (pc *Connector) UpdateUser(provider string, username string, password string, role string, sboxname string) (err error) {
Kevin Di Lallo
committed
	if provider == "" {
		provider = ProviderLocal
	}
	if username == "" {
		return errors.New("Missing username")
	}
	if password != "" {
		query := `UPDATE ` + UsersTable + `
			SET password = crypt('` + password + `', gen_salt('bf'))
Kevin Di Lallo
committed
			WHERE provider = ($1) AND username = ($2)`
		_, err = pc.db.Exec(query, provider, username)
		if err != nil {
			log.Error(err.Error())
			return err
		}
	}
	if role != "" {
		err = isValidRole(role)
		if err != nil {
			return err
		}
		query := `UPDATE ` + UsersTable + `
Kevin Di Lallo
committed
			SET role = $3
			WHERE provider = ($1) AND username = ($2)`
		_, err = pc.db.Exec(query, provider, username, role)
		if err != nil {
			log.Error(err.Error())
			return err
		}
	}
	if sboxname != "" {
		query := `UPDATE ` + UsersTable + `
Kevin Di Lallo
committed
			SET sboxname = $3
			WHERE provider = ($1) AND username = ($2)`
		_, err = pc.db.Exec(query, provider, username, sboxname)
		if err != nil {
			log.Error(err.Error())
			return err
		}
	}
	return nil
}
// GetUser - Get user information
Kevin Di Lallo
committed
func (pc *Connector) GetUser(provider string, username string) (user *User, err error) {
Kevin Di Lallo
committed
	if provider == "" {
		provider = ProviderLocal
	}
	if username == "" {
		err = errors.New("Missing username")
		return nil, err
	}
	// Get user entry
	var rows *sql.Rows
	rows, err = pc.db.Query(`
Kevin Di Lallo
committed
		SELECT id, provider, username, password, role, sboxname
Kevin Di Lallo
committed
		WHERE provider = ($1) AND username = ($2)`, provider, username)
	if err != nil {
		log.Error(err.Error())
		return nil, err
	}
	defer rows.Close()
	// Scan result
	for rows.Next() {
		user = new(User)
Kevin Di Lallo
committed
		err = rows.Scan(&user.Id, &user.Provider, &user.Username, &user.Password, &user.Role, &user.Sboxname)
		if err != nil {
			log.Error(err.Error())
			return nil, err
		}
	}
	err = rows.Err()
	if err != nil {
		log.Error(err)
	}
	// Return error if not found
	if user == nil {
Kevin Di Lallo
committed
		err = errors.New(provider + " user not found: " + username)
		return nil, err
	}
	return user, nil
}
// GetAllUsers - Get All users
func (pc *Connector) GetUsers() (userMap map[string]*User, err error) {
	// Create map
	userMap = make(map[string]*User)
	// Get user entries
	var rows *sql.Rows
	rows, err = pc.db.Query(`
Kevin Di Lallo
committed
		SELECT id, provider, username, password, role, sboxname
		FROM ` + UsersTable)
	if err != nil {
		log.Error(err.Error())
		return userMap, err
	}
	defer rows.Close()
	// Scan results
	for rows.Next() {
		user := new(User)
Kevin Di Lallo
committed
		err = rows.Scan(&user.Id, &user.Provider, &user.Username, &user.Password, &user.Role, &user.Sboxname)
		if err != nil {
			log.Error(err.Error())
			return userMap, err
		}
		// Add to map
Kevin Di Lallo
committed
		userMap[pc.GetUserKey(user.Provider, user.Username)] = user
	}
	err = rows.Err()
	if err != nil {
		log.Error(err)
	}
	return userMap, nil
}
Kevin Di Lallo
committed
// GetUserKey - Get provider-specific user key
func (pc *Connector) GetUserKey(provider string, username string) (key string) {
	if provider == "" {
		provider = ProviderLocal
	}
	return provider + "-" + username
}
Kevin Di Lallo
committed
func (pc *Connector) DeleteUser(provider string, username string) (err error) {
Kevin Di Lallo
committed
	if provider == "" {
		provider = ProviderLocal
	}
	if username == "" {
		err = errors.New("Missing username")
		return err
	}
Kevin Di Lallo
committed
	_, err = pc.db.Exec(`DELETE FROM `+UsersTable+` WHERE provider = ($1) AND username = ($2)`, provider, username)
	if err != nil {
		log.Error(err.Error())
		return err
	}
	return nil
}
// DeleteAllUsers - Delete all users entries
func (pc *Connector) DeleteUsers() (err error) {
	_, err = pc.db.Exec(`DELETE FROM ` + UsersTable)
	if err != nil {
		log.Error(err.Error())
		return err
	}
	return nil
}
//IsValidUser - does if user exists
Kevin Di Lallo
committed
func (pc *Connector) IsValidUser(provider string, username string) (valid bool, err error) {
Kevin Di Lallo
committed
	if provider == "" {
		provider = ProviderLocal
	}
	if username == "" {
		err = errors.New("Missing username")
		return false, err
	}
Kevin Di Lallo
committed
		WHERE provider = ($1) AND username = ($2)`, provider, username)
	if err != nil {
		log.Error(err.Error())
		return false, err
	}
	defer rows.Close()
	// Scan results
	for rows.Next() {
		user := new(User)
Kevin Di Lallo
committed
		err = rows.Scan(&user.Id)
		if err != nil {
			log.Error(err.Error())
			return false, err
		} else {
			//User exists
			return true, nil
		}
	// User does not exist & no error
	return false, nil
}
//AuthenticateUser - returns true or false if credentials are OK
Kevin Di Lallo
committed
func (pc *Connector) AuthenticateUser(provider string, username string, password string) (authenticated bool, err error) {
Kevin Di Lallo
committed
	if provider == "" {
		provider = ProviderLocal
	}
	if username == "" {
		err = errors.New("Missing username")
		return false, err
	}
Kevin Di Lallo
committed
		WHERE provider = ($1) AND username = ($2)
		AND password = crypt('`+password+`', password)`, provider, username)
	if err != nil {
		log.Error(err.Error())
		return false, err
	}
	defer rows.Close()
	// Scan results
	for rows.Next() {
		user := new(User)
Kevin Di Lallo
committed
		err = rows.Scan(&user.Id)
		if err != nil {
			log.Error(err.Error())
			return false, err
		} else {
			//User exists
			return true, nil
		}
	}
	// User does not exist & no error
	return false, nil
}
// isValidRole - does role exist
func isValidRole(role string) error {
Kevin Di Lallo
committed
	case RoleUser, RoleAdmin: