Commit 92fff640 authored by Kevin Di Lallo's avatar Kevin Di Lallo
Browse files

oauth login request cleanup + oauth error reporting in redirect url

parent ce5d4ad4
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -49,7 +49,7 @@ type Scenario struct {

type LoginRequest struct {
	provider string
	timestamp time.Time
	timer    *time.Timer
}

type PlatformCtrl struct {
+64 −17
Original line number Diff line number Diff line
@@ -35,6 +35,7 @@ import (
	"net/http"
	"os"
	"strings"
	"sync"
	"time"

	dataModel "github.com/InterDigitalInc/AdvantEDGE/go-packages/meep-data-model"
@@ -51,6 +52,8 @@ import (
const OAUTH_PROVIDER_GITHUB = "github"
const OAUTH_PROVIDER_GITLAB = "gitlab"

var mutex sync.Mutex

func initOAuth() {

	// Get default platform URI
@@ -113,6 +116,39 @@ func getUniqueState() (state string, err error) {
	return "", errors.New("Failed to generate a random state string")
}

func getLoginRequest(state string) *LoginRequest {
	mutex.Lock()
	defer mutex.Unlock()
	request, found := pfmCtrl.loginRequests[state]
	if !found {
		return nil
	}
	return request
}

func setLoginRequest(state string, request *LoginRequest) {
	mutex.Lock()
	defer mutex.Unlock()
	pfmCtrl.loginRequests[state] = request
}

func delLoginRequest(state string) {
	mutex.Lock()
	defer mutex.Unlock()
	request, found := pfmCtrl.loginRequests[state]
	if !found {
		return
	}
	if request.timer != nil {
		request.timer.Stop()
	}
	delete(pfmCtrl.loginRequests, state)
}

func getErrUrl(err string) string {
	return pfmCtrl.uri + "?err=" + strings.ReplaceAll(err, " ", "_")
}

func uaLoginOAuth(w http.ResponseWriter, r *http.Request) {
	log.Info("----- OAUTH LOGIN -----")

@@ -123,8 +159,9 @@ func uaLoginOAuth(w http.ResponseWriter, r *http.Request) {
	// Get provider-specific OAuth config
	config, found := pfmCtrl.oauthConfigs[provider]
	if !found {
		log.Error("Provider config not found for: ", provider)
		http.Redirect(w, r, pfmCtrl.uri, http.StatusFound)
		err := errors.New("Provider config not found for: " + provider)
		log.Error(err.Error())
		http.Redirect(w, r, getErrUrl(err.Error()), http.StatusFound)
		return
	}

@@ -132,16 +169,22 @@ func uaLoginOAuth(w http.ResponseWriter, r *http.Request) {
	state, err := getUniqueState()
	if err != nil {
		log.Error(err.Error())
		http.Redirect(w, r, pfmCtrl.uri, http.StatusFound)
		http.Redirect(w, r, getErrUrl(err.Error()), http.StatusFound)
		return
	}

	// Track oauth request
	// Track oauth request & handle
	request := &LoginRequest{
		provider: provider,
		timestamp: time.Now(),
		timer:    time.NewTimer(10 * time.Minute),
	}
	pfmCtrl.loginRequests[state] = request
	setLoginRequest(state, request)

	// Start timer to remove request from map
	go func() {
		<-request.timer.C
		delLoginRequest(state)
	}()

	// Generate provider-specific oauth redirect
	uri := config.AuthCodeURL(state, oauth2.AccessTypeOnline)
@@ -156,13 +199,17 @@ func uaAuthorize(w http.ResponseWriter, r *http.Request) {
	state := query.Get("state")

	// Validate request state
	request, found := pfmCtrl.loginRequests[state]
	if !found {
		log.Error("Login request not found with provided state: ", state)
		http.Redirect(w, r, pfmCtrl.uri, http.StatusFound)
	request := getLoginRequest(state)
	if request == nil {
		err := errors.New("Invalid OAuth state")
		log.Error(err.Error())
		http.Redirect(w, r, getErrUrl(err.Error()), http.StatusFound)
		return
	}

	// Delete login request & timer
	delLoginRequest(state)

	// Get provider-specific OAuth config
	config := pfmCtrl.oauthConfigs[request.provider]

@@ -170,7 +217,7 @@ func uaAuthorize(w http.ResponseWriter, r *http.Request) {
	token, err := config.Exchange(context.Background(), code)
	if err != nil {
		log.Error(err.Error())
		http.Redirect(w, r, pfmCtrl.uri, http.StatusFound)
		http.Redirect(w, r, getErrUrl(err.Error()), http.StatusFound)
		return
	}

@@ -183,13 +230,13 @@ func uaAuthorize(w http.ResponseWriter, r *http.Request) {
		if client == nil {
			err = errors.New("Failed to create new GitHub client")
			log.Error(err.Error())
			http.Redirect(w, r, pfmCtrl.uri, http.StatusFound)
			http.Redirect(w, r, getErrUrl(err.Error()), http.StatusFound)
			return
		}
		user, _, err := client.Users.Get(context.Background(), "")
		if err != nil {
			log.Error(err.Error())
			http.Redirect(w, r, pfmCtrl.uri, http.StatusFound)
			http.Redirect(w, r, getErrUrl("Failed to retrieve GitHub user ID"), http.StatusFound)
			return
		}
		userId = *user.Login
@@ -198,13 +245,13 @@ func uaAuthorize(w http.ResponseWriter, r *http.Request) {
		if err != nil {
			err = errors.New("Failed to create new GitLab client")
			log.Error(err.Error())
			http.Redirect(w, r, pfmCtrl.uri, http.StatusFound)
			http.Redirect(w, r, getErrUrl(err.Error()), http.StatusFound)
			return
		}
		user, _, err := client.Users.CurrentUser()
		if err != nil {
			log.Error(err.Error())
			http.Redirect(w, r, pfmCtrl.uri, http.StatusFound)
			http.Redirect(w, r, getErrUrl("Failed to retrieve GitLab user ID"), http.StatusFound)
			return
		}
		userId = user.Username
@@ -215,7 +262,7 @@ func uaAuthorize(w http.ResponseWriter, r *http.Request) {
	sandboxName, err, errCode := startSession(userId, w, r)
	if err != nil {
		log.Error(err.Error())
		http.Redirect(w, r, pfmCtrl.uri, errCode)
		http.Redirect(w, r, getErrUrl(err.Error()), errCode)
		return
	}