Now we're ready to start with the real implementation.
We'll start by adding code to the routes (controllers) that will check if the user has access to perform the actions.
Before we start with that, we need to change the webserver so that it passes the cerberus access token to the context. (The access token will later be generated by the backend using the api key & secret, and passed to the frontend, which includes it with each request)
In the 'internal/server/webserver.go' file, change the code so it looks like:
Copy
cerberusTokenPair := cerberus.TokenPair{
AccessToken: c.GetHeader("CerberusAccessToken"),
RefreshToken: c.GetHeader("CerberusRefreshToken"),
}
// Set userId and cerberusToken for route handlers
c.Set("userId", userId)
c.Set("accountId", accountId)
c.Set("cerberusTokenPair", cerberusTokenPair)
c.Next()
Also change the cors configuration below:
Copy corsConfig.AllowHeaders = []string{"Content-Type", "Authorization", "CerberusAccessToken", "CerberusRefreshToken"}
All changes below are in the 'internal/routes' folder.
Change the 'projects.go' file so that the struct and constructor looks like:
Copy
type projectRoutes struct {
service services.ProjectService
cerberusClient cerberus.CerberusClient
}
func NewProjectRoutes(service services.ProjectService, cerberusClient cerberus.CerberusClient) Routable {
return &projectRoutes{service: service, cerberusClient: cerberusClient}
}
Then, change the 'Create' function so it has the following code:
Copy hasAccess, err := r.cerberusClient.HasAccess(c, accountId, common.CreateProject_A)
if err != nil || !hasAccess {
c.AbortWithStatusJSON(http.StatusForbidden, jsonError(err))
return
}
if err := c.Bind(&projectData); err != nil {
.
.
.
Now change the 'Get' function to include:
Copy hasAccess, err := r.cerberusClient.HasAccess(c, projectId, common.ReadProject_A)
if err != nil || !hasAccess {
c.AbortWithStatusJSON(http.StatusForbidden, jsonError(err))
return
}
project, err := r.service.Get(
.
.
.
And the 'Delete' function:
Copy hasAccess, err := r.cerberusClient.HasAccess(c, projectId, common.DeleteProject_A)
if err != nil || !hasAccess {
c.AbortWithStatusJSON(http.StatusForbidden, jsonError(err))
return
}
err = r.service.Delete(
.
.
.
You should get the idea by now, and might even take a stab at completing the other routes, but we'll include the solution here anyway.
Next up is the 'sprints.go' file.
Change the 'Create' function to include:
Copy hasAccess, err := r.cerberusClient.HasAccess(c, projectId, common.CreateSprint_A)
if err != nil || !hasAccess {
c.AbortWithStatusJSON(http.StatusForbidden, jsonError(err))
return
}
if err := c.Bind(&resourceTypeData); err != nil {
.
.
.
And the 'Start' function:
Copy hasAccess, err := r.cerberusClient.HasAccess(c, sprintId, common.StartSprint_A)
if err != nil || !hasAccess {
c.AbortWithStatusJSON(http.StatusForbidden, jsonError(err))
return
}
rts, err := r.service.Start(
.
.
.
And the 'End' function:
Copy hasAccess, err := r.cerberusClient.HasAccess(c, sprintId, common.EndSprint_A)
if err != nil || !hasAccess {
c.AbortWithStatusJSON(http.StatusForbidden, jsonError(err))
return
}
rts, err := r.service.End(
.
.
.
And lastly, the 'Get' function:
Copy hasAccess, err := r.cerberusClient.HasAccess(c, sprintId, common.ReadSprint_A)
if err != nil || !hasAccess {
c.AbortWithStatusJSON(http.StatusForbidden, jsonError(err))
return
}
sprint, err := r.service.Get(
.
.
.
Next is the 'stories.go' file.
The 'Create' function:
Copy hasAccess, err := r.cerberusClient.HasAccess(c, sprintId, common.CreateStory_A)
if err != nil || !hasAccess {
c.AbortWithStatusJSON(http.StatusForbidden, jsonError(err))
return
}
if err := c.Bind(&data); err != nil {
.
.
.
The 'Get' function:
Copy hasAccess, err := r.cerberusClient.HasAccess(c, storyId, common.ReadStory_A)
if err != nil || !hasAccess {
c.AbortWithStatusJSON(http.StatusForbidden, jsonError(err))
return
}
story, err := r.service.Get(
.
.
.
The 'Estimate' function:
Copy hasAccess, err := r.cerberusClient.HasAccess(c, storyId, common.EstimateStory_A)
if err != nil || !hasAccess {
c.AbortWithStatusJSON(http.StatusForbidden, jsonError(err))
return
}
var data StoryData
.
.
.
The 'ChangeStatus' function:
Copy hasAccess, err := r.cerberusClient.HasAccess(c, storyId, common.ChangeStoryStatus_A)
if err != nil || !hasAccess {
c.AbortWithStatusJSON(http.StatusForbidden, jsonError(err))
return
}
var data StoryData
.
.
.
The 'Assign' function:
Copy hasAccess, err := r.cerberusClient.HasAccess(c, storyId, common.ChangeStoryAssignee_A)
if err != nil || !hasAccess {
c.AbortWithStatusJSON(http.StatusForbidden, jsonError(err))
return
}
var data StoryData
.
.
.
And lastly, the 'users.go' file.
In the 'Add' function:
Copy accountId, exists := c.Get("accountId")
if !exists {
c.AbortWithStatusJSON(400, jsonError(fmt.Errorf("no accountId")))
}
hasAccess, err := r.cerberusClient.HasAccess(c, accountId.(string), common.AddUser_A)
if err != nil || !hasAccess {
c.AbortWithStatusJSON(http.StatusForbidden, jsonError(err))
return
}
var userData UserData
.
.
.
As you can see, all that's required to protect our app is to add the 'HasAccess' checks on every controller function in our app that we'd like to protect.
Our routes now have different constructors, and we'll need to update the 'cmd/api/api.go' file:
Copy userService := services.NewUserService(
txProvider,
userRepo,
accountRepo,
jwtSecret, saltRounds, cerberusClient)
publicRoutes := publicRoutes(userService)
privateRoutes := privateRoutes(
cerberusClient,
userService,
services.NewProjectService(txProvider, projectRepo, cerberusClient),
services.NewSprintService(txProvider, sprintRepo, cerberusClient),
services.NewStoryService(txProvider, storyRepo, cerberusClient))
And also, the 'privateRoutes' function:
Copy func privateRoutes(
cerberusClient cerberus.CerberusClient,
userService services.UserService,
projectService services.ProjectService,
sprintService services.SprintService,
storyService services.StoryService) []routes.Routable {
return []routes.Routable{
routes.NewUserRoutes(userService, cerberusClient),
routes.NewProjectRoutes(projectService, cerberusClient),
routes.NewSprintRoutes(sprintService, cerberusClient),
routes.NewStoryRoutes(storyService, cerberusClient),
}
}