Unverified Commit c262a87c authored by Kevin Di Lallo's avatar Kevin Di Lallo Committed by GitHub
Browse files

Merge pull request #150 from dilallkx/kd_sp45_dev_oauth

OAuth Improvements
parents ce5d4ad4 ae282d2a
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

version: 1.5.6
version: 1.5.7
repo:
  name: AdvantEDGE

+2 −2
Original line number Diff line number Diff line
@@ -49,7 +49,7 @@ type Scenario struct {

type LoginRequest struct {
	provider string
	timestamp time.Time
	timer    *time.Timer
}

type PlatformCtrl struct {
+81 −25
Original line number Diff line number Diff line
@@ -35,6 +35,7 @@ import (
	"net/http"
	"os"
	"strings"
	"sync"
	"time"

	dataModel "github.com/InterDigitalInc/AdvantEDGE/go-packages/meep-data-model"
@@ -50,6 +51,9 @@ import (

const OAUTH_PROVIDER_GITHUB = "github"
const OAUTH_PROVIDER_GITLAB = "gitlab"
const OAUTH_PROVIDER_LOCAL = "local"

var mutex sync.Mutex

func initOAuth() {

@@ -113,6 +117,39 @@ func getUniqueState() (state string, err error) {
	return "", errors.New("Failed to generate a random state string")
}

func getLoginRequest(state string) *LoginRequest {
	mutex.Lock()
	defer mutex.Unlock()
	request, found := pfmCtrl.loginRequests[state]
	if !found {
		return nil
	}
	return request
}

func setLoginRequest(state string, request *LoginRequest) {
	mutex.Lock()
	defer mutex.Unlock()
	pfmCtrl.loginRequests[state] = request
}

func delLoginRequest(state string) {
	mutex.Lock()
	defer mutex.Unlock()
	request, found := pfmCtrl.loginRequests[state]
	if !found {
		return
	}
	if request.timer != nil {
		request.timer.Stop()
	}
	delete(pfmCtrl.loginRequests, state)
}

func getErrUrl(err string) string {
	return pfmCtrl.uri + "?err=" + strings.ReplaceAll(err, " ", "+")
}

func uaLoginOAuth(w http.ResponseWriter, r *http.Request) {
	log.Info("----- OAUTH LOGIN -----")

@@ -123,8 +160,9 @@ func uaLoginOAuth(w http.ResponseWriter, r *http.Request) {
	// Get provider-specific OAuth config
	config, found := pfmCtrl.oauthConfigs[provider]
	if !found {
		log.Error("Provider config not found for: ", provider)
		http.Redirect(w, r, pfmCtrl.uri, http.StatusFound)
		err := errors.New("Provider config not found for: " + provider)
		log.Error(err.Error())
		http.Redirect(w, r, getErrUrl(err.Error()), http.StatusFound)
		return
	}

@@ -132,16 +170,22 @@ func uaLoginOAuth(w http.ResponseWriter, r *http.Request) {
	state, err := getUniqueState()
	if err != nil {
		log.Error(err.Error())
		http.Redirect(w, r, pfmCtrl.uri, http.StatusFound)
		http.Redirect(w, r, getErrUrl(err.Error()), http.StatusFound)
		return
	}

	// Track oauth request
	// Track oauth request & handle
	request := &LoginRequest{
		provider: provider,
		timestamp: time.Now(),
		timer:    time.NewTimer(10 * time.Minute),
	}
	pfmCtrl.loginRequests[state] = request
	setLoginRequest(state, request)

	// Start timer to remove request from map
	go func() {
		<-request.timer.C
		delLoginRequest(state)
	}()

	// Generate provider-specific oauth redirect
	uri := config.AuthCodeURL(state, oauth2.AccessTypeOnline)
@@ -156,40 +200,51 @@ func uaAuthorize(w http.ResponseWriter, r *http.Request) {
	state := query.Get("state")

	// Validate request state
	request, found := pfmCtrl.loginRequests[state]
	if !found {
		log.Error("Login request not found with provided state: ", state)
		http.Redirect(w, r, pfmCtrl.uri, http.StatusFound)
	request := getLoginRequest(state)
	if request == nil {
		err := errors.New("Invalid OAuth state")
		log.Error(err.Error())
		http.Redirect(w, r, getErrUrl(err.Error()), http.StatusFound)
		return
	}

	// Get provider-specific OAuth config
	config := pfmCtrl.oauthConfigs[request.provider]
	provider := request.provider
	config := pfmCtrl.oauthConfigs[provider]

	// Delete login request & timer
	delLoginRequest(state)

	// Retrieve access token
	token, err := config.Exchange(context.Background(), code)
	if err != nil {
		log.Error(err.Error())
		http.Redirect(w, r, pfmCtrl.uri, http.StatusFound)
		http.Redirect(w, r, getErrUrl(err.Error()), http.StatusFound)
		return
	}

	// Retrieve User ID
	var userId string
	switch request.provider {
	switch provider {
	case OAUTH_PROVIDER_GITHUB:
		oauthClient := config.Client(context.Background(), token)
		if oauthClient == nil {
			err = errors.New("Failed to create new GitHub oauth client")
			log.Error(err.Error())
			http.Redirect(w, r, getErrUrl(err.Error()), http.StatusFound)
			return
		}
		client := github.NewClient(oauthClient)
		if client == nil {
			err = errors.New("Failed to create new GitHub client")
			log.Error(err.Error())
			http.Redirect(w, r, pfmCtrl.uri, http.StatusFound)
			http.Redirect(w, r, getErrUrl(err.Error()), http.StatusFound)
			return
		}
		user, _, err := client.Users.Get(context.Background(), "")
		if err != nil {
			log.Error(err.Error())
			http.Redirect(w, r, pfmCtrl.uri, http.StatusFound)
			http.Redirect(w, r, getErrUrl("Failed to retrieve GitHub user ID"), http.StatusFound)
			return
		}
		userId = *user.Login
@@ -198,13 +253,13 @@ func uaAuthorize(w http.ResponseWriter, r *http.Request) {
		if err != nil {
			err = errors.New("Failed to create new GitLab client")
			log.Error(err.Error())
			http.Redirect(w, r, pfmCtrl.uri, http.StatusFound)
			http.Redirect(w, r, getErrUrl(err.Error()), http.StatusFound)
			return
		}
		user, _, err := client.Users.CurrentUser()
		if err != nil {
			log.Error(err.Error())
			http.Redirect(w, r, pfmCtrl.uri, http.StatusFound)
			http.Redirect(w, r, getErrUrl("Failed to retrieve GitLab user ID"), http.StatusFound)
			return
		}
		userId = user.Username
@@ -212,10 +267,10 @@ func uaAuthorize(w http.ResponseWriter, r *http.Request) {
	}

	// Start user session
	sandboxName, err, errCode := startSession(userId, w, r)
	sandboxName, err, errCode := startSession(provider, userId, w, r)
	if err != nil {
		log.Error(err.Error())
		http.Redirect(w, r, pfmCtrl.uri, errCode)
		http.Redirect(w, r, getErrUrl(err.Error()), errCode)
		return
	}

@@ -231,14 +286,14 @@ func uaLoginUser(w http.ResponseWriter, r *http.Request) {
	password := r.FormValue("password")

	// Validate user credentials
	authenticated, err := pfmCtrl.userStore.AuthenticateUser(username, password)
	authenticated, err := pfmCtrl.userStore.AuthenticateUser(OAUTH_PROVIDER_LOCAL, username, password)
	if err != nil || !authenticated {
		http.Error(w, "Unauthorized", http.StatusUnauthorized)
		return
	}

	// Start user session
	sandboxName, err, errCode := startSession(username, w, r)
	sandboxName, err, errCode := startSession(OAUTH_PROVIDER_LOCAL, username, w, r)
	if err != nil {
		log.Error(err.Error())
		http.Error(w, err.Error(), errCode)
@@ -264,11 +319,11 @@ func uaLoginUser(w http.ResponseWriter, r *http.Request) {
}

// Retrieve existing user session or create a new one
func startSession(username string, w http.ResponseWriter, r *http.Request) (sandboxName string, err error, code int) {
func startSession(provider string, username string, w http.ResponseWriter, r *http.Request) (sandboxName string, err error, code int) {

	// Get existing session by user name, if any
	sessionStore := pfmCtrl.sessionMgr.GetSessionStore()
	session, err := sessionStore.GetByName(username)
	session, err := sessionStore.GetByName(provider, username)
	if err != nil {
		// Check if max session count is reached before creating a new one
		count := sessionStore.GetCount()
@@ -279,7 +334,7 @@ func startSession(username string, w http.ResponseWriter, r *http.Request) (sand

		// Get requested sandbox name & role from user profile, if any
		role := users.RoleUser
		user, err := pfmCtrl.userStore.GetUser(username)
		user, err := pfmCtrl.userStore.GetUser(provider, username)
		if err == nil {
			sandboxName = user.Sboxname
			role = user.Role
@@ -305,6 +360,7 @@ func startSession(username string, w http.ResponseWriter, r *http.Request) (sand
		session = new(sm.Session)
		session.ID = ""
		session.Username = username
		session.Provider = provider
		session.Sandbox = sandboxName
		session.Role = role
	} else {
+0 −19
Original line number Diff line number Diff line
{
  "requires": true,
  "lockfileVersion": 1,
  "dependencies": {
    "get-port": {
      "version": "4.2.0",
      "resolved": "https://registry.npmjs.org/get-port/-/get-port-4.2.0.tgz",
      "integrity": "sha512-/b3jarXkH8KJoOMQc3uVGHASwGLPq3gSFJ7tgJm2diza+bydJPTGOibin2steecKeOylE8oY2JERlVWkAJO6yw=="
    },
    "http-echo-server": {
      "version": "2.1.1",
      "resolved": "https://registry.npmjs.org/http-echo-server/-/http-echo-server-2.1.1.tgz",
      "integrity": "sha512-ybEQrtw0fGmSHZHa8W0tjHFz+m9ZBxWT2aYGWTaAlU2fldrxbjOsgs8qfbXBLwJMsToQIg89mTS9kbw44ZRr8A==",
      "requires": {
        "get-port": "^4.0.0"
      }
    }
  }
}
+0 −19
Original line number Diff line number Diff line
{
  "requires": true,
  "lockfileVersion": 1,
  "dependencies": {
    "get-port": {
      "version": "4.2.0",
      "resolved": "https://registry.npmjs.org/get-port/-/get-port-4.2.0.tgz",
      "integrity": "sha512-/b3jarXkH8KJoOMQc3uVGHASwGLPq3gSFJ7tgJm2diza+bydJPTGOibin2steecKeOylE8oY2JERlVWkAJO6yw=="
    },
    "http-echo-server": {
      "version": "2.1.1",
      "resolved": "https://registry.npmjs.org/http-echo-server/-/http-echo-server-2.1.1.tgz",
      "integrity": "sha512-ybEQrtw0fGmSHZHa8W0tjHFz+m9ZBxWT2aYGWTaAlU2fldrxbjOsgs8qfbXBLwJMsToQIg89mTS9kbw44ZRr8A==",
      "requires": {
        "get-port": "^4.0.0"
      }
    }
  }
}
Loading