Cloudwatch: Use context in aws DescribeLogGroupsWithContext (#77176)

This commit is contained in:
Shabeeb Khalid 2023-11-01 21:06:06 +02:00 committed by GitHub
parent d5932760d9
commit a59588a62e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 66 additions and 72 deletions

View File

@ -236,7 +236,7 @@ func (e *cloudWatchExecutor) checkHealthLogs(ctx context.Context, pluginCtx back
return err return err
} }
logsClient := NewLogsAPI(session) logsClient := NewLogsAPI(session)
_, err = logsClient.DescribeLogGroups(&cloudwatchlogs.DescribeLogGroupsInput{Limit: aws.Int64(1)}) _, err = logsClient.DescribeLogGroupsWithContext(ctx, &cloudwatchlogs.DescribeLogGroupsInput{Limit: aws.Int64(1)})
return err return err
} }

View File

@ -220,7 +220,7 @@ func TestQuery_ResourceRequest_DescribeLogGroups_with_CrossAccountQuerying(t *te
t.Run("maps log group api response to resource response of log-groups", func(t *testing.T) { t.Run("maps log group api response to resource response of log-groups", func(t *testing.T) {
logsApi = mocks.LogsAPI{} logsApi = mocks.LogsAPI{}
logsApi.On("DescribeLogGroups", mock.Anything).Return(&cloudwatchlogs.DescribeLogGroupsOutput{ logsApi.On("DescribeLogGroupsWithContext", mock.Anything).Return(&cloudwatchlogs.DescribeLogGroupsOutput{
LogGroups: []*cloudwatchlogs.LogGroup{ LogGroups: []*cloudwatchlogs.LogGroup{
{Arn: aws.String("arn:aws:logs:us-east-1:111:log-group:group_a"), LogGroupName: aws.String("group_a")}, {Arn: aws.String("arn:aws:logs:us-east-1:111:log-group:group_a"), LogGroupName: aws.String("group_a")},
}, },
@ -248,7 +248,7 @@ func TestQuery_ResourceRequest_DescribeLogGroups_with_CrossAccountQuerying(t *te
} }
]`, string(sender.Response.Body)) ]`, string(sender.Response.Body))
logsApi.AssertCalled(t, "DescribeLogGroups", logsApi.AssertCalled(t, "DescribeLogGroupsWithContext",
&cloudwatchlogs.DescribeLogGroupsInput{ &cloudwatchlogs.DescribeLogGroupsInput{
AccountIdentifiers: []*string{utils.Pointer("some-account-id")}, AccountIdentifiers: []*string{utils.Pointer("some-account-id")},
IncludeLinkedAccounts: utils.Pointer(true), IncludeLinkedAccounts: utils.Pointer(true),

View File

@ -307,7 +307,7 @@ func (e *cloudWatchExecutor) handleGetLogGroups(ctx context.Context, pluginCtx b
input.LogGroupNamePrefix = aws.String(logGroupNamePrefix) input.LogGroupNamePrefix = aws.String(logGroupNamePrefix)
} }
var response *cloudwatchlogs.DescribeLogGroupsOutput var response *cloudwatchlogs.DescribeLogGroupsOutput
response, err = logsClient.DescribeLogGroups(input) response, err = logsClient.DescribeLogGroupsWithContext(ctx, input)
if err != nil || response == nil { if err != nil || response == nil {
return nil, err return nil, err
} }

View File

@ -15,7 +15,7 @@ type LogsAPI struct {
mock.Mock mock.Mock
} }
func (l *LogsAPI) DescribeLogGroups(input *cloudwatchlogs.DescribeLogGroupsInput) (*cloudwatchlogs.DescribeLogGroupsOutput, error) { func (l *LogsAPI) DescribeLogGroupsWithContext(ctx context.Context, input *cloudwatchlogs.DescribeLogGroupsInput, option ...request.Option) (*cloudwatchlogs.DescribeLogGroupsOutput, error) {
args := l.Called(input) args := l.Called(input)
return args.Get(0).(*cloudwatchlogs.DescribeLogGroupsOutput), args.Error(1) return args.Get(0).(*cloudwatchlogs.DescribeLogGroupsOutput), args.Error(1)
@ -31,7 +31,7 @@ type LogsService struct {
mock.Mock mock.Mock
} }
func (l *LogsService) GetLogGroups(request resources.LogGroupsRequest) ([]resources.ResourceResponse[resources.LogGroup], error) { func (l *LogsService) GetLogGroupsWithContext(ctx context.Context, request resources.LogGroupsRequest) ([]resources.ResourceResponse[resources.LogGroup], error) {
args := l.Called(request) args := l.Called(request)
return args.Get(0).([]resources.ResourceResponse[resources.LogGroup]), args.Error(1) return args.Get(0).([]resources.ResourceResponse[resources.LogGroup]), args.Error(1)

View File

@ -37,7 +37,7 @@ type ListMetricsProvider interface {
} }
type LogGroupsProvider interface { type LogGroupsProvider interface {
GetLogGroups(request resources.LogGroupsRequest) ([]resources.ResourceResponse[resources.LogGroup], error) GetLogGroupsWithContext(ctx context.Context, request resources.LogGroupsRequest) ([]resources.ResourceResponse[resources.LogGroup], error)
GetLogGroupFieldsWithContext(ctx context.Context, request resources.LogGroupFieldsRequest, option ...request.Option) ([]resources.ResourceResponse[resources.LogGroupField], error) GetLogGroupFieldsWithContext(ctx context.Context, request resources.LogGroupFieldsRequest, option ...request.Option) ([]resources.ResourceResponse[resources.LogGroupField], error)
} }
@ -60,7 +60,7 @@ type CloudWatchMetricsAPIProvider interface {
} }
type CloudWatchLogsAPIProvider interface { type CloudWatchLogsAPIProvider interface {
DescribeLogGroups(*cloudwatchlogs.DescribeLogGroupsInput) (*cloudwatchlogs.DescribeLogGroupsOutput, error) DescribeLogGroupsWithContext(ctx context.Context, in *cloudwatchlogs.DescribeLogGroupsInput, opts ...request.Option) (*cloudwatchlogs.DescribeLogGroupsOutput, error)
GetLogGroupFieldsWithContext(ctx context.Context, in *cloudwatchlogs.GetLogGroupFieldsInput, option ...request.Option) (*cloudwatchlogs.GetLogGroupFieldsOutput, error) GetLogGroupFieldsWithContext(ctx context.Context, in *cloudwatchlogs.GetLogGroupFieldsInput, option ...request.Option) (*cloudwatchlogs.GetLogGroupFieldsOutput, error)
} }

View File

@ -24,7 +24,7 @@ func LogGroupsHandler(ctx context.Context, pluginCtx backend.PluginContext, reqC
return nil, models.NewHttpError("newLogGroupsService error", http.StatusInternalServerError, err) return nil, models.NewHttpError("newLogGroupsService error", http.StatusInternalServerError, err)
} }
logGroups, err := service.GetLogGroups(request) logGroups, err := service.GetLogGroupsWithContext(ctx, request)
if err != nil { if err != nil {
return nil, models.NewHttpError("GetLogGroups error", http.StatusInternalServerError, err) return nil, models.NewHttpError("GetLogGroups error", http.StatusInternalServerError, err)
} }

View File

@ -31,7 +31,7 @@ func TestLogGroupsRoute(t *testing.T) {
t.Run("successfully returns 1 log group with account id", func(t *testing.T) { t.Run("successfully returns 1 log group with account id", func(t *testing.T) {
mockLogsService := mocks.LogsService{} mockLogsService := mocks.LogsService{}
mockLogsService.On("GetLogGroups", mock.Anything).Return([]resources.ResourceResponse[resources.LogGroup]{{ mockLogsService.On("GetLogGroupsWithContext", mock.Anything).Return([]resources.ResourceResponse[resources.LogGroup]{{
Value: resources.LogGroup{ Value: resources.LogGroup{
Arn: "some arn", Arn: "some arn",
Name: "some name", Name: "some name",
@ -53,7 +53,7 @@ func TestLogGroupsRoute(t *testing.T) {
t.Run("successfully returns multiple log groups with account id", func(t *testing.T) { t.Run("successfully returns multiple log groups with account id", func(t *testing.T) {
mockLogsService := mocks.LogsService{} mockLogsService := mocks.LogsService{}
mockLogsService.On("GetLogGroups", mock.Anything).Return( mockLogsService.On("GetLogGroupsWithContext", mock.Anything).Return(
[]resources.ResourceResponse[resources.LogGroup]{ []resources.ResourceResponse[resources.LogGroup]{
{ {
Value: resources.LogGroup{ Value: resources.LogGroup{
@ -99,7 +99,7 @@ func TestLogGroupsRoute(t *testing.T) {
t.Run("returns error when both logGroupPrefix and logGroup Pattern are provided", func(t *testing.T) { t.Run("returns error when both logGroupPrefix and logGroup Pattern are provided", func(t *testing.T) {
mockLogsService := mocks.LogsService{} mockLogsService := mocks.LogsService{}
mockLogsService.On("GetLogGroups", mock.Anything).Return([]resources.ResourceResponse[resources.LogGroup]{}, nil) mockLogsService.On("GetLogGroupsWithContext", mock.Anything).Return([]resources.ResourceResponse[resources.LogGroup]{}, nil)
newLogGroupsService = func(_ context.Context, pluginCtx backend.PluginContext, reqCtxFactory models.RequestContextFactoryFunc, region string) (models.LogGroupsProvider, error) { newLogGroupsService = func(_ context.Context, pluginCtx backend.PluginContext, reqCtxFactory models.RequestContextFactoryFunc, region string) (models.LogGroupsProvider, error) {
return &mockLogsService, nil return &mockLogsService, nil
} }
@ -115,7 +115,7 @@ func TestLogGroupsRoute(t *testing.T) {
t.Run("passes default log group limit and nil for logGroupNamePrefix, accountId, and logGroupPattern", func(t *testing.T) { t.Run("passes default log group limit and nil for logGroupNamePrefix, accountId, and logGroupPattern", func(t *testing.T) {
mockLogsService := mocks.LogsService{} mockLogsService := mocks.LogsService{}
mockLogsService.On("GetLogGroups", mock.Anything).Return([]resources.ResourceResponse[resources.LogGroup]{}, nil) mockLogsService.On("GetLogGroupsWithContext", mock.Anything).Return([]resources.ResourceResponse[resources.LogGroup]{}, nil)
newLogGroupsService = func(_ context.Context, pluginCtx backend.PluginContext, reqCtxFactory models.RequestContextFactoryFunc, region string) (models.LogGroupsProvider, error) { newLogGroupsService = func(_ context.Context, pluginCtx backend.PluginContext, reqCtxFactory models.RequestContextFactoryFunc, region string) (models.LogGroupsProvider, error) {
return &mockLogsService, nil return &mockLogsService, nil
} }
@ -125,7 +125,7 @@ func TestLogGroupsRoute(t *testing.T) {
handler := http.HandlerFunc(ResourceRequestMiddleware(LogGroupsHandler, logger, reqCtxFunc)) handler := http.HandlerFunc(ResourceRequestMiddleware(LogGroupsHandler, logger, reqCtxFunc))
handler.ServeHTTP(rr, req) handler.ServeHTTP(rr, req)
mockLogsService.AssertCalled(t, "GetLogGroups", resources.LogGroupsRequest{ mockLogsService.AssertCalled(t, "GetLogGroupsWithContext", resources.LogGroupsRequest{
Limit: 50, Limit: 50,
ResourceRequest: resources.ResourceRequest{}, ResourceRequest: resources.ResourceRequest{},
LogGroupNamePrefix: nil, LogGroupNamePrefix: nil,
@ -135,7 +135,7 @@ func TestLogGroupsRoute(t *testing.T) {
t.Run("passes default log group limit and nil for logGroupNamePrefix when both are absent", func(t *testing.T) { t.Run("passes default log group limit and nil for logGroupNamePrefix when both are absent", func(t *testing.T) {
mockLogsService := mocks.LogsService{} mockLogsService := mocks.LogsService{}
mockLogsService.On("GetLogGroups", mock.Anything).Return([]resources.ResourceResponse[resources.LogGroup]{}, nil) mockLogsService.On("GetLogGroupsWithContext", mock.Anything).Return([]resources.ResourceResponse[resources.LogGroup]{}, nil)
newLogGroupsService = func(_ context.Context, pluginCtx backend.PluginContext, reqCtxFactory models.RequestContextFactoryFunc, region string) (models.LogGroupsProvider, error) { newLogGroupsService = func(_ context.Context, pluginCtx backend.PluginContext, reqCtxFactory models.RequestContextFactoryFunc, region string) (models.LogGroupsProvider, error) {
return &mockLogsService, nil return &mockLogsService, nil
} }
@ -145,7 +145,7 @@ func TestLogGroupsRoute(t *testing.T) {
handler := http.HandlerFunc(ResourceRequestMiddleware(LogGroupsHandler, logger, reqCtxFunc)) handler := http.HandlerFunc(ResourceRequestMiddleware(LogGroupsHandler, logger, reqCtxFunc))
handler.ServeHTTP(rr, req) handler.ServeHTTP(rr, req)
mockLogsService.AssertCalled(t, "GetLogGroups", resources.LogGroupsRequest{ mockLogsService.AssertCalled(t, "GetLogGroupsWithContext", resources.LogGroupsRequest{
Limit: 50, Limit: 50,
LogGroupNamePrefix: nil, LogGroupNamePrefix: nil,
}) })
@ -153,7 +153,7 @@ func TestLogGroupsRoute(t *testing.T) {
t.Run("passes log group limit from query parameter", func(t *testing.T) { t.Run("passes log group limit from query parameter", func(t *testing.T) {
mockLogsService := mocks.LogsService{} mockLogsService := mocks.LogsService{}
mockLogsService.On("GetLogGroups", mock.Anything).Return([]resources.ResourceResponse[resources.LogGroup]{}, nil) mockLogsService.On("GetLogGroupsWithContext", mock.Anything).Return([]resources.ResourceResponse[resources.LogGroup]{}, nil)
newLogGroupsService = func(_ context.Context, pluginCtx backend.PluginContext, reqCtxFactory models.RequestContextFactoryFunc, region string) (models.LogGroupsProvider, error) { newLogGroupsService = func(_ context.Context, pluginCtx backend.PluginContext, reqCtxFactory models.RequestContextFactoryFunc, region string) (models.LogGroupsProvider, error) {
return &mockLogsService, nil return &mockLogsService, nil
} }
@ -163,14 +163,14 @@ func TestLogGroupsRoute(t *testing.T) {
handler := http.HandlerFunc(ResourceRequestMiddleware(LogGroupsHandler, logger, reqCtxFunc)) handler := http.HandlerFunc(ResourceRequestMiddleware(LogGroupsHandler, logger, reqCtxFunc))
handler.ServeHTTP(rr, req) handler.ServeHTTP(rr, req)
mockLogsService.AssertCalled(t, "GetLogGroups", resources.LogGroupsRequest{ mockLogsService.AssertCalled(t, "GetLogGroupsWithContext", resources.LogGroupsRequest{
Limit: 2, Limit: 2,
}) })
}) })
t.Run("passes logGroupPrefix from query parameter", func(t *testing.T) { t.Run("passes logGroupPrefix from query parameter", func(t *testing.T) {
mockLogsService := mocks.LogsService{} mockLogsService := mocks.LogsService{}
mockLogsService.On("GetLogGroups", mock.Anything).Return([]resources.ResourceResponse[resources.LogGroup]{}, nil) mockLogsService.On("GetLogGroupsWithContext", mock.Anything).Return([]resources.ResourceResponse[resources.LogGroup]{}, nil)
newLogGroupsService = func(_ context.Context, pluginCtx backend.PluginContext, reqCtxFactory models.RequestContextFactoryFunc, region string) (models.LogGroupsProvider, error) { newLogGroupsService = func(_ context.Context, pluginCtx backend.PluginContext, reqCtxFactory models.RequestContextFactoryFunc, region string) (models.LogGroupsProvider, error) {
return &mockLogsService, nil return &mockLogsService, nil
} }
@ -180,7 +180,7 @@ func TestLogGroupsRoute(t *testing.T) {
handler := http.HandlerFunc(ResourceRequestMiddleware(LogGroupsHandler, logger, reqCtxFunc)) handler := http.HandlerFunc(ResourceRequestMiddleware(LogGroupsHandler, logger, reqCtxFunc))
handler.ServeHTTP(rr, req) handler.ServeHTTP(rr, req)
mockLogsService.AssertCalled(t, "GetLogGroups", resources.LogGroupsRequest{ mockLogsService.AssertCalled(t, "GetLogGroupsWithContext", resources.LogGroupsRequest{
Limit: 50, Limit: 50,
LogGroupNamePrefix: utils.Pointer("some-prefix"), LogGroupNamePrefix: utils.Pointer("some-prefix"),
}) })
@ -188,7 +188,7 @@ func TestLogGroupsRoute(t *testing.T) {
t.Run("passes logGroupPattern from query parameter", func(t *testing.T) { t.Run("passes logGroupPattern from query parameter", func(t *testing.T) {
mockLogsService := mocks.LogsService{} mockLogsService := mocks.LogsService{}
mockLogsService.On("GetLogGroups", mock.Anything).Return([]resources.ResourceResponse[resources.LogGroup]{}, nil) mockLogsService.On("GetLogGroupsWithContext", mock.Anything).Return([]resources.ResourceResponse[resources.LogGroup]{}, nil)
newLogGroupsService = func(_ context.Context, pluginCtx backend.PluginContext, reqCtxFactory models.RequestContextFactoryFunc, region string) (models.LogGroupsProvider, error) { newLogGroupsService = func(_ context.Context, pluginCtx backend.PluginContext, reqCtxFactory models.RequestContextFactoryFunc, region string) (models.LogGroupsProvider, error) {
return &mockLogsService, nil return &mockLogsService, nil
} }
@ -198,7 +198,7 @@ func TestLogGroupsRoute(t *testing.T) {
handler := http.HandlerFunc(ResourceRequestMiddleware(LogGroupsHandler, logger, reqCtxFunc)) handler := http.HandlerFunc(ResourceRequestMiddleware(LogGroupsHandler, logger, reqCtxFunc))
handler.ServeHTTP(rr, req) handler.ServeHTTP(rr, req)
mockLogsService.AssertCalled(t, "GetLogGroups", resources.LogGroupsRequest{ mockLogsService.AssertCalled(t, "GetLogGroupsWithContext", resources.LogGroupsRequest{
Limit: 50, Limit: 50,
LogGroupNamePattern: utils.Pointer("some-pattern"), LogGroupNamePattern: utils.Pointer("some-pattern"),
}) })
@ -206,7 +206,7 @@ func TestLogGroupsRoute(t *testing.T) {
t.Run("passes logGroupPattern from query parameter", func(t *testing.T) { t.Run("passes logGroupPattern from query parameter", func(t *testing.T) {
mockLogsService := mocks.LogsService{} mockLogsService := mocks.LogsService{}
mockLogsService.On("GetLogGroups", mock.Anything).Return([]resources.ResourceResponse[resources.LogGroup]{}, nil) mockLogsService.On("GetLogGroupsWithContext", mock.Anything).Return([]resources.ResourceResponse[resources.LogGroup]{}, nil)
newLogGroupsService = func(_ context.Context, pluginCtx backend.PluginContext, reqCtxFactory models.RequestContextFactoryFunc, region string) (models.LogGroupsProvider, error) { newLogGroupsService = func(_ context.Context, pluginCtx backend.PluginContext, reqCtxFactory models.RequestContextFactoryFunc, region string) (models.LogGroupsProvider, error) {
return &mockLogsService, nil return &mockLogsService, nil
} }
@ -216,7 +216,7 @@ func TestLogGroupsRoute(t *testing.T) {
handler := http.HandlerFunc(ResourceRequestMiddleware(LogGroupsHandler, logger, reqCtxFunc)) handler := http.HandlerFunc(ResourceRequestMiddleware(LogGroupsHandler, logger, reqCtxFunc))
handler.ServeHTTP(rr, req) handler.ServeHTTP(rr, req)
mockLogsService.AssertCalled(t, "GetLogGroups", resources.LogGroupsRequest{ mockLogsService.AssertCalled(t, "GetLogGroupsWithContext", resources.LogGroupsRequest{
Limit: 50, Limit: 50,
ResourceRequest: resources.ResourceRequest{AccountId: utils.Pointer("some-account-id")}, ResourceRequest: resources.ResourceRequest{AccountId: utils.Pointer("some-account-id")},
}) })
@ -224,7 +224,7 @@ func TestLogGroupsRoute(t *testing.T) {
t.Run("returns error if service returns error", func(t *testing.T) { t.Run("returns error if service returns error", func(t *testing.T) {
mockLogsService := mocks.LogsService{} mockLogsService := mocks.LogsService{}
mockLogsService.On("GetLogGroups", mock.Anything). mockLogsService.On("GetLogGroupsWithContext", mock.Anything).
Return([]resources.ResourceResponse[resources.LogGroup]{}, fmt.Errorf("some error")) Return([]resources.ResourceResponse[resources.LogGroup]{}, fmt.Errorf("some error"))
newLogGroupsService = func(_ context.Context, pluginCtx backend.PluginContext, reqCtxFactory models.RequestContextFactoryFunc, region string) (models.LogGroupsProvider, error) { newLogGroupsService = func(_ context.Context, pluginCtx backend.PluginContext, reqCtxFactory models.RequestContextFactoryFunc, region string) (models.LogGroupsProvider, error) {
return &mockLogsService, nil return &mockLogsService, nil

View File

@ -20,7 +20,7 @@ func NewLogGroupsService(logsClient models.CloudWatchLogsAPIProvider, isCrossAcc
return &LogGroupsService{logGroupsAPI: logsClient, isCrossAccountEnabled: isCrossAccountEnabled} return &LogGroupsService{logGroupsAPI: logsClient, isCrossAccountEnabled: isCrossAccountEnabled}
} }
func (s *LogGroupsService) GetLogGroups(req resources.LogGroupsRequest) ([]resources.ResourceResponse[resources.LogGroup], error) { func (s *LogGroupsService) GetLogGroupsWithContext(ctx context.Context, req resources.LogGroupsRequest) ([]resources.ResourceResponse[resources.LogGroup], error) {
input := &cloudwatchlogs.DescribeLogGroupsInput{ input := &cloudwatchlogs.DescribeLogGroupsInput{
Limit: aws.Int64(req.Limit), Limit: aws.Int64(req.Limit),
LogGroupNamePrefix: req.LogGroupNamePrefix, LogGroupNamePrefix: req.LogGroupNamePrefix,
@ -39,7 +39,7 @@ func (s *LogGroupsService) GetLogGroups(req resources.LogGroupsRequest) ([]resou
result := []resources.ResourceResponse[resources.LogGroup]{} result := []resources.ResourceResponse[resources.LogGroup]{}
for { for {
response, err := s.logGroupsAPI.DescribeLogGroups(input) response, err := s.logGroupsAPI.DescribeLogGroupsWithContext(ctx, input)
if err != nil || response == nil { if err != nil || response == nil {
return nil, err return nil, err
} }

View File

@ -17,7 +17,7 @@ import (
func TestGetLogGroups(t *testing.T) { func TestGetLogGroups(t *testing.T) {
t.Run("Should map log groups response", func(t *testing.T) { t.Run("Should map log groups response", func(t *testing.T) {
mockLogsAPI := &mocks.LogsAPI{} mockLogsAPI := &mocks.LogsAPI{}
mockLogsAPI.On("DescribeLogGroups", mock.Anything).Return( mockLogsAPI.On("DescribeLogGroupsWithContext", mock.Anything).Return(
&cloudwatchlogs.DescribeLogGroupsOutput{ &cloudwatchlogs.DescribeLogGroupsOutput{
LogGroups: []*cloudwatchlogs.LogGroup{ LogGroups: []*cloudwatchlogs.LogGroup{
{Arn: utils.Pointer("arn:aws:logs:us-east-1:111:log-group:group_a"), LogGroupName: utils.Pointer("group_a")}, {Arn: utils.Pointer("arn:aws:logs:us-east-1:111:log-group:group_a"), LogGroupName: utils.Pointer("group_a")},
@ -27,7 +27,7 @@ func TestGetLogGroups(t *testing.T) {
}, nil) }, nil)
service := NewLogGroupsService(mockLogsAPI, false) service := NewLogGroupsService(mockLogsAPI, false)
resp, err := service.GetLogGroups(resources.LogGroupsRequest{}) resp, err := service.GetLogGroupsWithContext(context.Background(), resources.LogGroupsRequest{})
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, []resources.ResourceResponse[resources.LogGroup]{ assert.Equal(t, []resources.ResourceResponse[resources.LogGroup]{
@ -48,10 +48,10 @@ func TestGetLogGroups(t *testing.T) {
t.Run("Should return an empty error if api doesn't return any data", func(t *testing.T) { t.Run("Should return an empty error if api doesn't return any data", func(t *testing.T) {
mockLogsAPI := &mocks.LogsAPI{} mockLogsAPI := &mocks.LogsAPI{}
mockLogsAPI.On("DescribeLogGroups", mock.Anything).Return(&cloudwatchlogs.DescribeLogGroupsOutput{}, nil) mockLogsAPI.On("DescribeLogGroupsWithContext", mock.Anything).Return(&cloudwatchlogs.DescribeLogGroupsOutput{}, nil)
service := NewLogGroupsService(mockLogsAPI, false) service := NewLogGroupsService(mockLogsAPI, false)
resp, err := service.GetLogGroups(resources.LogGroupsRequest{}) resp, err := service.GetLogGroupsWithContext(context.Background(), resources.LogGroupsRequest{})
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, []resources.ResourceResponse[resources.LogGroup]{}, resp) assert.Equal(t, []resources.ResourceResponse[resources.LogGroup]{}, resp)
@ -60,16 +60,16 @@ func TestGetLogGroups(t *testing.T) {
t.Run("Should only use LogGroupNamePrefix even if LogGroupNamePattern passed in resource call", func(t *testing.T) { t.Run("Should only use LogGroupNamePrefix even if LogGroupNamePattern passed in resource call", func(t *testing.T) {
// TODO: use LogGroupNamePattern when we have accounted for its behavior, still a little unexpected at the moment // TODO: use LogGroupNamePattern when we have accounted for its behavior, still a little unexpected at the moment
mockLogsAPI := &mocks.LogsAPI{} mockLogsAPI := &mocks.LogsAPI{}
mockLogsAPI.On("DescribeLogGroups", mock.Anything).Return(&cloudwatchlogs.DescribeLogGroupsOutput{}, nil) mockLogsAPI.On("DescribeLogGroupsWithContext", mock.Anything).Return(&cloudwatchlogs.DescribeLogGroupsOutput{}, nil)
service := NewLogGroupsService(mockLogsAPI, false) service := NewLogGroupsService(mockLogsAPI, false)
_, err := service.GetLogGroups(resources.LogGroupsRequest{ _, err := service.GetLogGroupsWithContext(context.Background(), resources.LogGroupsRequest{
Limit: 0, Limit: 0,
LogGroupNamePrefix: utils.Pointer("test"), LogGroupNamePrefix: utils.Pointer("test"),
}) })
assert.NoError(t, err) assert.NoError(t, err)
mockLogsAPI.AssertCalled(t, "DescribeLogGroups", &cloudwatchlogs.DescribeLogGroupsInput{ mockLogsAPI.AssertCalled(t, "DescribeLogGroupsWithContext", &cloudwatchlogs.DescribeLogGroupsInput{
Limit: utils.Pointer(int64(0)), Limit: utils.Pointer(int64(0)),
LogGroupNamePrefix: utils.Pointer("test"), LogGroupNamePrefix: utils.Pointer("test"),
}) })
@ -77,24 +77,24 @@ func TestGetLogGroups(t *testing.T) {
t.Run("Should call api without LogGroupNamePrefix nor LogGroupNamePattern if not passed in resource call", func(t *testing.T) { t.Run("Should call api without LogGroupNamePrefix nor LogGroupNamePattern if not passed in resource call", func(t *testing.T) {
mockLogsAPI := &mocks.LogsAPI{} mockLogsAPI := &mocks.LogsAPI{}
mockLogsAPI.On("DescribeLogGroups", mock.Anything).Return(&cloudwatchlogs.DescribeLogGroupsOutput{}, nil) mockLogsAPI.On("DescribeLogGroupsWithContext", mock.Anything).Return(&cloudwatchlogs.DescribeLogGroupsOutput{}, nil)
service := NewLogGroupsService(mockLogsAPI, false) service := NewLogGroupsService(mockLogsAPI, false)
_, err := service.GetLogGroups(resources.LogGroupsRequest{}) _, err := service.GetLogGroupsWithContext(context.Background(), resources.LogGroupsRequest{})
assert.NoError(t, err) assert.NoError(t, err)
mockLogsAPI.AssertCalled(t, "DescribeLogGroups", &cloudwatchlogs.DescribeLogGroupsInput{ mockLogsAPI.AssertCalled(t, "DescribeLogGroupsWithContext", &cloudwatchlogs.DescribeLogGroupsInput{
Limit: utils.Pointer(int64(0)), Limit: utils.Pointer(int64(0)),
}) })
}) })
t.Run("Should return an error when API returns error", func(t *testing.T) { t.Run("Should return an error when API returns error", func(t *testing.T) {
mockLogsAPI := &mocks.LogsAPI{} mockLogsAPI := &mocks.LogsAPI{}
mockLogsAPI.On("DescribeLogGroups", mock.Anything).Return(&cloudwatchlogs.DescribeLogGroupsOutput{}, mockLogsAPI.On("DescribeLogGroupsWithContext", mock.Anything).Return(&cloudwatchlogs.DescribeLogGroupsOutput{},
fmt.Errorf("some error")) fmt.Errorf("some error"))
service := NewLogGroupsService(mockLogsAPI, false) service := NewLogGroupsService(mockLogsAPI, false)
_, err := service.GetLogGroups(resources.LogGroupsRequest{}) _, err := service.GetLogGroupsWithContext(context.Background(), resources.LogGroupsRequest{})
assert.Error(t, err) assert.Error(t, err)
assert.Equal(t, "some error", err.Error()) assert.Equal(t, "some error", err.Error())
@ -108,7 +108,7 @@ func TestGetLogGroups(t *testing.T) {
ListAllLogGroups: false, ListAllLogGroups: false,
} }
mockLogsAPI.On("DescribeLogGroups", &cloudwatchlogs.DescribeLogGroupsInput{ mockLogsAPI.On("DescribeLogGroupsWithContext", &cloudwatchlogs.DescribeLogGroupsInput{
Limit: aws.Int64(req.Limit), Limit: aws.Int64(req.Limit),
LogGroupNamePrefix: req.LogGroupNamePrefix, LogGroupNamePrefix: req.LogGroupNamePrefix,
}).Return(&cloudwatchlogs.DescribeLogGroupsOutput{ }).Return(&cloudwatchlogs.DescribeLogGroupsOutput{
@ -119,10 +119,10 @@ func TestGetLogGroups(t *testing.T) {
}, nil) }, nil)
service := NewLogGroupsService(mockLogsAPI, false) service := NewLogGroupsService(mockLogsAPI, false)
resp, err := service.GetLogGroups(req) resp, err := service.GetLogGroupsWithContext(context.Background(), req)
assert.NoError(t, err) assert.NoError(t, err)
mockLogsAPI.AssertNumberOfCalls(t, "DescribeLogGroups", 1) mockLogsAPI.AssertNumberOfCalls(t, "DescribeLogGroupsWithContext", 1)
assert.Equal(t, []resources.ResourceResponse[resources.LogGroup]{ assert.Equal(t, []resources.ResourceResponse[resources.LogGroup]{
{ {
AccountId: utils.Pointer("111"), AccountId: utils.Pointer("111"),
@ -140,7 +140,7 @@ func TestGetLogGroups(t *testing.T) {
} }
// first call // first call
mockLogsAPI.On("DescribeLogGroups", &cloudwatchlogs.DescribeLogGroupsInput{ mockLogsAPI.On("DescribeLogGroupsWithContext", &cloudwatchlogs.DescribeLogGroupsInput{
Limit: aws.Int64(req.Limit), Limit: aws.Int64(req.Limit),
LogGroupNamePrefix: req.LogGroupNamePrefix, LogGroupNamePrefix: req.LogGroupNamePrefix,
}).Return(&cloudwatchlogs.DescribeLogGroupsOutput{ }).Return(&cloudwatchlogs.DescribeLogGroupsOutput{
@ -151,7 +151,7 @@ func TestGetLogGroups(t *testing.T) {
}, nil) }, nil)
// second call // second call
mockLogsAPI.On("DescribeLogGroups", &cloudwatchlogs.DescribeLogGroupsInput{ mockLogsAPI.On("DescribeLogGroupsWithContext", &cloudwatchlogs.DescribeLogGroupsInput{
Limit: aws.Int64(req.Limit), Limit: aws.Int64(req.Limit),
LogGroupNamePrefix: req.LogGroupNamePrefix, LogGroupNamePrefix: req.LogGroupNamePrefix,
NextToken: utils.Pointer("token"), NextToken: utils.Pointer("token"),
@ -161,9 +161,9 @@ func TestGetLogGroups(t *testing.T) {
}, },
}, nil) }, nil)
service := NewLogGroupsService(mockLogsAPI, false) service := NewLogGroupsService(mockLogsAPI, false)
resp, err := service.GetLogGroups(req) resp, err := service.GetLogGroupsWithContext(context.Background(), req)
assert.NoError(t, err) assert.NoError(t, err)
mockLogsAPI.AssertNumberOfCalls(t, "DescribeLogGroups", 2) mockLogsAPI.AssertNumberOfCalls(t, "DescribeLogGroupsWithContext", 2)
assert.Equal(t, []resources.ResourceResponse[resources.LogGroup]{ assert.Equal(t, []resources.ResourceResponse[resources.LogGroup]{
{ {
AccountId: utils.Pointer("111"), AccountId: utils.Pointer("111"),
@ -180,16 +180,16 @@ func TestGetLogGroups(t *testing.T) {
func TestGetLogGroupsCrossAccountQuerying(t *testing.T) { func TestGetLogGroupsCrossAccountQuerying(t *testing.T) {
t.Run("Should not includeLinkedAccounts or accountId if isCrossAccountEnabled is set to false", func(t *testing.T) { t.Run("Should not includeLinkedAccounts or accountId if isCrossAccountEnabled is set to false", func(t *testing.T) {
mockLogsAPI := &mocks.LogsAPI{} mockLogsAPI := &mocks.LogsAPI{}
mockLogsAPI.On("DescribeLogGroups", mock.Anything).Return(&cloudwatchlogs.DescribeLogGroupsOutput{}, nil) mockLogsAPI.On("DescribeLogGroupsWithContext", mock.Anything).Return(&cloudwatchlogs.DescribeLogGroupsOutput{}, nil)
service := NewLogGroupsService(mockLogsAPI, false) service := NewLogGroupsService(mockLogsAPI, false)
_, err := service.GetLogGroups(resources.LogGroupsRequest{ _, err := service.GetLogGroupsWithContext(context.Background(), resources.LogGroupsRequest{
ResourceRequest: resources.ResourceRequest{AccountId: utils.Pointer("accountId")}, ResourceRequest: resources.ResourceRequest{AccountId: utils.Pointer("accountId")},
LogGroupNamePrefix: utils.Pointer("prefix"), LogGroupNamePrefix: utils.Pointer("prefix"),
}) })
assert.NoError(t, err) assert.NoError(t, err)
mockLogsAPI.AssertCalled(t, "DescribeLogGroups", &cloudwatchlogs.DescribeLogGroupsInput{ mockLogsAPI.AssertCalled(t, "DescribeLogGroupsWithContext", &cloudwatchlogs.DescribeLogGroupsInput{
Limit: utils.Pointer(int64(0)), Limit: utils.Pointer(int64(0)),
LogGroupNamePrefix: utils.Pointer("prefix"), LogGroupNamePrefix: utils.Pointer("prefix"),
}) })
@ -197,17 +197,17 @@ func TestGetLogGroupsCrossAccountQuerying(t *testing.T) {
t.Run("Should replace LogGroupNamePrefix if LogGroupNamePattern passed in resource call", func(t *testing.T) { t.Run("Should replace LogGroupNamePrefix if LogGroupNamePattern passed in resource call", func(t *testing.T) {
mockLogsAPI := &mocks.LogsAPI{} mockLogsAPI := &mocks.LogsAPI{}
mockLogsAPI.On("DescribeLogGroups", mock.Anything).Return(&cloudwatchlogs.DescribeLogGroupsOutput{}, nil) mockLogsAPI.On("DescribeLogGroupsWithContext", mock.Anything).Return(&cloudwatchlogs.DescribeLogGroupsOutput{}, nil)
service := NewLogGroupsService(mockLogsAPI, true) service := NewLogGroupsService(mockLogsAPI, true)
_, err := service.GetLogGroups(resources.LogGroupsRequest{ _, err := service.GetLogGroupsWithContext(context.Background(), resources.LogGroupsRequest{
ResourceRequest: resources.ResourceRequest{AccountId: utils.Pointer("accountId")}, ResourceRequest: resources.ResourceRequest{AccountId: utils.Pointer("accountId")},
LogGroupNamePrefix: utils.Pointer("prefix"), LogGroupNamePrefix: utils.Pointer("prefix"),
LogGroupNamePattern: utils.Pointer("pattern"), LogGroupNamePattern: utils.Pointer("pattern"),
}) })
assert.NoError(t, err) assert.NoError(t, err)
mockLogsAPI.AssertCalled(t, "DescribeLogGroups", &cloudwatchlogs.DescribeLogGroupsInput{ mockLogsAPI.AssertCalled(t, "DescribeLogGroupsWithContext", &cloudwatchlogs.DescribeLogGroupsInput{
AccountIdentifiers: []*string{utils.Pointer("accountId")}, AccountIdentifiers: []*string{utils.Pointer("accountId")},
Limit: utils.Pointer(int64(0)), Limit: utils.Pointer(int64(0)),
LogGroupNamePrefix: utils.Pointer("pattern"), LogGroupNamePrefix: utils.Pointer("pattern"),
@ -217,15 +217,15 @@ func TestGetLogGroupsCrossAccountQuerying(t *testing.T) {
t.Run("Should includeLinkedAccounts,and accountId if isCrossAccountEnabled is set to true", func(t *testing.T) { t.Run("Should includeLinkedAccounts,and accountId if isCrossAccountEnabled is set to true", func(t *testing.T) {
mockLogsAPI := &mocks.LogsAPI{} mockLogsAPI := &mocks.LogsAPI{}
mockLogsAPI.On("DescribeLogGroups", mock.Anything).Return(&cloudwatchlogs.DescribeLogGroupsOutput{}, nil) mockLogsAPI.On("DescribeLogGroupsWithContext", mock.Anything).Return(&cloudwatchlogs.DescribeLogGroupsOutput{}, nil)
service := NewLogGroupsService(mockLogsAPI, true) service := NewLogGroupsService(mockLogsAPI, true)
_, err := service.GetLogGroups(resources.LogGroupsRequest{ _, err := service.GetLogGroupsWithContext(context.Background(), resources.LogGroupsRequest{
ResourceRequest: resources.ResourceRequest{AccountId: utils.Pointer("accountId")}, ResourceRequest: resources.ResourceRequest{AccountId: utils.Pointer("accountId")},
}) })
assert.NoError(t, err) assert.NoError(t, err)
mockLogsAPI.AssertCalled(t, "DescribeLogGroups", &cloudwatchlogs.DescribeLogGroupsInput{ mockLogsAPI.AssertCalled(t, "DescribeLogGroupsWithContext", &cloudwatchlogs.DescribeLogGroupsInput{
Limit: utils.Pointer(int64(0)), Limit: utils.Pointer(int64(0)),
IncludeLinkedAccounts: utils.Pointer(true), IncludeLinkedAccounts: utils.Pointer(true),
AccountIdentifiers: []*string{utils.Pointer("accountId")}, AccountIdentifiers: []*string{utils.Pointer("accountId")},
@ -234,15 +234,15 @@ func TestGetLogGroupsCrossAccountQuerying(t *testing.T) {
t.Run("Should should not override prefix is there is no logGroupNamePattern", func(t *testing.T) { t.Run("Should should not override prefix is there is no logGroupNamePattern", func(t *testing.T) {
mockLogsAPI := &mocks.LogsAPI{} mockLogsAPI := &mocks.LogsAPI{}
mockLogsAPI.On("DescribeLogGroups", mock.Anything).Return(&cloudwatchlogs.DescribeLogGroupsOutput{}, nil) mockLogsAPI.On("DescribeLogGroupsWithContext", mock.Anything).Return(&cloudwatchlogs.DescribeLogGroupsOutput{}, nil)
service := NewLogGroupsService(mockLogsAPI, true) service := NewLogGroupsService(mockLogsAPI, true)
_, err := service.GetLogGroups(resources.LogGroupsRequest{ _, err := service.GetLogGroupsWithContext(context.Background(), resources.LogGroupsRequest{
ResourceRequest: resources.ResourceRequest{AccountId: utils.Pointer("accountId")}, ResourceRequest: resources.ResourceRequest{AccountId: utils.Pointer("accountId")},
LogGroupNamePrefix: utils.Pointer("prefix"), LogGroupNamePrefix: utils.Pointer("prefix"),
}) })
assert.NoError(t, err) assert.NoError(t, err)
mockLogsAPI.AssertCalled(t, "DescribeLogGroups", &cloudwatchlogs.DescribeLogGroupsInput{ mockLogsAPI.AssertCalled(t, "DescribeLogGroupsWithContext", &cloudwatchlogs.DescribeLogGroupsInput{
AccountIdentifiers: []*string{utils.Pointer("accountId")}, AccountIdentifiers: []*string{utils.Pointer("accountId")},
Limit: utils.Pointer(int64(0)), Limit: utils.Pointer(int64(0)),
LogGroupNamePrefix: utils.Pointer("prefix"), LogGroupNamePrefix: utils.Pointer("prefix"),
@ -252,15 +252,15 @@ func TestGetLogGroupsCrossAccountQuerying(t *testing.T) {
t.Run("Should not includeLinkedAccounts, or accountId if accountId is nil", func(t *testing.T) { t.Run("Should not includeLinkedAccounts, or accountId if accountId is nil", func(t *testing.T) {
mockLogsAPI := &mocks.LogsAPI{} mockLogsAPI := &mocks.LogsAPI{}
mockLogsAPI.On("DescribeLogGroups", mock.Anything).Return(&cloudwatchlogs.DescribeLogGroupsOutput{}, nil) mockLogsAPI.On("DescribeLogGroupsWithContext", mock.Anything).Return(&cloudwatchlogs.DescribeLogGroupsOutput{}, nil)
service := NewLogGroupsService(mockLogsAPI, true) service := NewLogGroupsService(mockLogsAPI, true)
_, err := service.GetLogGroups(resources.LogGroupsRequest{ _, err := service.GetLogGroupsWithContext(context.Background(), resources.LogGroupsRequest{
LogGroupNamePrefix: utils.Pointer("prefix"), LogGroupNamePrefix: utils.Pointer("prefix"),
}) })
assert.NoError(t, err) assert.NoError(t, err)
mockLogsAPI.AssertCalled(t, "DescribeLogGroups", &cloudwatchlogs.DescribeLogGroupsInput{ mockLogsAPI.AssertCalled(t, "DescribeLogGroupsWithContext", &cloudwatchlogs.DescribeLogGroupsInput{
Limit: utils.Pointer(int64(0)), Limit: utils.Pointer(int64(0)),
LogGroupNamePrefix: utils.Pointer("prefix"), LogGroupNamePrefix: utils.Pointer("prefix"),
}) })
@ -268,10 +268,10 @@ func TestGetLogGroupsCrossAccountQuerying(t *testing.T) {
t.Run("Should should not override prefix is there is no logGroupNamePattern", func(t *testing.T) { t.Run("Should should not override prefix is there is no logGroupNamePattern", func(t *testing.T) {
mockLogsAPI := &mocks.LogsAPI{} mockLogsAPI := &mocks.LogsAPI{}
mockLogsAPI.On("DescribeLogGroups", mock.Anything).Return(&cloudwatchlogs.DescribeLogGroupsOutput{}, nil) mockLogsAPI.On("DescribeLogGroupsWithContext", mock.Anything).Return(&cloudwatchlogs.DescribeLogGroupsOutput{}, nil)
service := NewLogGroupsService(mockLogsAPI, true) service := NewLogGroupsService(mockLogsAPI, true)
_, err := service.GetLogGroups(resources.LogGroupsRequest{ _, err := service.GetLogGroupsWithContext(context.Background(), resources.LogGroupsRequest{
ResourceRequest: resources.ResourceRequest{ ResourceRequest: resources.ResourceRequest{
AccountId: utils.Pointer("accountId"), AccountId: utils.Pointer("accountId"),
}, },
@ -279,7 +279,7 @@ func TestGetLogGroupsCrossAccountQuerying(t *testing.T) {
}) })
assert.NoError(t, err) assert.NoError(t, err)
mockLogsAPI.AssertCalled(t, "DescribeLogGroups", &cloudwatchlogs.DescribeLogGroupsInput{ mockLogsAPI.AssertCalled(t, "DescribeLogGroupsWithContext", &cloudwatchlogs.DescribeLogGroupsInput{
AccountIdentifiers: []*string{utils.Pointer("accountId")}, AccountIdentifiers: []*string{utils.Pointer("accountId")},
IncludeLinkedAccounts: utils.Pointer(true), IncludeLinkedAccounts: utils.Pointer(true),
Limit: utils.Pointer(int64(0)), Limit: utils.Pointer(int64(0)),

View File

@ -71,14 +71,8 @@ func (m *mockLogsSyncClient) StartQueryWithContext(ctx context.Context, input *c
return args.Get(0).(*cloudwatchlogs.StartQueryOutput), args.Error(1) return args.Get(0).(*cloudwatchlogs.StartQueryOutput), args.Error(1)
} }
func (m *fakeCWLogsClient) DescribeLogGroups(input *cloudwatchlogs.DescribeLogGroupsInput) (*cloudwatchlogs.DescribeLogGroupsOutput, error) {
m.calls.describeLogGroups = append(m.calls.describeLogGroups, input)
output := &m.logGroups[m.logGroupsIndex]
m.logGroupsIndex++
return output, nil
}
func (m *fakeCWLogsClient) DescribeLogGroupsWithContext(ctx context.Context, input *cloudwatchlogs.DescribeLogGroupsInput, option ...request.Option) (*cloudwatchlogs.DescribeLogGroupsOutput, error) { func (m *fakeCWLogsClient) DescribeLogGroupsWithContext(ctx context.Context, input *cloudwatchlogs.DescribeLogGroupsInput, option ...request.Option) (*cloudwatchlogs.DescribeLogGroupsOutput, error) {
m.calls.describeLogGroups = append(m.calls.describeLogGroups, input)
output := &m.logGroups[m.logGroupsIndex] output := &m.logGroups[m.logGroupsIndex]
m.logGroupsIndex++ m.logGroupsIndex++
return output, nil return output, nil
@ -207,7 +201,7 @@ func (c fakeCheckHealthClient) ListMetricsPagesWithContext(ctx aws.Context, inpu
return nil return nil
} }
func (c fakeCheckHealthClient) DescribeLogGroups(input *cloudwatchlogs.DescribeLogGroupsInput) (*cloudwatchlogs.DescribeLogGroupsOutput, error) { func (c fakeCheckHealthClient) DescribeLogGroupsWithContext(ctx context.Context, input *cloudwatchlogs.DescribeLogGroupsInput, option ...request.Option) (*cloudwatchlogs.DescribeLogGroupsOutput, error) {
if c.describeLogGroups != nil { if c.describeLogGroups != nil {
return c.describeLogGroups(input) return c.describeLogGroups(input)
} }