From 9631958a82c70f30421fc4e292f700ec8881805e Mon Sep 17 00:00:00 2001 From: Chongyi Zheng Date: Mon, 18 Sep 2023 04:40:50 -0400 Subject: [PATCH] Refactor lfs requests (#26783) - Refactor lfs request code - The original code uses `performRequest` function to create the request, uses a callback to modify the request, and then send the request. - Now it's replaced with `createRequest` that only creates request and `performRequest` that only sends the request. - Reuse `createRequest` and `performRequest` in `http_client.go` and `transferadapter.go` --------- Co-authored-by: wxiaoguang --- modules/lfs/filesystem_client.go | 12 ++-- modules/lfs/http_client.go | 104 +++++++++++++++++++---------- modules/lfs/http_client_test.go | 36 ++++++----- modules/lfs/pointer.go | 4 +- modules/lfs/transferadapter.go | 108 +++++++++---------------------- modules/util/path.go | 1 + 6 files changed, 127 insertions(+), 138 deletions(-) diff --git a/modules/lfs/filesystem_client.go b/modules/lfs/filesystem_client.go index 835551e00c..3503a9effc 100644 --- a/modules/lfs/filesystem_client.go +++ b/modules/lfs/filesystem_client.go @@ -15,7 +15,7 @@ import ( // FilesystemClient is used to read LFS data from a filesystem path type FilesystemClient struct { - lfsdir string + lfsDir string } // BatchSize returns the preferred size of batchs to process @@ -25,16 +25,12 @@ func (c *FilesystemClient) BatchSize() int { func newFilesystemClient(endpoint *url.URL) *FilesystemClient { path, _ := util.FileURLToPath(endpoint) - - lfsdir := filepath.Join(path, "lfs", "objects") - - client := &FilesystemClient{lfsdir} - - return client + lfsDir := filepath.Join(path, "lfs", "objects") + return &FilesystemClient{lfsDir} } func (c *FilesystemClient) objectPath(oid string) string { - return filepath.Join(c.lfsdir, oid[0:2], oid[2:4], oid) + return filepath.Join(c.lfsDir, oid[0:2], oid[2:4], oid) } // Download reads the specific LFS object from the target path diff --git a/modules/lfs/http_client.go b/modules/lfs/http_client.go index ec0d6269bd..de0b1e4fed 100644 --- a/modules/lfs/http_client.go +++ b/modules/lfs/http_client.go @@ -8,6 +8,7 @@ import ( "context" "errors" "fmt" + "io" "net/http" "net/url" "strings" @@ -17,7 +18,7 @@ import ( "code.gitea.io/gitea/modules/proxy" ) -const batchSize = 20 +const httpBatchSize = 20 // HTTPClient is used to communicate with the LFS server // https://github.com/git-lfs/git-lfs/blob/main/docs/api/batch.md @@ -29,7 +30,7 @@ type HTTPClient struct { // BatchSize returns the preferred size of batchs to process func (c *HTTPClient) BatchSize() int { - return batchSize + return httpBatchSize } func newHTTPClient(endpoint *url.URL, httpTransport *http.Transport) *HTTPClient { @@ -43,28 +44,25 @@ func newHTTPClient(endpoint *url.URL, httpTransport *http.Transport) *HTTPClient Transport: httpTransport, } - client := &HTTPClient{ - client: hc, - endpoint: strings.TrimSuffix(endpoint.String(), "/"), - transfers: make(map[string]TransferAdapter), - } - basic := &BasicTransferAdapter{hc} - - client.transfers[basic.Name()] = basic + client := &HTTPClient{ + client: hc, + endpoint: strings.TrimSuffix(endpoint.String(), "/"), + transfers: map[string]TransferAdapter{ + basic.Name(): basic, + }, + } return client } func (c *HTTPClient) transferNames() []string { keys := make([]string, len(c.transfers)) - i := 0 for k := range c.transfers { keys[i] = k i++ } - return keys } @@ -74,7 +72,6 @@ func (c *HTTPClient) batch(ctx context.Context, operation string, objects []Poin url := fmt.Sprintf("%s/objects/batch", c.endpoint) request := &BatchRequest{operation, c.transferNames(), nil, objects} - payload := new(bytes.Buffer) err := json.NewEncoder(payload).Encode(request) if err != nil { @@ -82,32 +79,17 @@ func (c *HTTPClient) batch(ctx context.Context, operation string, objects []Poin return nil, err } - log.Trace("Calling: %s", url) - - req, err := http.NewRequestWithContext(ctx, "POST", url, payload) + req, err := createRequest(ctx, http.MethodPost, url, map[string]string{"Content-Type": MediaType}, payload) if err != nil { - log.Error("Error creating request: %v", err) return nil, err } - req.Header.Set("Content-type", MediaType) - req.Header.Set("Accept", MediaType) - res, err := c.client.Do(req) + res, err := performRequest(ctx, c.client, req) if err != nil { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - log.Error("Error while processing request: %v", err) return nil, err } defer res.Body.Close() - if res.StatusCode != http.StatusOK { - return nil, fmt.Errorf("Unexpected server response: %s", res.Status) - } - var response BatchResponse err = json.NewDecoder(res.Body).Decode(&response) if err != nil { @@ -177,7 +159,7 @@ func (c *HTTPClient) performOperation(ctx context.Context, objects []Pointer, dc link, ok := object.Actions["upload"] if !ok { log.Debug("%+v", object) - return errors.New("Missing action 'upload'") + return errors.New("missing action 'upload'") } content, err := uc(object.Pointer, nil) @@ -187,8 +169,6 @@ func (c *HTTPClient) performOperation(ctx context.Context, objects []Pointer, dc err = transferAdapter.Upload(ctx, link, object.Pointer, content) - content.Close() - if err != nil { return err } @@ -203,7 +183,7 @@ func (c *HTTPClient) performOperation(ctx context.Context, objects []Pointer, dc link, ok := object.Actions["download"] if !ok { log.Debug("%+v", object) - return errors.New("Missing action 'download'") + return errors.New("missing action 'download'") } content, err := transferAdapter.Download(ctx, link) @@ -219,3 +199,59 @@ func (c *HTTPClient) performOperation(ctx context.Context, objects []Pointer, dc return nil } + +// createRequest creates a new request, and sets the headers. +func createRequest(ctx context.Context, method, url string, headers map[string]string, body io.Reader) (*http.Request, error) { + log.Trace("createRequest: %s", url) + req, err := http.NewRequestWithContext(ctx, method, url, body) + if err != nil { + log.Error("Error creating request: %v", err) + return nil, err + } + + for key, value := range headers { + req.Header.Set(key, value) + } + req.Header.Set("Accept", MediaType) + + return req, nil +} + +// performRequest sends a request, optionally performs a callback on the request and returns the response. +// If the status code is 200, the response is returned, and it will contain a non-nil Body. +// Otherwise, it will return an error, and the Body will be nil or closed. +func performRequest(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) { + log.Trace("performRequest: %s", req.URL) + res, err := client.Do(req) + if err != nil { + select { + case <-ctx.Done(): + return res, ctx.Err() + default: + } + log.Error("Error while processing request: %v", err) + return res, err + } + + if res.StatusCode != http.StatusOK { + defer res.Body.Close() + return res, handleErrorResponse(res) + } + + return res, nil +} + +func handleErrorResponse(resp *http.Response) error { + var er ErrorResponse + err := json.NewDecoder(resp.Body).Decode(&er) + if err != nil { + if err == io.EOF { + return io.ErrUnexpectedEOF + } + log.Error("Error decoding json: %v", err) + return err + } + + log.Trace("ErrorResponse: %v", er) + return errors.New(er.Message) +} diff --git a/modules/lfs/http_client_test.go b/modules/lfs/http_client_test.go index cb71b9008a..7459d9c0c9 100644 --- a/modules/lfs/http_client_test.go +++ b/modules/lfs/http_client_test.go @@ -177,7 +177,7 @@ func TestHTTPClientDownload(t *testing.T) { // case 0 { endpoint: "https://status-not-ok.io", - expectederror: "Unexpected server response: ", + expectederror: io.ErrUnexpectedEOF.Error(), }, // case 1 { @@ -207,7 +207,7 @@ func TestHTTPClientDownload(t *testing.T) { // case 6 { endpoint: "https://empty-actions-map.io", - expectederror: "Missing action 'download'", + expectederror: "missing action 'download'", }, // case 7 { @@ -217,27 +217,28 @@ func TestHTTPClientDownload(t *testing.T) { // case 8 { endpoint: "https://upload-actions-map.io", - expectederror: "Missing action 'download'", + expectederror: "missing action 'download'", }, // case 9 { endpoint: "https://verify-actions-map.io", - expectederror: "Missing action 'download'", + expectederror: "missing action 'download'", }, // case 10 { endpoint: "https://unknown-actions-map.io", - expectederror: "Missing action 'download'", + expectederror: "missing action 'download'", }, } for n, c := range cases { client := &HTTPClient{ - client: hc, - endpoint: c.endpoint, - transfers: make(map[string]TransferAdapter), + client: hc, + endpoint: c.endpoint, + transfers: map[string]TransferAdapter{ + "dummy": dummy, + }, } - client.transfers["dummy"] = dummy err := client.Download(context.Background(), []Pointer{p}, func(p Pointer, content io.ReadCloser, objectError error) error { if objectError != nil { @@ -284,7 +285,7 @@ func TestHTTPClientUpload(t *testing.T) { // case 0 { endpoint: "https://status-not-ok.io", - expectederror: "Unexpected server response: ", + expectederror: io.ErrUnexpectedEOF.Error(), }, // case 1 { @@ -319,7 +320,7 @@ func TestHTTPClientUpload(t *testing.T) { // case 7 { endpoint: "https://download-actions-map.io", - expectederror: "Missing action 'upload'", + expectederror: "missing action 'upload'", }, // case 8 { @@ -329,22 +330,23 @@ func TestHTTPClientUpload(t *testing.T) { // case 9 { endpoint: "https://verify-actions-map.io", - expectederror: "Missing action 'upload'", + expectederror: "missing action 'upload'", }, // case 10 { endpoint: "https://unknown-actions-map.io", - expectederror: "Missing action 'upload'", + expectederror: "missing action 'upload'", }, } for n, c := range cases { client := &HTTPClient{ - client: hc, - endpoint: c.endpoint, - transfers: make(map[string]TransferAdapter), + client: hc, + endpoint: c.endpoint, + transfers: map[string]TransferAdapter{ + "dummy": dummy, + }, } - client.transfers["dummy"] = dummy err := client.Upload(context.Background(), []Pointer{p}, func(p Pointer, objectError error) (io.ReadCloser, error) { return io.NopCloser(new(bytes.Buffer)), objectError diff --git a/modules/lfs/pointer.go b/modules/lfs/pointer.go index d7653e836c..3e5bb8f91d 100644 --- a/modules/lfs/pointer.go +++ b/modules/lfs/pointer.go @@ -29,10 +29,10 @@ const ( var ( // ErrMissingPrefix occurs if the content lacks the LFS prefix - ErrMissingPrefix = errors.New("Content lacks the LFS prefix") + ErrMissingPrefix = errors.New("content lacks the LFS prefix") // ErrInvalidStructure occurs if the content has an invalid structure - ErrInvalidStructure = errors.New("Content has an invalid structure") + ErrInvalidStructure = errors.New("content has an invalid structure") // ErrInvalidOIDFormat occurs if the oid has an invalid format ErrInvalidOIDFormat = errors.New("OID has an invalid format") diff --git a/modules/lfs/transferadapter.go b/modules/lfs/transferadapter.go index 649497aabb..d425b91946 100644 --- a/modules/lfs/transferadapter.go +++ b/modules/lfs/transferadapter.go @@ -6,8 +6,6 @@ package lfs import ( "bytes" "context" - "errors" - "fmt" "io" "net/http" @@ -15,7 +13,7 @@ import ( "code.gitea.io/gitea/modules/log" ) -// TransferAdapter represents an adapter for downloading/uploading LFS objects +// TransferAdapter represents an adapter for downloading/uploading LFS objects. type TransferAdapter interface { Name() string Download(ctx context.Context, l *Link) (io.ReadCloser, error) @@ -23,41 +21,48 @@ type TransferAdapter interface { Verify(ctx context.Context, l *Link, p Pointer) error } -// BasicTransferAdapter implements the "basic" adapter +// BasicTransferAdapter implements the "basic" adapter. type BasicTransferAdapter struct { client *http.Client } -// Name returns the name of the adapter +// Name returns the name of the adapter. func (a *BasicTransferAdapter) Name() string { return "basic" } -// Download reads the download location and downloads the data +// Download reads the download location and downloads the data. func (a *BasicTransferAdapter) Download(ctx context.Context, l *Link) (io.ReadCloser, error) { - resp, err := a.performRequest(ctx, "GET", l, nil, nil) + req, err := createRequest(ctx, http.MethodGet, l.Href, l.Header, nil) + if err != nil { + return nil, err + } + resp, err := performRequest(ctx, a.client, req) if err != nil { return nil, err } return resp.Body, nil } -// Upload sends the content to the LFS server +// Upload sends the content to the LFS server. func (a *BasicTransferAdapter) Upload(ctx context.Context, l *Link, p Pointer, r io.Reader) error { - _, err := a.performRequest(ctx, "PUT", l, r, func(req *http.Request) { - if len(req.Header.Get("Content-Type")) == 0 { - req.Header.Set("Content-Type", "application/octet-stream") - } - - if req.Header.Get("Transfer-Encoding") == "chunked" { - req.TransferEncoding = []string{"chunked"} - } - - req.ContentLength = p.Size - }) + req, err := createRequest(ctx, http.MethodPut, l.Href, l.Header, r) if err != nil { return err } + if req.Header.Get("Content-Type") == "" { + req.Header.Set("Content-Type", "application/octet-stream") + } + if req.Header.Get("Transfer-Encoding") == "chunked" { + req.TransferEncoding = []string{"chunked"} + } + req.ContentLength = p.Size + + res, err := performRequest(ctx, a.client, req) + if err != nil { + return err + } + defer res.Body.Close() return nil } @@ -69,66 +74,15 @@ func (a *BasicTransferAdapter) Verify(ctx context.Context, l *Link, p Pointer) e return err } - _, err = a.performRequest(ctx, "POST", l, bytes.NewReader(b), func(req *http.Request) { - req.Header.Set("Content-Type", MediaType) - }) + req, err := createRequest(ctx, http.MethodPost, l.Href, l.Header, bytes.NewReader(b)) if err != nil { return err } + req.Header.Set("Content-Type", MediaType) + res, err := performRequest(ctx, a.client, req) + if err != nil { + return err + } + defer res.Body.Close() return nil } - -func (a *BasicTransferAdapter) performRequest(ctx context.Context, method string, l *Link, body io.Reader, callback func(*http.Request)) (*http.Response, error) { - log.Trace("Calling: %s %s", method, l.Href) - - req, err := http.NewRequestWithContext(ctx, method, l.Href, body) - if err != nil { - log.Error("Error creating request: %v", err) - return nil, err - } - for key, value := range l.Header { - req.Header.Set(key, value) - } - req.Header.Set("Accept", MediaType) - - if callback != nil { - callback(req) - } - - res, err := a.client.Do(req) - if err != nil { - select { - case <-ctx.Done(): - return res, ctx.Err() - default: - } - log.Error("Error while processing request: %v", err) - return res, err - } - - if res.StatusCode != http.StatusOK { - return res, handleErrorResponse(res) - } - - return res, nil -} - -func handleErrorResponse(resp *http.Response) error { - defer resp.Body.Close() - - er, err := decodeResponseError(resp.Body) - if err != nil { - return fmt.Errorf("Request failed with status %s", resp.Status) - } - log.Trace("ErrorRespone: %v", er) - return errors.New(er.Message) -} - -func decodeResponseError(r io.Reader) (ErrorResponse, error) { - var er ErrorResponse - err := json.NewDecoder(r).Decode(&er) - if err != nil { - log.Error("Error decoding json: %v", err) - } - return er, err -} diff --git a/modules/util/path.go b/modules/util/path.go index 58258560dd..e8537fb6b9 100644 --- a/modules/util/path.go +++ b/modules/util/path.go @@ -225,6 +225,7 @@ func isOSWindows() bool { var driveLetterRegexp = regexp.MustCompile("/[A-Za-z]:/") // FileURLToPath extracts the path information from a file://... url. +// It returns an error only if the URL is not a file URL. func FileURLToPath(u *url.URL) (string, error) { if u.Scheme != "file" { return "", errors.New("URL scheme is not 'file': " + u.String())