add the visitor API endpoint
This commit is contained in:
@ -17,6 +17,18 @@ func (db *Type) Load() (err error) {
|
||||
return fmt.Errorf("cannot access the database: %s", err.Error())
|
||||
}
|
||||
|
||||
// see database/visitor.go
|
||||
_, err = db.sql.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS visitor_count(
|
||||
id TEXT NOT NULL UNIQUE,
|
||||
count INTEGER NOT NULL
|
||||
);
|
||||
`)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create the visitor_count table: %s", err.Error())
|
||||
}
|
||||
|
||||
// see database/service.go
|
||||
_, err = db.sql.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS services(
|
||||
|
48
api/database/visitor.go
Normal file
48
api/database/visitor.go
Normal file
@ -0,0 +1,48 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func (db *Type) VisitorGet() (uint64, error) {
|
||||
var (
|
||||
row *sql.Row
|
||||
count uint64
|
||||
err error
|
||||
)
|
||||
|
||||
if row = db.sql.QueryRow("SELECT count FROM visitor_count WHERE id = 0"); row == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
if err = row.Scan(&count); err != nil && err != sql.ErrNoRows {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (db *Type) VisitorIncrement() (err error) {
|
||||
if _, err = db.sql.Exec("UPDATE visitor_count SET count = count + 1 WHERE id = 0"); err != nil && err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: err is always nil even if there is no rows for some reason, check sql.Result instead
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
_, err = db.sql.Exec(
|
||||
`INSERT INTO visitor_count(
|
||||
id, count
|
||||
) values(?, ?)`,
|
||||
0, 0,
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -89,6 +89,7 @@ func main() {
|
||||
|
||||
// v1 user routes
|
||||
v1.Get("/services", routes.GET_Services)
|
||||
v1.Get("/visitor", routes.GET_Visitor)
|
||||
v1.Get("/news/:lang", routes.GET_News)
|
||||
|
||||
// v1 admin routes
|
||||
|
50
api/routes/visitor.go
Normal file
50
api/routes/visitor.go
Normal file
@ -0,0 +1,50 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/ngn13/website/api/database"
|
||||
"github.com/ngn13/website/api/util"
|
||||
)
|
||||
|
||||
const LAST_ADDRS_MAX = 30
|
||||
|
||||
var last_addrs []string
|
||||
|
||||
func GET_Visitor(c *fiber.Ctx) error {
|
||||
var (
|
||||
err error
|
||||
count uint64
|
||||
)
|
||||
|
||||
db := c.Locals("database").(*database.Type)
|
||||
new_addr := util.GetSHA1(util.IP(c))
|
||||
|
||||
for _, addr := range last_addrs {
|
||||
if new_addr == addr {
|
||||
if count, err = db.VisitorGet(); err != nil {
|
||||
return util.ErrInternal(c, err)
|
||||
}
|
||||
|
||||
return util.JSON(c, 200, fiber.Map{
|
||||
"result": count,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if err = db.VisitorIncrement(); err != nil {
|
||||
return util.ErrInternal(c, err)
|
||||
}
|
||||
|
||||
if count, err = db.VisitorGet(); err != nil {
|
||||
return util.ErrInternal(c, err)
|
||||
}
|
||||
|
||||
if len(last_addrs) > LAST_ADDRS_MAX {
|
||||
last_addrs = append(last_addrs[:0], last_addrs[1:]...)
|
||||
last_addrs = append(last_addrs, new_addr)
|
||||
}
|
||||
|
||||
return util.JSON(c, 200, fiber.Map{
|
||||
"result": count,
|
||||
})
|
||||
}
|
Reference in New Issue
Block a user