Unverified Commit 01b832f5 authored by Kevin Di Lallo's avatar Kevin Di Lallo Committed by GitHub
Browse files

Merge pull request #109 from roymx/meep-users

meep-users package
parents 85bac1c7 975841b6
Loading
Loading
Loading
Loading
+457 −0
Original line number Diff line number Diff line
/*
 * Copyright (c) 2020  InterDigital Communications, Inc
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package usersdb

import (
	"database/sql"
	"errors"
	"strings"

	log "github.com/InterDigitalInc/AdvantEDGE/go-packages/meep-logger"

        _ "github.com/lib/pq"
)

// DB Config
const (
	DbHost              = "meep-postgis.default.svc.cluster.local"
	DbPort              = "5432"
	DbUser              = ""
	DbPassword          = ""
	DbDefault           = "postgres"
	DbMaxRetryCount int = 2
)

// DB Table Names
const (
	UsersTable      = "users"
)

const (
	RoleUser				= "user"
	RoleSuper				= "super"
)

type User struct {
	id   					string
	username   		string
	password			string
	role					string
	sboxname			string
}

// Connector - Implements a Postgis SQL DB connector
type Connector struct {
	name      string
	dbName    string
	db        *sql.DB
	connected bool
	updateCb  func(string, string)
}

// NewConnector - Creates and initializes a Postgis connector
func NewConnector(name, user, pwd, host, port string) (pc *Connector, err error) {
	if name == "" {
		err = errors.New("Missing connector name")
		return nil, err
	}

	// Create new connector
	pc = new(Connector)
	pc.name = name

	// Connect to Postgis DB
	for retry := 0; retry <= DbMaxRetryCount; retry++ {
		pc.db, err = pc.connectDB("", user, pwd, host, port)
		if err == nil {
			break
		}
	}
	if err != nil {
		log.Error("Failed to connect to postgis with err: ", err.Error())
		return nil, err
	}
	defer pc.db.Close()

	// Create DB if it does not exist
	// Use format: '<name>' & replace dashes with underscores
	pc.dbName = strings.ToLower(strings.Replace(name, "-", "_", -1))

	// Ignore DB creation error in case it already exists.
	// Failure will occur at DB connection if DB was not successfully created.
	_ = pc.CreateDb(pc.dbName)

	// Close connection to postgis
	pc.db.Close()

	// Connect with DB
	pc.db, err = pc.connectDB(pc.dbName, user, pwd, host, port)
	if err != nil {
		log.Error("Failed to connect to DB with err: ", err.Error())
		return nil, err
	}

	log.Info("Postgis Connector successfully created")
	pc.connected = true
	return pc, nil
}

func (pc *Connector) connectDB(dbName, user, pwd, host, port string) (db *sql.DB, err error) {
	// Set default values if none provided
	if dbName == "" {
		dbName = DbDefault
	}
	if host == "" {
		host = DbHost
	}
	if port == "" {
		port = DbPort
	}
	log.Debug("Connecting to Postgis DB [", dbName, "] at addr [", host, ":", port, "]")

	// Open postgis DB
	connStr := "user=" + user + " password=" + pwd + " dbname=" + dbName + " host=" + host + " port=" + port + " sslmode=disable"
	db, err = sql.Open("postgres", connStr)
	if err != nil {
		log.Warn("Failed to connect to Postgis DB with error: ", err.Error())
		return nil, err
	}

	// Make sure connection is up
	err = db.Ping()
	if err != nil {
		log.Warn("Failed to ping Postgis DB with error: ", err.Error())
		db.Close()
		return nil, err
	}

	log.Info("Connected to Postgis DB [", dbName, "]")
	return db, nil
}

// CreateDb -- Create new DB with provided name
func (pc *Connector) CreateDb(name string) (err error) {
	_, err = pc.db.Exec("CREATE DATABASE " + name)
	if err != nil {
		log.Error(err.Error())
		return err
	}

	log.Info("Created database: " + name)
	return nil
}

func (pc *Connector) CreateTables() (err error) {
	_, err = pc.db.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto")
	if err != nil {
		log.Error(err.Error())
		return err
	}

	// users Table
	_, err = pc.db.Exec(`CREATE TABLE ` + UsersTable + ` (
		id 							SERIAL 						PRIMARY KEY,
		username  			varchar(36)				NOT NULL UNIQUE,
		password				varchar(100)				NOT NULL,
		role						varchar(36)				NOT NULL DEFAULT 'user',
		sboxname				varchar(11)				NOT NULL DEFAULT ''
	)`)
	if err != nil {
		log.Error(err.Error())
		return err
	}
	log.Info("Created table: ", UsersTable)

	return nil
}

// DeleteTables - Delete all tables
func (pc *Connector) DeleteTables() (err error) {
	_ = pc.DeleteTable(UsersTable)
	return nil
}

// DeleteTable - Delete table with provided name
func (pc *Connector) DeleteTable(tableName string) (err error) {
	_, err = pc.db.Exec("DROP TABLE IF EXISTS " + tableName)
	if err != nil {
		log.Error(err.Error())
		return err
	}
	log.Info("Deleted table: " + tableName)
	return nil
}

// CreateUser - Create new user
func (pc *Connector) CreateUser(username string, password string, role string, sboxname string) (err error) {
	// Validate input
	if username == "" {
		return errors.New("Missing username")
	}
	if password == "" {
		return errors.New("Missing password")
	}
	if role == "" {
		role = RoleUser
	} else {
		err = isValidRole(role)
		if err != nil {
			return err
		}
	}

	// Create entry
	query := `INSERT INTO ` + UsersTable + ` (username, password, role, sboxname)
		VALUES ($1, crypt('`+password+`', gen_salt('bf')), $2, $3)`
	_, err = pc.db.Exec(query, username, role, sboxname)
	if err != nil {
		log.Error(err.Error())
		return err
	}

	return nil
}

// UpdateUser - Update existing user
func (pc *Connector) UpdateUser(username string, password string, role string, sboxname string) (err error) {
	// Validate input
	if username == "" {
		return errors.New("Missing username")
	}

	if password != "" {
		query := `UPDATE ` + UsersTable + `
			SET password = crypt('`+password+`', gen_salt('bf'))
			WHERE username = ($1)`
		_, err = pc.db.Exec(query, username)
		if err != nil {
			log.Error(err.Error())
			return err
		}
	}

	if role != "" {
		err = isValidRole(role)
		if err != nil {
			return err
		}
		query := `UPDATE ` + UsersTable + `
			SET role = $2
			WHERE username = ($1)`
		_, err = pc.db.Exec(query, username, role)
		if err != nil {
			log.Error(err.Error())
			return err
		}
	}

	if sboxname != "" {
		query := `UPDATE ` + UsersTable + `
			SET sboxname = $2
			WHERE username = ($1)`
		_, err = pc.db.Exec(query, username, sboxname)
		if err != nil {
			log.Error(err.Error())
			return err
		}
	}

	return nil
}

// GetUser - Get user information
func (pc *Connector) GetUser(username string) (user *User, err error) {
	// Validate input
	if username == "" {
		err = errors.New("Missing username")
		return nil, err
	}

	// Get user entry
	var rows *sql.Rows
	rows, err = pc.db.Query(`
		SELECT id, username, password, role, sboxname
		FROM `+UsersTable+`
		WHERE username = ($1)`, username)
	if err != nil {
		log.Error(err.Error())
		return nil, err
	}
	defer rows.Close()

	// Scan result
	for rows.Next() {
		user = new(User)
		err = rows.Scan(&user.id, &user.username, &user.password, &user.role, &user.sboxname)
		if err != nil {
			log.Error(err.Error())
			return nil, err
		}
	}
	err = rows.Err()
	if err != nil {
		log.Error(err)
	}

	// Return error if not found
	if user == nil {
		err = errors.New("user not found: " + username)
		return nil, err
	}
	return user, nil
}

// GetAllUsers - Get All users
func (pc *Connector) GetUsers() (userMap map[string]*User, err error) {
	// Create map
	userMap = make(map[string]*User)

	// Get user entries
	var rows *sql.Rows
	rows, err = pc.db.Query(`
		SELECT id, username, password, role, sboxname
		FROM ` + UsersTable)
	if err != nil {
		log.Error(err.Error())
		return userMap, err
	}
	defer rows.Close()

	// Scan results
	for rows.Next() {
		user := new(User)
		err = rows.Scan(&user.id, &user.username, &user.password, &user.role, &user.sboxname)
		if err != nil {
			log.Error(err.Error())
			return userMap, err
		}

		// Add to map
		userMap[user.username] = user
	}
	err = rows.Err()
	if err != nil {
		log.Error(err)
	}

	return userMap, nil
}

// DeleteUser - Delete user entry
func (pc *Connector) DeleteUser(username string) (err error) {
	// Validate input
	if username == "" {
		err = errors.New("Missing username")
		return err
	}

	_, err = pc.db.Exec(`DELETE FROM `+UsersTable+` WHERE username = ($1)`, username)
	if err != nil {
		log.Error(err.Error())
		return err
	}

	return nil
}

// DeleteAllUsers - Delete all users entries
func (pc *Connector) DeleteUsers() (err error) {
	_, err = pc.db.Exec(`DELETE FROM ` + UsersTable)
	if err != nil {
		log.Error(err.Error())
		return err
	}
	return nil
}

//IsValidUser - does if user exists
func (pc *Connector) IsValidUser(username string) (valid bool, err error){
	// Validate input
	if username == "" {
		err = errors.New("Missing username")
		return false, err
	}

	rows, err := pc.db.Query(`
		SELECT id
		FROM `+UsersTable+`
		WHERE username = ($1)`, username)
	if err != nil {
		log.Error(err.Error())
		return false, err
	}
	defer rows.Close()

	// Scan results
	for rows.Next() {
		user := new(User)
		err = rows.Scan(&user.id)
		if err != nil {
			log.Error(err.Error())
			return false, err
		} else {
			//User exists
			return true, nil
		}
	}
	// User does not exist & no error
	return false, nil
}

//AuthenticateUser - returns true or false if credentials are OK
func (pc *Connector) AuthenticateUser(username string, password string) (authenticated bool, err error){
	// Validate input
	if username == "" {
		err = errors.New("Missing username")
		return false, err
	}

	rows, err := pc.db.Query(`
		SELECT id
		FROM `+UsersTable+`
		WHERE username = ($1)
		AND password = crypt('`+password+`', password)` , username)
	if err != nil {
		log.Error(err.Error())
		return false, err
	}
	defer rows.Close()

	// Scan results
	for rows.Next() {
		user := new(User)
		err = rows.Scan(&user.id)
		if err != nil {
			log.Error(err.Error())
			return false, err
		} else {
			//User exists
			return true, nil
		}
	}
	// User does not exist & no error
	return false, nil
}

// isValidRole - does role exist
func isValidRole(role string) (error) {
	switch role {
	case RoleUser, RoleSuper:
		return nil
	}
	return errors.New("Inalid role")
}
+286 −0
Original line number Diff line number Diff line
/*
 * Copyright (c) 2020  InterDigital Communications, Inc
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package usersdb

import (
	"fmt"
	"testing"

	log "github.com/InterDigitalInc/AdvantEDGE/go-packages/meep-logger"
)

const (
	pcName      = "pc"
	pcDBUser    = "postgres"
	pcDBPwd     = "pwd"
	pcDBHost    = "localhost"
	pcDBPort    = "30432"

	username0		= ""
	username1		= "user1"
	username2		= "user2"
	username3		= "user3"

	password0		= ""
	password1		= "123" //3 chars
	password2		= "gie[rh[iuhberieg" //16 chars
	password3		= "efbiwerbfiwferbirgfbiuqrfgbdrfgjnbqairbqifhrbeqi[frb[rifhb[qirfbq]]]qaef[048FERGerwWRGG]FASF03404924" // 100 chars

	role0				= "invalid-role"
	role1				= "user"
	role2				= "user"
	role3				= "super"

	sboxname0		= "123456789012345" // more than 11 chars
	sboxname1		= "sbox-1"
	sboxname2		= "sbox-2"
	sboxname3		= "sbox-3"

)

func TestConnector(t *testing.T) {
	fmt.Println("--- ", t.Name())
	log.MeepTextLogInit(t.Name())

	// Invalid Connector
	fmt.Println("Invalid Connector")
	pc, err := NewConnector("", pcDBUser, pcDBPwd, pcDBHost, pcDBPort)
	if err == nil || pc != nil {
		t.Fatalf("DB connection should have failed")
	}
	pc, err = NewConnector(pcName, pcDBUser, pcDBPwd, "invalid-host", pcDBPort)
	if err == nil || pc != nil {
		t.Fatalf("DB connection should have failed")
	}
	pc, err = NewConnector(pcName, pcDBUser, pcDBPwd, pcDBHost, "invalid-port")
	if err == nil || pc != nil {
		t.Fatalf("DB connection should have failed")
	}
	pc, err = NewConnector(pcName, pcDBUser, "invalid-pwd", pcDBHost, pcDBPort)
	if err == nil || pc != nil {
		t.Fatalf("DB connection should have failed")
	}

	// Valid Connector
	fmt.Println("Create valid Postgis Connector")
	pc, err = NewConnector(pcName, pcDBUser, pcDBPwd, pcDBHost, pcDBPort)
	if err != nil || pc == nil {
		t.Fatalf("Failed to create postgis Connector")
	}

	// Cleanup
	_ = pc.DeleteTable(UsersTable)

	// Create tables
	fmt.Println("Create Tables")
	err = pc.CreateTables()
	if err != nil {
		t.Fatalf("Failed to create tables")
	}

	// Cleanup
	err = pc.DeleteTables()
	if err != nil {
		t.Fatalf("Failed to create tables")
	}

	// t.Fatalf("DONE")
}

func TestPostgisCreateUser(t *testing.T) {
	fmt.Println("--- ", t.Name())
	log.MeepTextLogInit(t.Name())

	// Create Connector
	fmt.Println("Create valid Connector")
	pc, err := NewConnector(pcName, pcDBUser, pcDBPwd, pcDBHost, pcDBPort)
	if err != nil || pc == nil {
		t.Fatalf("Failed to create postgis Connector")
	}

	// Cleanup
	_ = pc.DeleteTables()

	// Create tables
	fmt.Println("Create Tables")
	err = pc.CreateTables()
	if err != nil {
		t.Fatalf("Failed to create tables")
	}

	// Make sure users don't exist
	fmt.Println("Verify no user present")
	userMap, err := pc.GetUsers()
	if err != nil {
		t.Fatalf("Failed to get all users")
	}
	if len(userMap) != 0 {
		t.Fatalf("No user should be present")
	}

	fmt.Println("Create Invalid users")
	err = pc.CreateUser(username0, password1, role1, sboxname1)
	if err == nil {
		t.Fatalf("user creation should have failed")
	}
	err = pc.CreateUser(username1, password0, role1, sboxname1)
	if err == nil {
		t.Fatalf("user creation should have failed")
	}
	err = pc.CreateUser(username1, password1, role0, sboxname1)
	if err == nil {
		t.Fatalf("user creation should have failed")
	}
	err = pc.CreateUser(username1, password1, role1, sboxname0)
	if err == nil {
		t.Fatalf("user creation should have failed")
	}

	fmt.Println("user DB operations")
	err = pc.CreateUser(username1, password1, role1, sboxname1)
	if err != nil {
		t.Fatalf("Failed to create asset")
	}
	user, err := pc.GetUser(username1)
	if err != nil || user == nil {
		t.Fatalf("Failed to get user")
	}
	if user.username != username1 || user.role != role1 || user.sboxname != sboxname1 {
		t.Fatalf("Wrong user data")
	}
	if user.password == password1 {
		t.Fatalf("Password not encrypted")
	}
	valid,err := pc.IsValidUser(username1)
	if err != nil || !valid {
		t.Fatalf("Failed to validate user")
	}
	valid,err = pc.AuthenticateUser(username1, password1)
	if err != nil || !valid {
		t.Fatalf("Failed to authenticate user")
	}
	valid,err = pc.AuthenticateUser(username1, password2)
	if err != nil || valid {
		t.Fatalf("Wrong user authentication")
	}

	err = pc.CreateUser(username2, password2, role2, sboxname2)
	if err != nil {
		t.Fatalf("Failed to create asset")
	}
	user, err = pc.GetUser(username2)
	if err != nil || user == nil {
		t.Fatalf("Failed to get user")
	}
	if user.username != username2 || user.role != role2 || user.sboxname != sboxname2 {
		t.Fatalf("Wrong user data")
	}
	if user.password == password2 {
		t.Fatalf("Password not encrypted")
	}
	valid,err = pc.IsValidUser(username2)
	if err != nil || !valid {
		t.Fatalf("Failed to validate user")
	}
        valid,err = pc.AuthenticateUser(username2, password2)
        if err != nil || !valid {
		t.Fatalf("Failed to authenticate user")
	}
        valid,err = pc.AuthenticateUser(username2, password1)
        if err != nil || valid {
		t.Fatalf("Wrong user authentication")
	}

	err = pc.CreateUser(username3, password3, role3, sboxname3)
	if err != nil {
		t.Fatalf("Failed to create asset")
	}
	user, err = pc.GetUser(username3)
	if err != nil || user == nil {
		t.Fatalf("Failed to get user")
	}
	if user.username != username3 || user.role != role3 || user.sboxname != sboxname3 {
		t.Fatalf("Wrong user data")
	}
	if user.password == password3 {
		t.Fatalf("Password not encrypted")
	}
	valid,err = pc.IsValidUser(username3)
	if err != nil || !valid {
		t.Fatalf("Failed to validate user")
	}
        valid,err = pc.AuthenticateUser(username3, password3)
        if err != nil || !valid {
		t.Fatalf("Failed to authenticate user")
	}
        valid,err = pc.AuthenticateUser(username3, password2)
        if err != nil || valid {
		t.Fatalf("Wrong user authentication")
	}

	// Verify all additions worked
	userMap, err = pc.GetUsers()
	if err != nil || len(userMap) != 3 {
		t.Fatalf("Error getting all users")
	}

	// Remove & validate update
	fmt.Println("Remove user & validate update")
	err = pc.DeleteUser(username3)
	if err != nil {
		t.Fatalf("Failed to delete user")
	}
	user, err = pc.GetUser(username3)
	if err == nil || user != nil {
		t.Fatalf("user should no longer exist")
	}

	// Update & validate update
	fmt.Println("Add user & validate update")
	err = pc.UpdateUser(username1,password3,role3,sboxname3)
	if err != nil {
		t.Fatalf("Failed to update asset")
	}
	user, err = pc.GetUser(username1)
	if err != nil || user == nil {
		t.Fatalf("Failed to get user")
	}
	if user.username != username1 || user.role != role3 || user.sboxname != sboxname3 {
		t.Fatalf("Wrong user data")
	}
	valid,err = pc.AuthenticateUser(username1,password3)
	if err != nil || !valid {
		t.Fatalf("Failed to authenticate user")
	}
 	valid,err = pc.AuthenticateUser(username1,password1)
    	if err != nil || valid {
		t.Fatalf("Wrong user authentication")
	}

	// Delete all users & validate updates
	fmt.Println("Delete all users & validate updates")
	err = pc.DeleteUsers()
	if err != nil {
		t.Fatalf("Failed to delete all user")
	}
	userMap, err = pc.GetUsers()
	if err != nil || len(userMap) != 0 {
		t.Fatalf("user should no longer exist")
	}

	// t.Fatalf("DONE")
}
+10 −0
Original line number Diff line number Diff line
module github.com/InterDigitalInc/AdvantEDGE/go-packages/meep-users

go 1.12

require (
	github.com/InterDigitalInc/AdvantEDGE/go-packages/meep-logger v0.0.0
	github.com/lib/pq v1.5.2
)

replace github.com/InterDigitalInc/AdvantEDGE/go-packages/meep-logger => ../../go-packages/meep-logger
+11 −0
Original line number Diff line number Diff line
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/lib/pq v1.5.2 h1:yTSXVswvWUOQ3k1sd7vJfDrbSl8lKuscqFJRqjC0ifw=
github.com/lib/pq v1.5.2/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sirupsen/logrus v1.4.1 h1:GL2rEmy6nsikmW0r8opw9JIRScdMF5hA8cOYLH7In1k=
github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33 h1:I6FyU15t786LL7oL/hn43zqTuEGr4PN7F4XJ1p4E3Y8=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=