"test/system/dai_test.go" did not exist on "ccd23059219f88abde3c414abdf45a09781531ff"
Newer
Older
/*
 * Copyright (c) 2020  InterDigital Communications, Inc
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package usersdb
import (
	"database/sql"
	"errors"
	"strings"
	log "github.com/InterDigitalInc/AdvantEDGE/go-packages/meep-logger"
	_ "github.com/lib/pq"
)
// DB Config
const (
	DbHost              = "meep-postgis.default.svc.cluster.local"
	DbPort              = "5432"
	DbUser              = ""
	DbPassword          = ""
	DbDefault           = "postgres"
	DbMaxRetryCount int = 2
)
// DB Table Names
const (
	UsersTable = "users"
	RoleUser  = "user"
	RoleSuper = "super"
	id       string
	username string
	password string
	role     string
	sboxname string
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
}
// Connector - Implements a Postgis SQL DB connector
type Connector struct {
	name      string
	dbName    string
	db        *sql.DB
	connected bool
	updateCb  func(string, string)
}
// NewConnector - Creates and initializes a Postgis connector
func NewConnector(name, user, pwd, host, port string) (pc *Connector, err error) {
	if name == "" {
		err = errors.New("Missing connector name")
		return nil, err
	}
	// Create new connector
	pc = new(Connector)
	pc.name = name
	// Connect to Postgis DB
	for retry := 0; retry <= DbMaxRetryCount; retry++ {
		pc.db, err = pc.connectDB("", user, pwd, host, port)
		if err == nil {
			break
		}
	}
	if err != nil {
		log.Error("Failed to connect to postgis with err: ", err.Error())
		return nil, err
	}
	defer pc.db.Close()
	// Create DB if it does not exist
	// Use format: '<name>' & replace dashes with underscores
	pc.dbName = strings.ToLower(strings.Replace(name, "-", "_", -1))
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
	// Ignore DB creation error in case it already exists.
	// Failure will occur at DB connection if DB was not successfully created.
	_ = pc.CreateDb(pc.dbName)
	// Close connection to postgis
	pc.db.Close()
	// Connect with DB
	pc.db, err = pc.connectDB(pc.dbName, user, pwd, host, port)
	if err != nil {
		log.Error("Failed to connect to DB with err: ", err.Error())
		return nil, err
	}
	log.Info("Postgis Connector successfully created")
	pc.connected = true
	return pc, nil
}
func (pc *Connector) connectDB(dbName, user, pwd, host, port string) (db *sql.DB, err error) {
	// Set default values if none provided
	if dbName == "" {
		dbName = DbDefault
	}
	if host == "" {
		host = DbHost
	}
	if port == "" {
		port = DbPort
	}
	log.Debug("Connecting to Postgis DB [", dbName, "] at addr [", host, ":", port, "]")
	// Open postgis DB
	connStr := "user=" + user + " password=" + pwd + " dbname=" + dbName + " host=" + host + " port=" + port + " sslmode=disable"
	db, err = sql.Open("postgres", connStr)
	if err != nil {
		log.Warn("Failed to connect to Postgis DB with error: ", err.Error())
		return nil, err
	}
	// Make sure connection is up
	err = db.Ping()
	if err != nil {
		log.Warn("Failed to ping Postgis DB with error: ", err.Error())
		db.Close()
		return nil, err
	}
	log.Info("Connected to Postgis DB [", dbName, "]")
	return db, nil
}
// CreateDb -- Create new DB with provided name
func (pc *Connector) CreateDb(name string) (err error) {
	_, err = pc.db.Exec("CREATE DATABASE " + name)
	if err != nil {
		log.Error(err.Error())
		return err
	}
	log.Info("Created database: " + name)
	return nil
}
func (pc *Connector) CreateTables() (err error) {
	_, err = pc.db.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto")
	if err != nil {
		log.Error(err.Error())
		return err
	}
	// users Table
	_, err = pc.db.Exec(`CREATE TABLE IF NOT EXISTS ` + UsersTable + ` (
		id			SERIAL			PRIMARY KEY,
		username	varchar(36)		NOT NULL UNIQUE,
		password	varchar(100)	NOT NULL,
		role		varchar(36)		NOT NULL DEFAULT 'user',
		sboxname	varchar(11)		NOT NULL DEFAULT ''
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
	)`)
	if err != nil {
		log.Error(err.Error())
		return err
	}
	log.Info("Created table: ", UsersTable)
	return nil
}
// DeleteTables - Delete all tables
func (pc *Connector) DeleteTables() (err error) {
	_ = pc.DeleteTable(UsersTable)
	return nil
}
// DeleteTable - Delete table with provided name
func (pc *Connector) DeleteTable(tableName string) (err error) {
	_, err = pc.db.Exec("DROP TABLE IF EXISTS " + tableName)
	if err != nil {
		log.Error(err.Error())
		return err
	}
	log.Info("Deleted table: " + tableName)
	return nil
}
// CreateUser - Create new user
func (pc *Connector) CreateUser(username string, password string, role string, sboxname string) (err error) {
	// Validate input
	if username == "" {
		return errors.New("Missing username")
	}
	if password == "" {
		return errors.New("Missing password")
	}
	if role == "" {
		role = RoleUser
	} else {
		err = isValidRole(role)
		if err != nil {
			return err
		}
	}
	// Create entry
	query := `INSERT INTO ` + UsersTable + ` (username, password, role, sboxname)
		VALUES ($1, crypt('` + password + `', gen_salt('bf')), $2, $3)`
	_, err = pc.db.Exec(query, username, role, sboxname)
	if err != nil {
		log.Error(err.Error())
		return err
	}
	return nil
}
// UpdateUser - Update existing user
func (pc *Connector) UpdateUser(username string, password string, role string, sboxname string) (err error) {
	// Validate input
	if username == "" {
		return errors.New("Missing username")
	}
	if password != "" {
		query := `UPDATE ` + UsersTable + `
			SET password = crypt('` + password + `', gen_salt('bf'))
		_, err = pc.db.Exec(query, username)
		if err != nil {
			log.Error(err.Error())
			return err
		}
	}
	if role != "" {
		err = isValidRole(role)
		if err != nil {
			return err
		}
		query := `UPDATE ` + UsersTable + `
			SET role = $2
		_, err = pc.db.Exec(query, username, role)
		if err != nil {
			log.Error(err.Error())
			return err
		}
	}
	if sboxname != "" {
		query := `UPDATE ` + UsersTable + `
			SET sboxname = $2
		_, err = pc.db.Exec(query, username, sboxname)
		if err != nil {
			log.Error(err.Error())
			return err
		}
	}
	return nil
}
// GetUser - Get user information
func (pc *Connector) GetUser(username string) (user *User, err error) {
	// Validate input
	if username == "" {
		err = errors.New("Missing username")
		return nil, err
	}
	// Get user entry
	var rows *sql.Rows
	rows, err = pc.db.Query(`
		SELECT id, username, password, role, sboxname
		FROM `+UsersTable+`
	if err != nil {
		log.Error(err.Error())
		return nil, err
	}
	defer rows.Close()
	// Scan result
	for rows.Next() {
		user = new(User)
		err = rows.Scan(&user.id, &user.username, &user.password, &user.role, &user.sboxname)
		if err != nil {
			log.Error(err.Error())
			return nil, err
		}
	}
	err = rows.Err()
	if err != nil {
		log.Error(err)
	}
	// Return error if not found
	if user == nil {
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
		return nil, err
	}
	return user, nil
}
// GetAllUsers - Get All users
func (pc *Connector) GetUsers() (userMap map[string]*User, err error) {
	// Create map
	userMap = make(map[string]*User)
	// Get user entries
	var rows *sql.Rows
	rows, err = pc.db.Query(`
		SELECT id, username, password, role, sboxname
		FROM ` + UsersTable)
	if err != nil {
		log.Error(err.Error())
		return userMap, err
	}
	defer rows.Close()
	// Scan results
	for rows.Next() {
		user := new(User)
		err = rows.Scan(&user.id, &user.username, &user.password, &user.role, &user.sboxname)
		if err != nil {
			log.Error(err.Error())
			return userMap, err
		}
		// Add to map
		userMap[user.username] = user
	}
	err = rows.Err()
	if err != nil {
		log.Error(err)
	}
	return userMap, nil
}
// DeleteUser - Delete user entry
func (pc *Connector) DeleteUser(username string) (err error) {
	// Validate input
	if username == "" {
		err = errors.New("Missing username")
		return err
	}
	_, err = pc.db.Exec(`DELETE FROM `+UsersTable+` WHERE username = ($1)`, username)
	if err != nil {
		log.Error(err.Error())
		return err
	}
	return nil
}
// DeleteAllUsers - Delete all users entries
func (pc *Connector) DeleteUsers() (err error) {
	_, err = pc.db.Exec(`DELETE FROM ` + UsersTable)
	if err != nil {
		log.Error(err.Error())
		return err
	}
	return nil
}
//IsValidUser - does if user exists
func (pc *Connector) IsValidUser(username string) (valid bool, err error) {
	// Validate input
	if username == "" {
		err = errors.New("Missing username")
		return false, err
	}
	if err != nil {
		log.Error(err.Error())
		return false, err
	}
	defer rows.Close()
	// Scan results
	for rows.Next() {
		user := new(User)
		err = rows.Scan(&user.id)
		if err != nil {
			log.Error(err.Error())
			return false, err
		} else {
			//User exists
			return true, nil
		}
	// User does not exist & no error
	return false, nil
}
//AuthenticateUser - returns true or false if credentials are OK
func (pc *Connector) AuthenticateUser(username string, password string) (authenticated bool, err error) {
	// Validate input
	if username == "" {
		err = errors.New("Missing username")
		return false, err
	}
		AND password = crypt('`+password+`', password)`, username)
	if err != nil {
		log.Error(err.Error())
		return false, err
	}
	defer rows.Close()
	// Scan results
	for rows.Next() {
		user := new(User)
		err = rows.Scan(&user.id)
		if err != nil {
			log.Error(err.Error())
			return false, err
		} else {
			//User exists
			return true, nil
		}
	}
	// User does not exist & no error
	return false, nil
}
// isValidRole - does role exist
func isValidRole(role string) error {