fix typo in upload size check

The scp upload size check had a typo preventing files from reporting
their size, causing an extra temp file to be created.
This commit is contained in:
James Bardin 2022-11-08 10:21:08 -05:00
parent be5984d664
commit 8ba8d5aec4
2 changed files with 28 additions and 4 deletions

View File

@ -418,7 +418,7 @@ func (c *Communicator) Upload(path string, input io.Reader) error {
switch src := input.(type) { switch src := input.(type) {
case *os.File: case *os.File:
fi, err := src.Stat() fi, err := src.Stat()
if err != nil { if err == nil {
size = fi.Size() size = fi.Size()
} }
case *bytes.Buffer: case *bytes.Buffer:
@ -641,7 +641,13 @@ func checkSCPStatus(r *bufio.Reader) error {
return nil return nil
} }
var testUploadSizeHook func(size int64)
func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader, size int64) error { func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader, size int64) error {
if testUploadSizeHook != nil {
testUploadSizeHook(size)
}
if size == 0 { if size == 0 {
// Create a temporary file where we can copy the contents of the src // Create a temporary file where we can copy the contents of the src
// so that we can determine the length, since SCP is length-prefixed. // so that we can determine the length, since SCP is length-prefixed.

View File

@ -577,10 +577,28 @@ func TestAccUploadFile(t *testing.T) {
} }
tmpDir := t.TempDir() tmpDir := t.TempDir()
source, err := os.CreateTemp(tmpDir, "tempfile.in")
if err != nil {
t.Fatal(err)
}
content := "this is the file content"
if _, err := source.WriteString(content); err != nil {
t.Fatal(err)
}
source.Seek(0, io.SeekStart)
content := []byte("this is the file content")
source := bytes.NewReader(content)
tmpFile := filepath.Join(tmpDir, "tempFile.out") tmpFile := filepath.Join(tmpDir, "tempFile.out")
testUploadSizeHook = func(size int64) {
if size != int64(len(content)) {
t.Errorf("expected %d bytes, got %d\n", len(content), size)
}
}
defer func() {
testUploadSizeHook = nil
}()
err = c.Upload(tmpFile, source) err = c.Upload(tmpFile, source)
if err != nil { if err != nil {
t.Fatalf("error uploading file: %s", err) t.Fatalf("error uploading file: %s", err)
@ -591,7 +609,7 @@ func TestAccUploadFile(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if !bytes.Equal(data, content) { if string(data) != content {
t.Fatalf("bad: %s", data) t.Fatalf("bad: %s", data)
} }
} }