From 6b080f3821d8a503aeba175b44fa483ba7514846 Mon Sep 17 00:00:00 2001 From: Pierre Zemb Date: Tue, 14 Nov 2017 00:19:45 +0100 Subject: [PATCH] Fixes #6: add http basic auth --- README.md | 23 ++++++++++++++++++++++ auth.go | 58 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ main.go | 27 +++++++++++++++++++++++--- 3 files changed, 105 insertions(+), 3 deletions(-) create mode 100644 auth.go diff --git a/README.md b/README.md index 185476e..f4eb154 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,29 @@ Many links should provide you with additionnal info to see my point of view: docker run -d -p 80:8043 -v path/to/website:/srv/http --name goStatic pierrezemb/gostatic ``` +### Usage + +``` +./goStatic --help +Usage of ./goStatic: + -append-header HeaderName:Value + HTTP response header, specified as HeaderName:Value that should be added to all responses. + -default-user-basic-auth string + Define the user (default "gopher") + -enable-basic-auth + Enable basic auth. By default, password are randomly generated. Use --set-basic-auth to set it. + -password-length int + Size of the randomized password (default 16) + -path string + The path for the static files (default "/srv/http") + -port int + The listening port (default 8043) + -set-basic-auth string + Define the basic auth. Form must be user:password +``` + + + ### Wow, such container! What are you using? I'm using the centurylink/ca-certs image instead of the scratch image to avoid this error: diff --git a/auth.go b/auth.go new file mode 100644 index 0000000..88fd9cb --- /dev/null +++ b/auth.go @@ -0,0 +1,58 @@ +package main + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "log" + "net/http" + "strings" +) + +// authMiddleware checks basic auth +func authMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + auth := strings.SplitN(r.Header.Get("Authorization"), " ", 2) + + if len(auth) != 2 || auth[0] != "Basic" { + http.Error(w, "authorization failed", http.StatusUnauthorized) + return + } + + payload, _ := base64.StdEncoding.DecodeString(auth[1]) + pair := strings.SplitN(string(payload), ":", 2) + + if strings.Compare(pair[0], username) != 0 || strings.Compare(pair[1], password) != 0 { + http.Error(w, "authorization failed", http.StatusUnauthorized) + return + } + + next.ServeHTTP(w, r) + }) +} + +func parseAuth(auth string) { + identity := strings.Split(*setBasicAuth, ":") + if len(identity) != 2 { + log.Fatalln("basic auth must be like this: user:password") + } + + username = identity[0] + password = identity[1] +} + +func generateRandomAuth() { + username = "gopher" + password = generateRandomString() + log.Printf("User generated for basic auth. User:'%v', password:'%v'\n", username, password) +} + +func generateRandomString() string { + + b := make([]byte, *sizeRandom) + if _, err := rand.Read(b); err != nil { + panic(err) + } + return fmt.Sprintf("%X", b) +} diff --git a/main.go b/main.go index 9d91148..a1017cc 100644 --- a/main.go +++ b/main.go @@ -13,9 +13,16 @@ import ( var ( // Def of flags - portPtr = flag.Int("p", 8043, "The listening port") - path = flag.String("static", "/srv/http", "The path for the static files") - headerFlag = flag.String("appendHeader", "", "HTTP response header, specified as `HeaderName:Value` that should be added to all responses.") + portPtr = flag.Int("port", 8043, "The listening port") + path = flag.String("path", "/srv/http", "The path for the static files") + headerFlag = flag.String("append-header", "", "HTTP response header, specified as `HeaderName:Value` that should be added to all responses.") + basicAuth = flag.Bool("enable-basic-auth", false, "Enable basic auth. By default, password are randomly generated. Use --set-basic-auth to set it.") + setBasicAuth = flag.String("set-basic-auth", "", "Define the basic auth. Form must be user:password") + defaultUsernameBasicAuth = flag.String("default-user-basic-auth", "gopher", "Define the user") + sizeRandom = flag.Int("password-length", 16, "Size of the randomized password") + + username string + password string ) func parseHeaderFlag(headerFlag string) (string, string) { @@ -33,10 +40,24 @@ func main() { flag.Parse() + // sanity check + if len(*setBasicAuth) != 0 && !*basicAuth { + *basicAuth = true + } + port := ":" + strconv.FormatInt(int64(*portPtr), 10) handler := http.FileServer(http.Dir(*path)) + if *basicAuth { + if len(*setBasicAuth) != 0 { + parseAuth(*setBasicAuth) + } else { + generateRandomAuth() + } + handler = authMiddleware(handler) + } + // Extra headers. if len(*headerFlag) > 0 { header, headerValue := parseHeaderFlag(*headerFlag)