diff --git a/api/api_test.go b/api/api_test.go index 4dda992..755d6d4 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -16,9 +16,10 @@ import ( type testInformation struct { Method string + Header map[string]string + Cookie map[string]string Body interface{} Query map[string]interface{} - Cookie map[string]string ResultBody interface{} ResultCookie []string @@ -57,8 +58,12 @@ func checkTestInformation(t *testing.T, url string, information []testInformatio for i, information := range information { var body io.Reader if information.Body != nil { - buf, _ := json.Marshal(information.Body) - body = bytes.NewReader(buf) + if b, ok := information.Body.([]byte); ok { + body = bytes.NewReader(b) + } else { + buf, _ := json.Marshal(information.Body) + body = bytes.NewReader(buf) + } } query := url2.Values{} @@ -77,6 +82,11 @@ func checkTestInformation(t *testing.T, url string, information []testInformatio }) } } + if information.Header != nil { + for name, value := range information.Header { + req.Header.Set(name, value) + } + } resp, _ := http.DefaultClient.Do(req) diff --git a/api/assets.go b/api/assets.go index b9c901e..1d98200 100644 --- a/api/assets.go +++ b/api/assets.go @@ -4,10 +4,9 @@ import ( "TheAdversary/config" "TheAdversary/database" "TheAdversary/schema" - "encoding/base64" "encoding/json" - "go.uber.org/zap" "gorm.io/gorm/clause" + "io" "net/http" "net/url" "path" @@ -37,6 +36,12 @@ func Assets(w http.ResponseWriter, r *http.Request) { } func assetsGet(w http.ResponseWriter, r *http.Request) { + _, ok := authorizedSession(r) + if !ok { + w.WriteHeader(http.StatusUnauthorized) + return + } + query := r.URL.Query() request := database.GetDB().Table("assets") @@ -65,22 +70,32 @@ func assetsGet(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(&assets) } -type assetsPostPayload struct { - Name string `json:"name"` - Content string `json:"data"` -} - func assetsPost(w http.ResponseWriter, r *http.Request) { - var payload assetsPostPayload - if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { - InvalidJson.Send(w) + _, ok := authorizedSession(r) + if !ok { + w.WriteHeader(http.StatusUnauthorized) return } - rawData, err := base64.StdEncoding.DecodeString(payload.Content) + file, header, err := r.FormFile("file") if err != nil { - zap.S().Warnf("Cannot decode base64") - ApiError{Message: "invalid base64 content", Code: http.StatusUnprocessableEntity}.Send(w) + if err == http.ErrMissingFile { + ApiError{Message: "file is missing", Code: http.StatusUnprocessableEntity}.Send(w) + } else { + ApiError{Message: "could not parse file" + err.Error(), Code: http.StatusInternalServerError}.Send(w) + } + return + } + defer file.Close() + + var name string + if name = r.FormValue("name"); name == "" { + name = header.Filename + } + + rawData, err := io.ReadAll(file) + if err != nil { + ApiError{Message: "failed to read file", Code: http.StatusInternalServerError}.Send(w) return } @@ -89,7 +104,7 @@ func assetsPost(w http.ResponseWriter, r *http.Request) { Name string Data []byte Link string - }{Name: payload.Name, Data: rawData, Link: url.PathEscape(payload.Name)} + }{Name: name, Data: rawData, Link: url.PathEscape(name)} if database.GetDB().Table("assets").Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "name"}}, @@ -97,7 +112,7 @@ func assetsPost(w http.ResponseWriter, r *http.Request) { }).Create(&tmpDatabaseSchema).RowsAffected == 0 { w.WriteHeader(http.StatusConflict) } else { - w.WriteHeader(http.StatusOK) + w.WriteHeader(http.StatusCreated) json.NewEncoder(w).Encode(schema.Asset{ Id: tmpDatabaseSchema.Id, Name: tmpDatabaseSchema.Name, @@ -111,6 +126,12 @@ type assetsDeletePayload struct { } func assetsDelete(w http.ResponseWriter, r *http.Request) { + _, ok := authorizedSession(r) + if !ok { + w.WriteHeader(http.StatusUnauthorized) + return + } + var payload assetsDeletePayload if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { InvalidJson.Send(w) diff --git a/api/assets_test.go b/api/assets_test.go index bc7ca4e..6dcb9ea 100644 --- a/api/assets_test.go +++ b/api/assets_test.go @@ -3,7 +3,9 @@ package api import ( "TheAdversary/database" "TheAdversary/schema" - "encoding/base64" + "bytes" + "fmt" + "mime/multipart" "net/http" "net/http/httptest" "testing" @@ -31,6 +33,13 @@ func TestAssetsGet(t *testing.T) { checkTestInformation(t, server.URL, []testInformation{ { Method: http.MethodGet, + Code: http.StatusUnauthorized, + }, + { + Method: http.MethodGet, + Cookie: map[string]string{ + "session_id": initSession(), + }, Query: map[string]interface{}{ "q": "linux", }, @@ -45,6 +54,9 @@ func TestAssetsGet(t *testing.T) { }, { Method: http.MethodGet, + Cookie: map[string]string{ + "session_id": initSession(), + }, Query: map[string]interface{}{ "limit": 1, }, @@ -59,6 +71,9 @@ func TestAssetsGet(t *testing.T) { }, { Method: http.MethodGet, + Cookie: map[string]string{ + "session_id": initSession(), + }, ResultBody: []schema.Asset{ { Id: 1, @@ -81,27 +96,47 @@ func TestAssetsPost(t *testing.T) { t.Fatal(err) } + var buf bytes.Buffer + mw := multipart.NewWriter(&buf) + mw.SetBoundary("test") + formFile, _ := mw.CreateFormFile("file", "srfwsr") + formFile.Write([]byte("just a test file")) + mw.WriteField("name", "test") + mw.Close() + + fmt.Println(buf.String()) + server := httptest.NewServer(http.HandlerFunc(assetsPost)) checkTestInformation(t, server.URL, []testInformation{ { Method: http.MethodPost, - Body: assetsPostPayload{ - Name: "test", - Content: base64.StdEncoding.EncodeToString([]byte("test asset")), + Code: http.StatusUnauthorized, + }, + { + Method: http.MethodPost, + Header: map[string]string{ + "Content-Type": "multipart/form-data; boundary=test", }, + Cookie: map[string]string{ + "session_id": initSession(), + }, + Body: buf.Bytes(), ResultBody: schema.Asset{ Id: 1, Name: "test", Link: "/assets/test", }, - Code: http.StatusOK, + Code: http.StatusCreated, }, { Method: http.MethodPost, - Body: assetsPostPayload{ - Name: "test", - Content: base64.StdEncoding.EncodeToString([]byte("test asset")), + Header: map[string]string{ + "Content-Type": "multipart/form-data; boundary=test", }, + Cookie: map[string]string{ + "session_id": initSession(), + }, + Body: buf.Bytes(), Code: http.StatusConflict, }, }) @@ -122,6 +157,13 @@ func TestAssetsDelete(t *testing.T) { checkTestInformation(t, server.URL, []testInformation{ { Method: http.MethodDelete, + Code: http.StatusUnauthorized, + }, + { + Method: http.MethodDelete, + Cookie: map[string]string{ + "session_id": initSession(), + }, Body: assetsDeletePayload{ Id: 1, }, @@ -129,6 +171,9 @@ func TestAssetsDelete(t *testing.T) { }, { Method: http.MethodDelete, + Cookie: map[string]string{ + "session_id": initSession(), + }, Body: assetsDeletePayload{ Id: 69, },