From ae26700e590ed1462ec00841672b797f00fc6120 Mon Sep 17 00:00:00 2001 From: Mark Bailey Date: Mon, 25 Nov 2024 14:48:26 -0500 Subject: [PATCH] fix: fix routing and dependency injection --- app/controller/controller.go | 20 ++++++++++++++------ app/handler/handler.go | 8 +++++--- app/routing/router.go | 29 ++++++++++++++--------------- app/routing/routes.go | 2 +- go.mod | 2 +- handlers/admin/index.go | 4 ++-- handlers/shared/homepage.go | 14 +++++++------- handlers/shared/user.go | 27 ++++++++++++++++----------- lib/middleware/middleware.go | 12 ++++++++---- view/user/signin.templ | 4 +--- view/user/signin_templ.go | 4 ++-- 11 files changed, 71 insertions(+), 55 deletions(-) diff --git a/app/controller/controller.go b/app/controller/controller.go index 04467b6..0ab4550 100644 --- a/app/controller/controller.go +++ b/app/controller/controller.go @@ -1,12 +1,12 @@ package controller import ( - "git.markbailey.dev/cerbervs/ptpp/app/session" - "git.markbailey.dev/cerbervs/ptpp/lib/logger" "net/http" + "git.markbailey.dev/cerbervs/ptpp/app/session" + "git.markbailey.dev/cerbervs/ptpp/lib/logger" + "git.markbailey.dev/cerbervs/ptpp/lib/database" - werror "git.markbailey.dev/cerbervs/ptpp/lib/error" ) type IController interface { @@ -19,7 +19,7 @@ type IController interface { Post(w http.ResponseWriter, r *http.Request) error Patch(w http.ResponseWriter, r *http.Request) error Connect(w http.ResponseWriter, r *http.Request) error - Init(session.IManager, database.IDB, logger.ILogger) error + Init(session.IManager, database.IDB, logger.ILogger) IController } type Controller struct { @@ -28,8 +28,8 @@ type Controller struct { Logger logger.ILogger } -func (c Controller) Init(s session.IManager, d database.IDB, l logger.ILogger) error { - return werror.Wrap(nil, "You must implement the init method in your extended controller") +func (c Controller) Init(s session.IManager, d database.IDB, l logger.ILogger) IController { + return nil } func (c Controller) Get(w http.ResponseWriter, _ *http.Request) (_ error) { @@ -39,6 +39,7 @@ func (c Controller) Get(w http.ResponseWriter, _ *http.Request) (_ error) { } return nil } + func (c Controller) Head(w http.ResponseWriter, _ *http.Request) (_ error) { if _, err := w.Write([]byte("not implemented")); err != nil { c.Logger.Error(c.Logger.Wrap(err, "Error writing response")) @@ -46,6 +47,7 @@ func (c Controller) Head(w http.ResponseWriter, _ *http.Request) (_ error) { } return nil } + func (c Controller) Options(w http.ResponseWriter, _ *http.Request) (_ error) { if _, err := w.Write([]byte("not implemented")); err != nil { c.Logger.Error(c.Logger.Wrap(err, "Error writing response")) @@ -53,6 +55,7 @@ func (c Controller) Options(w http.ResponseWriter, _ *http.Request) (_ error) { } return nil } + func (c Controller) Trace(w http.ResponseWriter, _ *http.Request) (_ error) { if _, err := w.Write([]byte("not implemented")); err != nil { c.Logger.Error(c.Logger.Wrap(err, "Error writing response")) @@ -60,6 +63,7 @@ func (c Controller) Trace(w http.ResponseWriter, _ *http.Request) (_ error) { } return nil } + func (c Controller) Put(w http.ResponseWriter, _ *http.Request) (_ error) { if _, err := w.Write([]byte("not implemented")); err != nil { c.Logger.Error(c.Logger.Wrap(err, "Error writing response")) @@ -67,6 +71,7 @@ func (c Controller) Put(w http.ResponseWriter, _ *http.Request) (_ error) { } return nil } + func (c Controller) Delete(w http.ResponseWriter, _ *http.Request) (_ error) { if _, err := w.Write([]byte("not implemented")); err != nil { c.Logger.Error(c.Logger.Wrap(err, "Error writing response")) @@ -74,6 +79,7 @@ func (c Controller) Delete(w http.ResponseWriter, _ *http.Request) (_ error) { } return nil } + func (c Controller) Post(w http.ResponseWriter, _ *http.Request) (_ error) { if _, err := w.Write([]byte("not implemented")); err != nil { c.Logger.Error(c.Logger.Wrap(err, "Error writing response")) @@ -81,6 +87,7 @@ func (c Controller) Post(w http.ResponseWriter, _ *http.Request) (_ error) { } return nil } + func (c Controller) Patch(w http.ResponseWriter, _ *http.Request) (_ error) { if _, err := w.Write([]byte("not implemented")); err != nil { c.Logger.Error(c.Logger.Wrap(err, "Error writing response")) @@ -88,6 +95,7 @@ func (c Controller) Patch(w http.ResponseWriter, _ *http.Request) (_ error) { } return nil } + func (c Controller) Connect(w http.ResponseWriter, _ *http.Request) (_ error) { if _, err := w.Write([]byte("not implemented")); err != nil { c.Logger.Error(c.Logger.Wrap(err, "Error writing response")) diff --git a/app/handler/handler.go b/app/handler/handler.go index 61a1f96..016ec0e 100644 --- a/app/handler/handler.go +++ b/app/handler/handler.go @@ -1,10 +1,12 @@ package handler -import "net/http" +import ( + "net/http" +) -type Handler func(http.ResponseWriter, *http.Request) error +type HandlerFunc func(http.ResponseWriter, *http.Request) error -func (fn Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (fn HandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) { if err := fn(w, r); err != nil { w.Header().Set("HX-Retarget", "#layout_content") w.Header().Set("HX-Reswap", "innerHTML") diff --git a/app/routing/router.go b/app/routing/router.go index 325b5a7..132b614 100644 --- a/app/routing/router.go +++ b/app/routing/router.go @@ -1,6 +1,8 @@ package routing import ( + "net/http" + "git.markbailey.dev/cerbervs/ptpp/app/controller" "git.markbailey.dev/cerbervs/ptpp/app/handler" "git.markbailey.dev/cerbervs/ptpp/app/session" @@ -9,7 +11,6 @@ import ( "git.markbailey.dev/cerbervs/ptpp/lib/logger" "git.markbailey.dev/cerbervs/ptpp/lib/middleware" "git.markbailey.dev/cerbervs/ptpp/util" - "net/http" ) var sess session.IManager @@ -26,21 +27,22 @@ type Route struct { type Router struct { Mux *http.ServeMux - BasePath string - Routes []Route SubRouters *[]Router Middleware *[]middleware.Func + BasePath string + Routes []Route } func (r Router) HandleAllRequestMethods(route Route) { - r.Mux.Handle("GET "+r.BasePath+route.Path, handler.Handler(route.Controller.Get)) - r.Mux.Handle("OPTIONS "+r.BasePath+route.Path, handler.Handler(route.Controller.Options)) - r.Mux.Handle("TRACE "+r.BasePath+route.Path, handler.Handler(route.Controller.Trace)) - r.Mux.Handle("PUT "+r.BasePath+route.Path, handler.Handler(route.Controller.Put)) - r.Mux.Handle("DELETE "+r.BasePath+route.Path, handler.Handler(route.Controller.Delete)) - r.Mux.Handle("POST "+r.BasePath+route.Path, handler.Handler(route.Controller.Post)) - r.Mux.Handle("PATCH "+r.BasePath+route.Path, handler.Handler(route.Controller.Patch)) - r.Mux.Handle("CONNECT "+r.BasePath+route.Path, handler.Handler(route.Controller.Connect)) + c := route.Controller.Init(sess, database.ChooseDB(), logger.NewCompositeLogger()) + r.Mux.Handle("GET "+r.BasePath+route.Path, handler.HandlerFunc(c.Get)) + r.Mux.Handle("OPTIONS "+r.BasePath+route.Path, handler.HandlerFunc(c.Options)) + r.Mux.Handle("TRACE "+r.BasePath+route.Path, handler.HandlerFunc(c.Trace)) + r.Mux.Handle("PUT "+r.BasePath+route.Path, handler.HandlerFunc(c.Put)) + r.Mux.Handle("DELETE "+r.BasePath+route.Path, handler.HandlerFunc(c.Delete)) + r.Mux.Handle("POST "+r.BasePath+route.Path, handler.HandlerFunc(c.Post)) + r.Mux.Handle("PATCH "+r.BasePath+route.Path, handler.HandlerFunc(c.Patch)) + r.Mux.Handle("CONNECT "+r.BasePath+route.Path, handler.HandlerFunc(c.Connect)) } func (r Router) RegisterRoutes() http.Handler { @@ -49,10 +51,6 @@ func (r Router) RegisterRoutes() http.Handler { } for _, route := range r.Routes { - if err := route.Controller.Init(sess, database.ChooseDB(), logger.NewCompositeLogger()); err != nil { - panic(err) - } - r.HandleAllRequestMethods(route) } @@ -108,6 +106,7 @@ func (r Router) GetRouteByName(name string) (string, error) { return "", werror.Wrap(nil, "Route not found") } + func init() { var err error diff --git a/app/routing/routes.go b/app/routing/routes.go index 9481137..a6ab7f1 100644 --- a/app/routing/routes.go +++ b/app/routing/routes.go @@ -11,7 +11,7 @@ var AppRouter = Router{ Mux: http.NewServeMux(), BasePath: "/", Routes: []Route{ - {Controller: &shared.HomePageHandler{}, Path: "", Name: "app.index"}, + {Controller: &shared.HomePageController{}, Path: "", Name: "app.index"}, {Controller: &shared.SignUpHandler{}, Path: "sign-up", Name: "app.user.sign_up"}, {Controller: &shared.SignInHandler{}, Path: "sign-in", Name: "app.user.sign_in"}, {Controller: &shared.SignOutHandler{}, Path: "sign-out", Name: "app.user.sign_out"}, diff --git a/go.mod b/go.mod index eef1f19..cb5019d 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module git.markbailey.dev/cerbervs/ptpp -go 1.23.2 +go 1.23 require ( github.com/a-h/templ v0.2.793 diff --git a/handlers/admin/index.go b/handlers/admin/index.go index 002aa48..eb0015c 100644 --- a/handlers/admin/index.go +++ b/handlers/admin/index.go @@ -17,12 +17,12 @@ type IndexHandler struct { controller.Controller } -func (h IndexHandler) Init(s session.IManager, d database.IDB, l logger.ILogger) error { +func (h IndexHandler) Init(s session.IManager, d database.IDB, l logger.ILogger) controller.IController { h.Logger = l h.Db = d h.Session = s - return nil + return h } func (h IndexHandler) Get(w http.ResponseWriter, r *http.Request) error { diff --git a/handlers/shared/homepage.go b/handlers/shared/homepage.go index f8d7b6f..5aa9a52 100644 --- a/handlers/shared/homepage.go +++ b/handlers/shared/homepage.go @@ -13,19 +13,19 @@ import ( "git.markbailey.dev/cerbervs/ptpp/view/homepage" ) -type HomePageHandler struct { +type HomePageController struct { controller.Controller } -func (h HomePageHandler) Init(s session.IManager, d database.IDB, l logger.ILogger) error { - h.Logger = l - h.Db = d - h.Session = s +func (c HomePageController) Init(s session.IManager, d database.IDB, l logger.ILogger) controller.IController { + c.Logger = l + c.Db = d + c.Session = s - return nil + return c } -func (h HomePageHandler) Get(w http.ResponseWriter, r *http.Request) error { +func (h HomePageController) Get(w http.ResponseWriter, r *http.Request) error { if r.URL.Path != "/" { w.WriteHeader(http.StatusNotFound) err := layout.NotFound().Render(context.Background(), w) diff --git a/handlers/shared/user.go b/handlers/shared/user.go index 37ccc32..a6d178c 100644 --- a/handlers/shared/user.go +++ b/handlers/shared/user.go @@ -3,14 +3,15 @@ package shared import ( "context" "encoding/json" + "net/http" + "time" + "git.markbailey.dev/cerbervs/ptpp/app/controller" "git.markbailey.dev/cerbervs/ptpp/app/session" "git.markbailey.dev/cerbervs/ptpp/lib/database" "git.markbailey.dev/cerbervs/ptpp/lib/database/dto" "git.markbailey.dev/cerbervs/ptpp/lib/logger" "git.markbailey.dev/cerbervs/ptpp/util" - "net/http" - "time" "git.markbailey.dev/cerbervs/ptpp/view/user" ) @@ -19,48 +20,48 @@ type SignUpHandler struct { controller.Controller } -func (c SignUpHandler) Init(s session.IManager, d database.IDB, l logger.ILogger) error { +func (c SignUpHandler) Init(s session.IManager, d database.IDB, l logger.ILogger) controller.IController { c.Logger = l c.Db = d c.Session = s - return nil + return c } type SignInHandler struct { controller.Controller } -func (c SignInHandler) Init(s session.IManager, d database.IDB, l logger.ILogger) error { +func (c SignInHandler) Init(s session.IManager, d database.IDB, l logger.ILogger) controller.IController { c.Logger = l c.Db = d c.Session = s - return nil + return c } type PopulateHandler struct { controller.Controller } -func (c PopulateHandler) Init(s session.IManager, d database.IDB, l logger.ILogger) error { +func (c PopulateHandler) Init(s session.IManager, d database.IDB, l logger.ILogger) controller.IController { c.Logger = l c.Db = d c.Session = s - return nil + return c } type SignOutHandler struct { controller.Controller } -func (c SignOutHandler) Init(s session.IManager, d database.IDB, l logger.ILogger) error { +func (c SignOutHandler) Init(s session.IManager, d database.IDB, l logger.ILogger) controller.IController { c.Logger = l c.Db = d c.Session = s - return nil + return c } func (c PopulateHandler) Get(w http.ResponseWriter, r *http.Request) error { @@ -144,7 +145,7 @@ func (c SignInHandler) Post(w http.ResponseWriter, r *http.Request) error { foundUser, err := c.Db.Repo().FindUserByUsername(fd.Username) if foundUser == nil || err != nil { - return util.Redirect(w, r, "/sign-in", http.StatusSeeOther, false) + return util.Redirect(w, r, "/sign-in", http.StatusPermanentRedirect, false) } authenticated, err := util.CheckPassword(fd.Password, foundUser.Password) @@ -250,6 +251,10 @@ func (c SignUpHandler) Post(w http.ResponseWriter, r *http.Request) error { } foundUser, err := c.Db.Repo().FindUserByUsername(fd.Username) + if err != nil { + c.Db.Error(err) + return err + } if foundUser != nil { if _, err := w.Write([]byte("Invalid username. Please try another")); err != nil { diff --git a/lib/middleware/middleware.go b/lib/middleware/middleware.go index c210e87..7733cbd 100644 --- a/lib/middleware/middleware.go +++ b/lib/middleware/middleware.go @@ -75,15 +75,15 @@ func WithAuth(next http.Handler) http.Handler { } if cookie, err = r.Cookie("token"); err != nil { - _ = util.Redirect(w, r, "/signin", http.StatusSeeOther, true) + _ = util.Redirect(w, r, "/sign-in", http.StatusPermanentRedirect, true) return } if token = cookie.Value; token == "" { - _ = util.Redirect(w, r, "/signin", http.StatusSeeOther, true) + _ = util.Redirect(w, r, "/sign-in", http.StatusPermanentRedirect, true) return } if claims, err = util.ParseToken(token, os.Getenv("TOKEN_SECRET")); err != nil { - _ = util.Redirect(w, r, "/signin", http.StatusSeeOther, true) + _ = util.Redirect(w, r, "/sign-in", http.StatusPermanentRedirect, true) return } @@ -136,5 +136,9 @@ func WithUsername(next http.Handler) http.Handler { } func init() { - sess, _ = session.NewManager("memory", "ptpp", 3600) + var err error + sess, err = session.NewManager("memory", "ptpp", 3600) + if err != nil { + panic(err) + } } diff --git a/view/user/signin.templ b/view/user/signin.templ index bf3be38..389388f 100644 --- a/view/user/signin.templ +++ b/view/user/signin.templ @@ -15,9 +15,7 @@ templ SignIn(err string) {
diff --git a/view/user/signin_templ.go b/view/user/signin_templ.go index ea73fc6..5aeaef9 100644 --- a/view/user/signin_templ.go +++ b/view/user/signin_templ.go @@ -43,7 +43,7 @@ func SignIn(err string) templ.Component { }() } ctx = templ.InitializeContext(ctx) - _, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString("

Create an account

") + _, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString("

Create an account

") if templ_7745c5c3_Err != nil { return templ_7745c5c3_Err } @@ -55,7 +55,7 @@ func SignIn(err string) templ.Component { var templ_7745c5c3_Var3 string templ_7745c5c3_Var3, templ_7745c5c3_Err = templ.JoinStringErrs(err) if templ_7745c5c3_Err != nil { - return templ.Error{Err: templ_7745c5c3_Err, FileName: `view/user/signin.templ`, Line: 54, Col: 14} + return templ.Error{Err: templ_7745c5c3_Err, FileName: `view/user/signin.templ`, Line: 52, Col: 14} } _, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var3)) if templ_7745c5c3_Err != nil {