Commit 5703f2a5 authored by Kevin Di Lallo's avatar Kevin Di Lallo
Browse files

added oauth provider to user & session databases for provider-specific sessions

parent 92fff640
Loading
Loading
Loading
Loading
+13 −10
Original line number Diff line number Diff line
@@ -51,6 +51,7 @@ import (

const OAUTH_PROVIDER_GITHUB = "github"
const OAUTH_PROVIDER_GITLAB = "gitlab"
const OAUTH_PROVIDER_LOCAL = "local"

var mutex sync.Mutex

@@ -207,12 +208,13 @@ func uaAuthorize(w http.ResponseWriter, r *http.Request) {
		return
	}

	// Get provider-specific OAuth config
	provider := request.provider
	config := pfmCtrl.oauthConfigs[provider]

	// Delete login request & timer
	delLoginRequest(state)

	// Get provider-specific OAuth config
	config := pfmCtrl.oauthConfigs[request.provider]

	// Retrieve access token
	token, err := config.Exchange(context.Background(), code)
	if err != nil {
@@ -223,7 +225,7 @@ func uaAuthorize(w http.ResponseWriter, r *http.Request) {

	// Retrieve User ID
	var userId string
	switch request.provider {
	switch provider {
	case OAUTH_PROVIDER_GITHUB:
		oauthClient := config.Client(context.Background(), token)
		client := github.NewClient(oauthClient)
@@ -259,7 +261,7 @@ func uaAuthorize(w http.ResponseWriter, r *http.Request) {
	}

	// Start user session
	sandboxName, err, errCode := startSession(userId, w, r)
	sandboxName, err, errCode := startSession(provider, userId, w, r)
	if err != nil {
		log.Error(err.Error())
		http.Redirect(w, r, getErrUrl(err.Error()), errCode)
@@ -278,14 +280,14 @@ func uaLoginUser(w http.ResponseWriter, r *http.Request) {
	password := r.FormValue("password")

	// Validate user credentials
	authenticated, err := pfmCtrl.userStore.AuthenticateUser(username, password)
	authenticated, err := pfmCtrl.userStore.AuthenticateUser(OAUTH_PROVIDER_LOCAL, username, password)
	if err != nil || !authenticated {
		http.Error(w, "Unauthorized", http.StatusUnauthorized)
		return
	}

	// Start user session
	sandboxName, err, errCode := startSession(username, w, r)
	sandboxName, err, errCode := startSession(OAUTH_PROVIDER_LOCAL, username, w, r)
	if err != nil {
		log.Error(err.Error())
		http.Error(w, err.Error(), errCode)
@@ -311,11 +313,11 @@ func uaLoginUser(w http.ResponseWriter, r *http.Request) {
}

// Retrieve existing user session or create a new one
func startSession(username string, w http.ResponseWriter, r *http.Request) (sandboxName string, err error, code int) {
func startSession(provider string, username string, w http.ResponseWriter, r *http.Request) (sandboxName string, err error, code int) {

	// Get existing session by user name, if any
	sessionStore := pfmCtrl.sessionMgr.GetSessionStore()
	session, err := sessionStore.GetByName(username)
	session, err := sessionStore.GetByName(provider, username)
	if err != nil {
		// Check if max session count is reached before creating a new one
		count := sessionStore.GetCount()
@@ -326,7 +328,7 @@ func startSession(username string, w http.ResponseWriter, r *http.Request) (sand

		// Get requested sandbox name & role from user profile, if any
		role := users.RoleUser
		user, err := pfmCtrl.userStore.GetUser(username)
		user, err := pfmCtrl.userStore.GetUser(provider, username)
		if err == nil {
			sandboxName = user.Sboxname
			role = user.Role
@@ -352,6 +354,7 @@ func startSession(username string, w http.ResponseWriter, r *http.Request) (sand
		session = new(sm.Session)
		session.ID = ""
		session.Username = username
		session.Provider = provider
		session.Sandbox = sandboxName
		session.Role = role
	} else {
+8 −2
Original line number Diff line number Diff line
@@ -40,6 +40,7 @@ const SessionDuration = 1200 // 20 minutes
const (
	ValSessionID = "sid"
	ValUsername  = "user"
	ValProvider  = "provider"
	ValSandbox   = "sbox"
	ValRole      = "role"
	ValTimestamp = "timestamp"
@@ -54,6 +55,7 @@ const (
type Session struct {
	ID        string
	Username  string
	Provider  string
	Sandbox   string
	Role      string
	Timestamp time.Time
@@ -131,6 +133,7 @@ func (ss *SessionStore) Get(r *http.Request) (s *Session, err error) {
	s = new(Session)
	s.ID = sessionId
	s.Username = session[ValUsername]
	s.Provider = session[ValProvider]
	s.Sandbox = session[ValSandbox]
	s.Role = session[ValRole]
	s.Timestamp, _ = time.Parse(time.RFC3339, session[ValTimestamp])
@@ -166,6 +169,7 @@ func getSessionEntryHandler(key string, fields map[string]string, userData inter
	s := new(Session)
	s.ID = fields[ValSessionID]
	s.Username = fields[ValUsername]
	s.Provider = fields[ValProvider]
	s.Sandbox = fields[ValSandbox]
	s.Role = fields[ValRole]
	s.Timestamp, _ = time.Parse(time.RFC3339, fields[ValTimestamp])
@@ -174,10 +178,11 @@ func getSessionEntryHandler(key string, fields map[string]string, userData inter
}

// GetByName - Retrieve session by name
func (ss *SessionStore) GetByName(username string) (s *Session, err error) {
func (ss *SessionStore) GetByName(provider string, username string) (s *Session, err error) {
	// Get existing session, if any
	s = new(Session)
	s.Username = username
	s.Provider = provider
	err = ss.rc.ForEachEntry(ss.baseKey+"*", getUserEntryHandler, s)
	if err != nil {
		return nil, err
@@ -199,7 +204,7 @@ func getUserEntryHandler(key string, fields map[string]string, userData interfac
	}

	// look for matching username
	if fields[ValUsername] == s.Username {
	if fields[ValUsername] == s.Username && fields[ValProvider] == s.Provider {
		s.ID = fields[ValSessionID]
		s.Sandbox = fields[ValSandbox]
		s.Role = fields[ValRole]
@@ -229,6 +234,7 @@ func (ss *SessionStore) Set(s *Session, w http.ResponseWriter, r *http.Request)
	fields := make(map[string]interface{})
	fields[ValSessionID] = sessionId
	fields[ValUsername] = s.Username
	fields[ValProvider] = s.Provider
	fields[ValSandbox] = s.Sandbox
	fields[ValRole] = s.Role
	fields[ValTimestamp] = time.Now().Format(time.RFC3339)
+67 −32
Original line number Diff line number Diff line
@@ -41,13 +41,18 @@ const (
	UsersTable = "users"
)

const (
	ProviderLocal = "local"
)

const (
	RoleUser  = "user"
	RoleSuper = "super"
	RoleAdmin = "admin"
)

type User struct {
	Id       string
	Provider string
	Username string
	Password string
	Role     string
@@ -164,9 +169,10 @@ func (pc *Connector) CreateTables() (err error) {
	// users Table
	_, err = pc.db.Exec(`CREATE TABLE IF NOT EXISTS ` + UsersTable + ` (
		id			SERIAL			PRIMARY KEY,
		username	varchar(36)		NOT NULL UNIQUE,
		provider	varchar(20)		NOT NULL DEFAULT '` + ProviderLocal + `',
		username	varchar(36)		NOT NULL,
		password	varchar(100)	NOT NULL,
		role		varchar(36)		NOT NULL DEFAULT 'user',
		role		varchar(36)		NOT NULL DEFAULT '` + RoleUser + `',
		sboxname	varchar(11)		NOT NULL DEFAULT ''
	)`)
	if err != nil {
@@ -196,11 +202,14 @@ func (pc *Connector) DeleteTable(tableName string) (err error) {
}

// CreateUser - Create new user
func (pc *Connector) CreateUser(username string, password string, role string, sboxname string) (err error) {
func (pc *Connector) CreateUser(provider string, username string, password string, role string, sboxname string) (err error) {
	// Validate input
	if username == "" {
		return errors.New("Missing username")
	}
	if username == "" {
		return errors.New("Missing username")
	}
	if password == "" {
		return errors.New("Missing password")
	}
@@ -212,11 +221,14 @@ func (pc *Connector) CreateUser(username string, password string, role string, s
			return err
		}
	}
	if provider == "" {
		provider = ProviderLocal
	}

	// Create entry
	query := `INSERT INTO ` + UsersTable + ` (username, password, role, sboxname)
		VALUES ($1, crypt('` + password + `', gen_salt('bf')), $2, $3)`
	_, err = pc.db.Exec(query, username, role, sboxname)
	query := `INSERT INTO ` + UsersTable + ` (provider, username, password, role, sboxname)
		VALUES ($1, $2, crypt('` + password + `', gen_salt('bf')), $3, $4)`
	_, err = pc.db.Exec(query, provider, username, role, sboxname)
	if err != nil {
		log.Error(err.Error())
		return err
@@ -226,8 +238,11 @@ func (pc *Connector) CreateUser(username string, password string, role string, s
}

// UpdateUser - Update existing user
func (pc *Connector) UpdateUser(username string, password string, role string, sboxname string) (err error) {
func (pc *Connector) UpdateUser(provider string, username string, password string, role string, sboxname string) (err error) {
	// Validate input
	if provider == "" {
		provider = ProviderLocal
	}
	if username == "" {
		return errors.New("Missing username")
	}
@@ -235,8 +250,8 @@ func (pc *Connector) UpdateUser(username string, password string, role string, s
	if password != "" {
		query := `UPDATE ` + UsersTable + `
			SET password = crypt('` + password + `', gen_salt('bf'))
			WHERE username = ($1)`
		_, err = pc.db.Exec(query, username)
			WHERE provider = ($1) AND username = ($2)`
		_, err = pc.db.Exec(query, provider, username)
		if err != nil {
			log.Error(err.Error())
			return err
@@ -249,9 +264,9 @@ func (pc *Connector) UpdateUser(username string, password string, role string, s
			return err
		}
		query := `UPDATE ` + UsersTable + `
			SET role = $2
			WHERE username = ($1)`
		_, err = pc.db.Exec(query, username, role)
			SET role = $3
			WHERE provider = ($1) AND username = ($2)`
		_, err = pc.db.Exec(query, provider, username, role)
		if err != nil {
			log.Error(err.Error())
			return err
@@ -260,9 +275,9 @@ func (pc *Connector) UpdateUser(username string, password string, role string, s

	if sboxname != "" {
		query := `UPDATE ` + UsersTable + `
			SET sboxname = $2
			WHERE username = ($1)`
		_, err = pc.db.Exec(query, username, sboxname)
			SET sboxname = $3
			WHERE provider = ($1) AND username = ($2)`
		_, err = pc.db.Exec(query, provider, username, sboxname)
		if err != nil {
			log.Error(err.Error())
			return err
@@ -273,8 +288,11 @@ func (pc *Connector) UpdateUser(username string, password string, role string, s
}

// GetUser - Get user information
func (pc *Connector) GetUser(username string) (user *User, err error) {
func (pc *Connector) GetUser(provider string, username string) (user *User, err error) {
	// Validate input
	if provider == "" {
		provider = ProviderLocal
	}
	if username == "" {
		err = errors.New("Missing username")
		return nil, err
@@ -283,9 +301,9 @@ func (pc *Connector) GetUser(username string) (user *User, err error) {
	// Get user entry
	var rows *sql.Rows
	rows, err = pc.db.Query(`
		SELECT id, username, password, role, sboxname
		SELECT id, provider, username, password, role, sboxname
		FROM `+UsersTable+`
		WHERE username = ($1)`, username)
		WHERE provider = ($1) AND username = ($2)`, provider, username)
	if err != nil {
		log.Error(err.Error())
		return nil, err
@@ -295,7 +313,7 @@ func (pc *Connector) GetUser(username string) (user *User, err error) {
	// Scan result
	for rows.Next() {
		user = new(User)
		err = rows.Scan(&user.Id, &user.Username, &user.Password, &user.Role, &user.Sboxname)
		err = rows.Scan(&user.Id, &user.Provider, &user.Username, &user.Password, &user.Role, &user.Sboxname)
		if err != nil {
			log.Error(err.Error())
			return nil, err
@@ -308,7 +326,7 @@ func (pc *Connector) GetUser(username string) (user *User, err error) {

	// Return error if not found
	if user == nil {
		err = errors.New("user not found: " + username)
		err = errors.New(provider + " user not found: " + username)
		return nil, err
	}
	return user, nil
@@ -322,7 +340,7 @@ func (pc *Connector) GetUsers() (userMap map[string]*User, err error) {
	// Get user entries
	var rows *sql.Rows
	rows, err = pc.db.Query(`
		SELECT id, username, password, role, sboxname
		SELECT id, provider, username, password, role, sboxname
		FROM ` + UsersTable)
	if err != nil {
		log.Error(err.Error())
@@ -333,14 +351,14 @@ func (pc *Connector) GetUsers() (userMap map[string]*User, err error) {
	// Scan results
	for rows.Next() {
		user := new(User)
		err = rows.Scan(&user.Id, &user.Username, &user.Password, &user.Role, &user.Sboxname)
		err = rows.Scan(&user.Id, &user.Provider, &user.Username, &user.Password, &user.Role, &user.Sboxname)
		if err != nil {
			log.Error(err.Error())
			return userMap, err
		}

		// Add to map
		userMap[user.Username] = user
		userMap[pc.GetUserKey(user.Provider, user.Username)] = user
	}
	err = rows.Err()
	if err != nil {
@@ -350,15 +368,26 @@ func (pc *Connector) GetUsers() (userMap map[string]*User, err error) {
	return userMap, nil
}

// GetUserKey - Get provider-specific user key
func (pc *Connector) GetUserKey(provider string, username string) (key string) {
	if provider == "" {
		provider = ProviderLocal
	}
	return provider + "-" + username
}

// DeleteUser - Delete user entry
func (pc *Connector) DeleteUser(username string) (err error) {
func (pc *Connector) DeleteUser(provider string, username string) (err error) {
	// Validate input
	if provider == "" {
		provider = ProviderLocal
	}
	if username == "" {
		err = errors.New("Missing username")
		return err
	}

	_, err = pc.db.Exec(`DELETE FROM `+UsersTable+` WHERE username = ($1)`, username)
	_, err = pc.db.Exec(`DELETE FROM `+UsersTable+` WHERE provider = ($1) AND username = ($2)`, provider, username)
	if err != nil {
		log.Error(err.Error())
		return err
@@ -378,8 +407,11 @@ func (pc *Connector) DeleteUsers() (err error) {
}

//IsValidUser - does if user exists
func (pc *Connector) IsValidUser(username string) (valid bool, err error) {
func (pc *Connector) IsValidUser(provider string, username string) (valid bool, err error) {
	// Validate input
	if provider == "" {
		provider = ProviderLocal
	}
	if username == "" {
		err = errors.New("Missing username")
		return false, err
@@ -388,7 +420,7 @@ func (pc *Connector) IsValidUser(username string) (valid bool, err error) {
	rows, err := pc.db.Query(`
		SELECT id
		FROM `+UsersTable+`
		WHERE username = ($1)`, username)
		WHERE provider = ($1) AND username = ($2)`, provider, username)
	if err != nil {
		log.Error(err.Error())
		return false, err
@@ -412,8 +444,11 @@ func (pc *Connector) IsValidUser(username string) (valid bool, err error) {
}

//AuthenticateUser - returns true or false if credentials are OK
func (pc *Connector) AuthenticateUser(username string, password string) (authenticated bool, err error) {
func (pc *Connector) AuthenticateUser(provider string, username string, password string) (authenticated bool, err error) {
	// Validate input
	if provider == "" {
		provider = ProviderLocal
	}
	if username == "" {
		err = errors.New("Missing username")
		return false, err
@@ -422,8 +457,8 @@ func (pc *Connector) AuthenticateUser(username string, password string) (authent
	rows, err := pc.db.Query(`
		SELECT id
		FROM `+UsersTable+`
		WHERE username = ($1)
		AND password = crypt('`+password+`', password)`, username)
		WHERE provider = ($1) AND username = ($2)
		AND password = crypt('`+password+`', password)`, provider, username)
	if err != nil {
		log.Error(err.Error())
		return false, err
@@ -449,7 +484,7 @@ func (pc *Connector) AuthenticateUser(username string, password string) (authent
// isValidRole - does role exist
func isValidRole(role string) error {
	switch role {
	case RoleUser, RoleSuper:
	case RoleUser, RoleAdmin:
		return nil
	}
	return errors.New("Inalid role")
+55 −30
Original line number Diff line number Diff line
@@ -30,6 +30,10 @@ const (
	pcDBHost = "localhost"
	pcDBPort = "30432"

	provider1 = ""
	provider2 = "provider2"
	provider3 = "provider3"

	username0 = ""
	username1 = "user1"
	username2 = "user2"
@@ -43,7 +47,7 @@ const (
	role0 = "invalid-role"
	role1 = "user"
	role2 = "user"
	role3 = "super"
	role3 = "admin"

	sboxname0 = "123456789012345" // more than 11 chars
	sboxname1 = "sbox-1"
@@ -132,101 +136,101 @@ func TestPostgisCreateUser(t *testing.T) {
	}

	fmt.Println("Create Invalid users")
	err = pc.CreateUser(username0, password1, role1, sboxname1)
	err = pc.CreateUser(provider1, username0, password1, role1, sboxname1)
	if err == nil {
		t.Fatalf("user creation should have failed")
	}
	err = pc.CreateUser(username1, password0, role1, sboxname1)
	err = pc.CreateUser(provider1, username1, password0, role1, sboxname1)
	if err == nil {
		t.Fatalf("user creation should have failed")
	}
	err = pc.CreateUser(username1, password1, role0, sboxname1)
	err = pc.CreateUser(provider1, username1, password1, role0, sboxname1)
	if err == nil {
		t.Fatalf("user creation should have failed")
	}
	err = pc.CreateUser(username1, password1, role1, sboxname0)
	err = pc.CreateUser(provider1, username1, password1, role1, sboxname0)
	if err == nil {
		t.Fatalf("user creation should have failed")
	}

	fmt.Println("user DB operations")
	err = pc.CreateUser(username1, password1, role1, sboxname1)
	err = pc.CreateUser(provider1, username1, password1, role1, sboxname1)
	if err != nil {
		t.Fatalf("Failed to create asset")
	}
	user, err := pc.GetUser(username1)
	user, err := pc.GetUser(provider1, username1)
	if err != nil || user == nil {
		t.Fatalf("Failed to get user")
	}
	if user.Username != username1 || user.Role != role1 || user.Sboxname != sboxname1 {
	if user.Provider != ProviderLocal || user.Username != username1 || user.Role != role1 || user.Sboxname != sboxname1 {
		t.Fatalf("Wrong user data")
	}
	if user.Password == password1 {
		t.Fatalf("Password not encrypted")
	}
	valid, err := pc.IsValidUser(username1)
	valid, err := pc.IsValidUser(provider1, username1)
	if err != nil || !valid {
		t.Fatalf("Failed to validate user")
	}
	valid, err = pc.AuthenticateUser(username1, password1)
	valid, err = pc.AuthenticateUser(provider1, username1, password1)
	if err != nil || !valid {
		t.Fatalf("Failed to authenticate user")
	}
	valid, err = pc.AuthenticateUser(username1, password2)
	valid, err = pc.AuthenticateUser(provider1, username1, password2)
	if err != nil || valid {
		t.Fatalf("Wrong user authentication")
	}

	err = pc.CreateUser(username2, password2, role2, sboxname2)
	err = pc.CreateUser(provider2, username2, password2, role2, sboxname2)
	if err != nil {
		t.Fatalf("Failed to create asset")
	}
	user, err = pc.GetUser(username2)
	user, err = pc.GetUser(provider2, username2)
	if err != nil || user == nil {
		t.Fatalf("Failed to get user")
	}
	if user.Username != username2 || user.Role != role2 || user.Sboxname != sboxname2 {
	if user.Provider != provider2 || user.Username != username2 || user.Role != role2 || user.Sboxname != sboxname2 {
		t.Fatalf("Wrong user data")
	}
	if user.Password == password2 {
		t.Fatalf("Password not encrypted")
	}
	valid, err = pc.IsValidUser(username2)
	valid, err = pc.IsValidUser(provider2, username2)
	if err != nil || !valid {
		t.Fatalf("Failed to validate user")
	}
	valid, err = pc.AuthenticateUser(username2, password2)
	valid, err = pc.AuthenticateUser(provider2, username2, password2)
	if err != nil || !valid {
		t.Fatalf("Failed to authenticate user")
	}
	valid, err = pc.AuthenticateUser(username2, password1)
	valid, err = pc.AuthenticateUser(provider2, username2, password1)
	if err != nil || valid {
		t.Fatalf("Wrong user authentication")
	}

	err = pc.CreateUser(username3, password3, role3, sboxname3)
	err = pc.CreateUser(provider3, username3, password3, role3, sboxname3)
	if err != nil {
		t.Fatalf("Failed to create asset")
	}
	user, err = pc.GetUser(username3)
	user, err = pc.GetUser(provider3, username3)
	if err != nil || user == nil {
		t.Fatalf("Failed to get user")
	}
	if user.Username != username3 || user.Role != role3 || user.Sboxname != sboxname3 {
	if user.Provider != provider3 || user.Username != username3 || user.Role != role3 || user.Sboxname != sboxname3 {
		t.Fatalf("Wrong user data")
	}
	if user.Password == password3 {
		t.Fatalf("Password not encrypted")
	}
	valid, err = pc.IsValidUser(username3)
	valid, err = pc.IsValidUser(provider3, username3)
	if err != nil || !valid {
		t.Fatalf("Failed to validate user")
	}
	valid, err = pc.AuthenticateUser(username3, password3)
	valid, err = pc.AuthenticateUser(provider3, username3, password3)
	if err != nil || !valid {
		t.Fatalf("Failed to authenticate user")
	}
	valid, err = pc.AuthenticateUser(username3, password2)
	valid, err = pc.AuthenticateUser(provider3, username3, password2)
	if err != nil || valid {
		t.Fatalf("Wrong user authentication")
	}
@@ -236,36 +240,57 @@ func TestPostgisCreateUser(t *testing.T) {
	if err != nil || len(userMap) != 3 {
		t.Fatalf("Error getting all users")
	}
	user, found := userMap[pc.GetUserKey(provider1, username1)]
	if !found {
		t.Fatalf("User not found")
	}
	if user.Provider != ProviderLocal || user.Username != username1 || user.Role != role1 || user.Sboxname != sboxname1 {
		t.Fatalf("Wrong user data")
	}
	user, found = userMap[pc.GetUserKey(provider2, username2)]
	if !found {
		t.Fatalf("User not found")
	}
	if user.Provider != provider2 || user.Username != username2 || user.Role != role2 || user.Sboxname != sboxname2 {
		t.Fatalf("Wrong user data")
	}
	user, found = userMap[pc.GetUserKey(provider3, username3)]
	if !found {
		t.Fatalf("User not found")
	}
	if user.Provider != provider3 || user.Username != username3 || user.Role != role3 || user.Sboxname != sboxname3 {
		t.Fatalf("Wrong user data")
	}

	// Remove & validate update
	fmt.Println("Remove user & validate update")
	err = pc.DeleteUser(username3)
	err = pc.DeleteUser(provider3, username3)
	if err != nil {
		t.Fatalf("Failed to delete user")
	}
	user, err = pc.GetUser(username3)
	user, err = pc.GetUser(provider3, username3)
	if err == nil || user != nil {
		t.Fatalf("user should no longer exist")
	}

	// Update & validate update
	fmt.Println("Add user & validate update")
	err = pc.UpdateUser(username1, password3, role3, sboxname3)
	err = pc.UpdateUser(provider1, username1, password3, role3, sboxname3)
	if err != nil {
		t.Fatalf("Failed to update asset")
	}
	user, err = pc.GetUser(username1)
	user, err = pc.GetUser(provider1, username1)
	if err != nil || user == nil {
		t.Fatalf("Failed to get user")
	}
	if user.Username != username1 || user.Role != role3 || user.Sboxname != sboxname3 {
	if user.Provider != ProviderLocal || user.Username != username1 || user.Role != role3 || user.Sboxname != sboxname3 {
		t.Fatalf("Wrong user data")
	}
	valid, err = pc.AuthenticateUser(username1, password3)
	valid, err = pc.AuthenticateUser(provider1, username1, password3)
	if err != nil || !valid {
		t.Fatalf("Failed to authenticate user")
	}
	valid, err = pc.AuthenticateUser(username1, password1)
	valid, err = pc.AuthenticateUser(provider1, username1, password1)
	if err != nil || valid {
		t.Fatalf("Wrong user authentication")
	}