diff --git a/.air.toml b/.air.toml index 8ea14fb..b162020 100644 --- a/.air.toml +++ b/.air.toml @@ -4,8 +4,8 @@ tmp_dir = "tmp" [build] args_bin = [] - bin = "./tmp/main" - cmd = "go build -o ./tmp/main ./cmd/" + bin = "./air-bin" + cmd = "go build -o ./air-bin ./cmd/" delay = 1000 exclude_dir = ["public", "data", "tmp"] exclude_file = [] diff --git a/app/controller/ctx.go b/app/controller/ctx.go index 7d56d2c..ba8947f 100644 --- a/app/controller/ctx.go +++ b/app/controller/ctx.go @@ -1,7 +1,5 @@ package controller -import werror "git.markbailey.dev/cerbervs/ptpp/lib/error" - type IRouterCtx interface { GetRouteByName(string) (string, error) } @@ -14,10 +12,10 @@ type ControllerCtx struct { RouterCtx IRouterCtx } -func (c *ControllerCtx) GetRouteByName(name string) string { +func (c ControllerCtx) GetRouteByName(name string) string { path, err := c.RouterCtx.GetRouteByName(name) if err != nil { - panic(werror.Wrap(err, "Error getting route by name")) + panic(err) } return path diff --git a/app/routing/router.go b/app/routing/router.go index a283172..bb4acd2 100644 --- a/app/routing/router.go +++ b/app/routing/router.go @@ -1,7 +1,9 @@ package routing import ( + "errors" "net/http" + "sync" "git.markbailey.dev/cerbervs/ptpp/app/controller" "git.markbailey.dev/cerbervs/ptpp/app/handler" @@ -33,48 +35,64 @@ type Router struct { Routes []Route } -func (r Router) HandleAllRequestMethods(route Route) { - ctx := &controller.ControllerCtx{RouterCtx: r} +var ( + rtr *Router + rtrOnce sync.Once + mux *http.ServeMux + muxOnce sync.Once + fsOnce sync.Once +) + +func (r *Router) HandleAllRequestMethods(route Route, mux *http.ServeMux) { + ctx := controller.ControllerCtx{RouterCtx: rtr} c := route.Controller.Init(sess, database.ChooseDB(), logger.NewCompositeLogger(), ctx) - r.Mux.Handle("GET "+r.BasePath+route.Path, handler.HandlerFunc(c.Get)) - r.Mux.Handle("OPTIONS "+r.BasePath+route.Path, handler.HandlerFunc(c.Options)) - r.Mux.Handle("TRACE "+r.BasePath+route.Path, handler.HandlerFunc(c.Trace)) - r.Mux.Handle("PUT "+r.BasePath+route.Path, handler.HandlerFunc(c.Put)) - r.Mux.Handle("DELETE "+r.BasePath+route.Path, handler.HandlerFunc(c.Delete)) - r.Mux.Handle("POST "+r.BasePath+route.Path, handler.HandlerFunc(c.Post)) - r.Mux.Handle("PATCH "+r.BasePath+route.Path, handler.HandlerFunc(c.Patch)) - r.Mux.Handle("CONNECT "+r.BasePath+route.Path, handler.HandlerFunc(c.Connect)) + mux.Handle("GET "+r.BasePath+route.Path, handler.HandlerFunc(c.Get)) + mux.Handle("OPTIONS "+r.BasePath+route.Path, handler.HandlerFunc(c.Options)) + mux.Handle("TRACE "+r.BasePath+route.Path, handler.HandlerFunc(c.Trace)) + mux.Handle("PUT "+r.BasePath+route.Path, handler.HandlerFunc(c.Put)) + mux.Handle("DELETE "+r.BasePath+route.Path, handler.HandlerFunc(c.Delete)) + mux.Handle("POST "+r.BasePath+route.Path, handler.HandlerFunc(c.Post)) + mux.Handle("PATCH "+r.BasePath+route.Path, handler.HandlerFunc(c.Patch)) + mux.Handle("CONNECT "+r.BasePath+route.Path, handler.HandlerFunc(c.Connect)) } -func (r Router) RegisterRoutes() http.Handler { - if r.Mux == nil { - r.Mux = http.NewServeMux() - } +func (r *Router) RegisterRoutes() http.Handler { + rtrOnce.Do(func() { + rtr = r + }) + muxOnce.Do(func() { + mux = http.NewServeMux() + r.Mux = mux + }) + fsOnce.Do(func() { + r.RegisterFs() + }) - for _, route := range r.Routes { - r.HandleAllRequestMethods(route) - } + if r.Mux == nil { + r.Mux = http.NewServeMux() + } if r.SubRouters != nil { for _, subRouter := range *r.SubRouters { sr := subRouter.RegisterRoutes() - r.Mux.Handle("GET "+r.BasePath+subRouter.BasePath, sr) + r.Mux.Handle(r.BasePath+subRouter.BasePath, sr) } } + for _, route := range r.Routes { + r.HandleAllRequestMethods(route, r.Mux) + } + if r.Middleware != nil { mw := middleware.Compose(*r.Middleware) - return mw(r.Mux) + mctx := middleware.MiddlewareCtx{RouterCtx: rtr} + return mw(mux, mctx) } - return r.Mux + return mux } -func (r Router) RegisterFs() { - if r.Mux == nil { - r.Mux = http.NewServeMux() - } - +func (r *Router) RegisterFs() { fs := http.FileServer(http.Dir(util.GetFullyQualifiedPath("/public"))) r.Mux.Handle("GET "+r.BasePath+"public/", http.StripPrefix("/public/", fs)) } @@ -84,11 +102,11 @@ type RouteMapping struct { Name string } -func (r Router) GetFlatRouteList() []RouteMapping { +func (r *Router) GetFlatRouteList() []RouteMapping { var routes []RouteMapping for _, route := range r.Routes { - routes = append(routes, RouteMapping{Path: r.BasePath+route.Path, Name: route.Name}) + routes = append(routes, RouteMapping{Path: r.BasePath + route.Path, Name: route.Name}) } if r.SubRouters == nil { @@ -101,14 +119,14 @@ func (r Router) GetFlatRouteList() []RouteMapping { return routes } -func (r Router) GetRouteByName(name string) (string, error) { +func (r *Router) GetRouteByName(name string) (string, error) { for _, route := range r.GetFlatRouteList() { if route.Name == name { return route.Path, nil } } - return "", werror.Wrap(nil, "Route not found") + return "", werror.Wrap(errors.New(name+" does not exist"), "Route not found") } func init() { diff --git a/app/routing/routes.go b/app/routing/routes.go index a6ab7f1..c6686ab 100644 --- a/app/routing/routes.go +++ b/app/routing/routes.go @@ -1,33 +1,45 @@ package routing import ( + "sync" + "git.markbailey.dev/cerbervs/ptpp/handlers/admin" "git.markbailey.dev/cerbervs/ptpp/handlers/shared" "git.markbailey.dev/cerbervs/ptpp/lib/middleware" - "net/http" ) -var AppRouter = Router{ - Mux: http.NewServeMux(), - BasePath: "/", - Routes: []Route{ - {Controller: &shared.HomePageController{}, Path: "", Name: "app.index"}, - {Controller: &shared.SignUpHandler{}, Path: "sign-up", Name: "app.user.sign_up"}, - {Controller: &shared.SignInHandler{}, Path: "sign-in", Name: "app.user.sign_in"}, - {Controller: &shared.SignOutHandler{}, Path: "sign-out", Name: "app.user.sign_out"}, - {Controller: &shared.PopulateHandler{}, Path: "populate", Name: "app.populate"}, - }, - SubRouters: &[]Router{ - { +var ( + rtrinst *Router + rtrinstOnce sync.Once +) + +func NewRouter() *Router { + rtrinstOnce.Do(func() { + rtrinst = &Router{ Mux: nil, - BasePath: "admin/", + BasePath: "/", Routes: []Route{ - {Controller: &admin.IndexHandler{}, Path: "", Name: "app.admin.index"}, - {Controller: &admin.IndexHandler{}, Path: "butt", Name: "app.admin.butt"}, + {Controller: shared.HomePageController{}, Path: "", Name: "app.index"}, + {Controller: shared.SignUpHandler{}, Path: "sign-up", Name: "app.user.sign_up"}, + {Controller: shared.SignInHandler{}, Path: "sign-in", Name: "app.user.sign_in"}, + {Controller: shared.SignOutHandler{}, Path: "sign-out", Name: "app.user.sign_out"}, + {Controller: shared.PopulateHandler{}, Path: "populate", Name: "app.populate"}, }, - SubRouters: nil, - Middleware: &[]middleware.Func{middleware.WithAuth}, - }, - }, - Middleware: &[]middleware.Func{middleware.WithLogger, middleware.WithUsername}, + SubRouters: &[]Router{ + { + Mux: nil, + BasePath: "/admin/", + Routes: []Route{ + {Controller: admin.IndexHandler{}, Path: "", Name: "app.admin.index"}, + {Controller: admin.IndexHandler{}, Path: "butt", Name: "app.admin.butt"}, + }, + SubRouters: nil, + Middleware: &[]middleware.Func{middleware.WithAuth}, + }, + }, + Middleware: &[]middleware.Func{middleware.DontPanic, middleware.WithLogger, middleware.WithUsername}, + } + }) + + return rtrinst } diff --git a/bin/app b/bin/app index f8e80b6..7363c92 100755 --- a/bin/app +++ b/bin/app @@ -290,6 +290,7 @@ __sqlc() { echo ================================================================================ echo = Generating SQLC ============================================================== echo ================================================================================ + rm -rf ./models/* sqlc generate echo -e "\n" } @@ -298,6 +299,7 @@ __templ() { echo ================================================================================ echo = Generating templates ========================================================= echo ================================================================================ + find . -name '*_templ.go' -delete templ generate echo -e "\n" } diff --git a/cmd/main.go b/cmd/main.go index f7be0d7..8d0b026 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -19,10 +19,8 @@ func main() { devPort = 8080 ) - r := routing.AppRouter - - rh := r.RegisterRoutes() - r.RegisterFs() + r := routing.NewRouter() + mux := r.RegisterRoutes() var port int if os.Getenv("HTMX_APP_ENV") == "production" { @@ -35,7 +33,7 @@ func main() { Addr: addr, Server: http.Server{ Addr: addr + ":" + strconv.Itoa(port), - Handler: rh, + Handler: mux, DisableGeneralOptionsHandler: false, TLSConfig: nil, ReadTimeout: 5 * time.Second, diff --git a/handlers/shared/user.go b/handlers/shared/user.go index 7baa940..c0b1698 100644 --- a/handlers/shared/user.go +++ b/handlers/shared/user.go @@ -66,6 +66,7 @@ func (c SignInHandler) Init(s session.IManager, d database.IDB, l logger.ILogger c.Logger = l c.Db = d c.Session = s + c.Ctx = ctx return c } @@ -89,9 +90,9 @@ func (c SignInHandler) Get(w http.ResponseWriter, r *http.Request) error { } if foundUser.Admin == 1 { - return util.Redirect(w, r, "/admin/", http.StatusSeeOther, false) + return util.Redirect(w, r, c.Ctx.GetRouteByName("app.admin.index"), http.StatusSeeOther, false) } else { - return util.Redirect(w, r, "/", http.StatusSeeOther, false) + return util.Redirect(w, r, c.Ctx.GetRouteByName("app.index"), http.StatusSeeOther, false) } } } @@ -101,7 +102,7 @@ func (c SignInHandler) Get(w http.ResponseWriter, r *http.Request) error { formError = "" } - if err := user.SignIn(formError).Render(context.Background(), w); err != nil { + if err := user.SignIn(formError, c.Ctx).Render(context.Background(), w); err != nil { c.Logger.Error(c.Logger.Wrap(err, "Error rendering sign in form")) return err } @@ -122,7 +123,7 @@ func (c SignInHandler) Post(w http.ResponseWriter, r *http.Request) error { foundUser, err := c.Db.Repo().FindUserByUsername(fd.Username) if foundUser == nil || err != nil { - return util.Redirect(w, r, "/sign-in", http.StatusPermanentRedirect, false) + return util.Redirect(w, r, c.Ctx.GetRouteByName("app.user.sign_in"), http.StatusPermanentRedirect, false) } authenticated, err := util.CheckPassword(fd.Password, foundUser.Password) @@ -182,9 +183,9 @@ func (c SignInHandler) Post(w http.ResponseWriter, r *http.Request) error { c.Logger.Info(foundUser.Username + " logged in") if foundUser.Admin == 1 { - return util.Redirect(w, r, "/admin/", http.StatusSeeOther, false) + return util.Redirect(w, r, c.Ctx.GetRouteByName("app.admin.index"), http.StatusSeeOther, false) } else { - return util.Redirect(w, r, "/", http.StatusSeeOther, false) + return util.Redirect(w, r, c.Ctx.GetRouteByName("app.index"), http.StatusSeeOther, false) } } diff --git a/lib/middleware/middleware.go b/lib/middleware/middleware.go index 7733cbd..ab47afd 100644 --- a/lib/middleware/middleware.go +++ b/lib/middleware/middleware.go @@ -2,37 +2,71 @@ package middleware import ( "fmt" - "git.markbailey.dev/cerbervs/ptpp/app/session" - "git.markbailey.dev/cerbervs/ptpp/util" "net/http" "os" + "git.markbailey.dev/cerbervs/ptpp/app/session" + "git.markbailey.dev/cerbervs/ptpp/util" + "git.markbailey.dev/cerbervs/ptpp/lib/logger" ) -type Func func(http.Handler) http.Handler +type IRouterCtx interface { + GetRouteByName(string) (string, error) +} + +type IMiddlewareCtx interface { + GetRouteByName(string) string +} + +type MiddlewareCtx struct { + RouterCtx IRouterCtx +} + +func (m MiddlewareCtx) GetRouteByName(name string) string { + path, err := m.RouterCtx.GetRouteByName(name) + if err != nil { + panic(err) + } + + return path +} + +type Func func(http.Handler, IMiddlewareCtx) http.Handler var sess session.IManager func Compose(xs []Func) Func { - return func(next http.Handler) http.Handler { + return func(next http.Handler, m IMiddlewareCtx) http.Handler { for i := len(xs) - 1; i >= 0; i-- { x := xs[i] - next = x(next) + next = x(next, m) } return next } } -func WithLogger(next http.Handler) http.Handler { +func DontPanic(next http.Handler, m IMiddlewareCtx) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if r := recover(); r != nil { + http.Error(w, fmt.Sprintf("%v", r), http.StatusInternalServerError) + } + }() + next.ServeHTTP(w, r) + }) +} + +func WithLogger(next http.Handler, m IMiddlewareCtx) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handlerSess := sess.SessionStart(w, r) username, ok := handlerSess.Get("username").(string) if !ok { - username = "" + username = "" } + username = "(username " + username + ")" ipAddr := r.Header.Get("X-Real-IP") if ipAddr == "" { @@ -44,7 +78,7 @@ func WithLogger(next http.Handler) http.Handler { handlerLogger := logger.NewCompositeLogger() output := fmt.Sprintf( - "%s Request sent from %s to %s (username? %s)", + "%s Request sent from %s to %s %s", r.Method, ipAddr, r.URL.Path, @@ -56,7 +90,7 @@ func WithLogger(next http.Handler) http.Handler { }) } -func WithAuth(next http.Handler) http.Handler { +func WithAuth(next http.Handler, m IMiddlewareCtx) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var ( claims *util.CustomClaims @@ -75,15 +109,15 @@ func WithAuth(next http.Handler) http.Handler { } if cookie, err = r.Cookie("token"); err != nil { - _ = util.Redirect(w, r, "/sign-in", http.StatusPermanentRedirect, true) + _ = util.Redirect(w, r, m.GetRouteByName("app.user.sign_in"), http.StatusPermanentRedirect, true) return } if token = cookie.Value; token == "" { - _ = util.Redirect(w, r, "/sign-in", http.StatusPermanentRedirect, true) + _ = util.Redirect(w, r, m.GetRouteByName("app.user.sign_in"), http.StatusPermanentRedirect, true) return } if claims, err = util.ParseToken(token, os.Getenv("TOKEN_SECRET")); err != nil { - _ = util.Redirect(w, r, "/sign-in", http.StatusPermanentRedirect, true) + _ = util.Redirect(w, r, m.GetRouteByName("app.user.sign_in"), http.StatusPermanentRedirect, true) return } @@ -94,7 +128,7 @@ func WithAuth(next http.Handler) http.Handler { }) } -func WithUsername(next http.Handler) http.Handler { +func WithUsername(next http.Handler, m IMiddlewareCtx) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var ( claims *util.CustomClaims diff --git a/view/user/signin.templ b/view/user/signin.templ index 389388f..15f721b 100644 --- a/view/user/signin.templ +++ b/view/user/signin.templ @@ -1,8 +1,11 @@ package user -import "git.markbailey.dev/cerbervs/ptpp/view/layout" +import ( + "git.markbailey.dev/cerbervs/ptpp/view/layout" + "git.markbailey.dev/cerbervs/ptpp/app/controller" +) -templ SignIn(err string) { +templ SignIn(err string, context controller.IControllerCtx) { @layout.Layout() {
@@ -15,7 +18,7 @@ templ SignIn(err string) {
diff --git a/view/user/signin_templ.go b/view/user/signin_templ.go index 5aeaef9..26bfe80 100644 --- a/view/user/signin_templ.go +++ b/view/user/signin_templ.go @@ -8,9 +8,12 @@ package user import "github.com/a-h/templ" import templruntime "github.com/a-h/templ/runtime" -import "git.markbailey.dev/cerbervs/ptpp/view/layout" +import ( + "git.markbailey.dev/cerbervs/ptpp/app/controller" + "git.markbailey.dev/cerbervs/ptpp/view/layout" +) -func SignIn(err string) templ.Component { +func SignIn(err string, context controller.IControllerCtx) templ.Component { return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) { templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil { @@ -43,7 +46,20 @@ func SignIn(err string) templ.Component { }() } ctx = templ.InitializeContext(ctx) - _, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString("

Create an account

") + _, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString("

Create an account

") if templ_7745c5c3_Err != nil { return templ_7745c5c3_Err } @@ -52,12 +68,12 @@ func SignIn(err string) templ.Component { if templ_7745c5c3_Err != nil { return templ_7745c5c3_Err } - var templ_7745c5c3_Var3 string - templ_7745c5c3_Var3, templ_7745c5c3_Err = templ.JoinStringErrs(err) + var templ_7745c5c3_Var4 string + templ_7745c5c3_Var4, templ_7745c5c3_Err = templ.JoinStringErrs(err) if templ_7745c5c3_Err != nil { - return templ.Error{Err: templ_7745c5c3_Err, FileName: `view/user/signin.templ`, Line: 52, Col: 14} + return templ.Error{Err: templ_7745c5c3_Err, FileName: `view/user/signin.templ`, Line: 55, Col: 14} } - _, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var3)) + _, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var4)) if templ_7745c5c3_Err != nil { return templ_7745c5c3_Err }