DEV: Drop OpenStruct for the context object in services

While using `OpenStruct` is nice, it’s generally not a very good idea as
it usually leads to performance problems.

The `OpenStruct` source code even says basically to avoid it.

Since the context object is crucial in our services, this patch replaces
`OpenStruct` with a custom implementation instead.
This commit is contained in:
Loïc Guitaut 2024-10-03 18:05:45 +02:00 committed by Loïc Guitaut
parent 974a3bfc41
commit 229773e7a8
20 changed files with 145 additions and 143 deletions

View File

@ -9,9 +9,10 @@ class UpdateSiteSetting
attribute :new_value attribute :new_value
attribute :allow_changing_hidden, :boolean, default: false attribute :allow_changing_hidden, :boolean, default: false
before_validation { self.setting_name = setting_name&.to_sym }
validates :setting_name, presence: true validates :setting_name, presence: true
end end
step :convert_name_to_sym
policy :setting_is_visible policy :setting_is_visible
policy :setting_is_configurable policy :setting_is_configurable
step :cleanup_value step :cleanup_value
@ -19,28 +20,25 @@ class UpdateSiteSetting
private private
def convert_name_to_sym(setting_name:)
context.setting_name = setting_name.to_sym
end
def current_user_is_admin(guardian:) def current_user_is_admin(guardian:)
guardian.is_admin? guardian.is_admin?
end end
def setting_is_visible(setting_name:) def setting_is_visible(contract:)
context.allow_changing_hidden || !SiteSetting.hidden_settings.include?(setting_name) contract.allow_changing_hidden || !SiteSetting.hidden_settings.include?(contract.setting_name)
end end
def setting_is_configurable(setting_name:) def setting_is_configurable(contract:)
return true if !SiteSetting.plugins[setting_name] return true if !SiteSetting.plugins[contract.setting_name]
Discourse.plugins_by_name[SiteSetting.plugins[setting_name]].configurable? Discourse.plugins_by_name[SiteSetting.plugins[contract.setting_name]].configurable?
end end
def cleanup_value(setting_name:, new_value:) def cleanup_value(contract:)
new_value = contract.new_value
new_value = new_value.strip if new_value.is_a?(String) new_value = new_value.strip if new_value.is_a?(String)
case SiteSetting.type_supervisor.get_type(setting_name) case SiteSetting.type_supervisor.get_type(contract.setting_name)
when :integer when :integer
new_value = new_value.tr("^-0-9", "").to_i if new_value.is_a?(String) new_value = new_value.tr("^-0-9", "").to_i if new_value.is_a?(String)
when :file_size_restriction when :file_size_restriction
@ -50,10 +48,10 @@ class UpdateSiteSetting
when :upload when :upload
new_value = Upload.get_from_url(new_value) || "" new_value = Upload.get_from_url(new_value) || ""
end end
context.new_value = new_value context[:new_value] = new_value
end end
def save(setting_name:, new_value:, guardian:) def save(contract:, new_value:, guardian:)
SiteSetting.set_and_log(setting_name, new_value, guardian.user) SiteSetting.set_and_log(contract.setting_name, new_value, guardian.user)
end end
end end

View File

@ -17,8 +17,24 @@ module Service
end end
# Simple structure to hold the context of the service during its whole lifecycle. # Simple structure to hold the context of the service during its whole lifecycle.
class Context < OpenStruct class Context
include ActiveModel::Serialization delegate :slice, to: :store
def initialize(context = {})
@store = context.symbolize_keys
end
def [](key)
store[key.to_sym]
end
def []=(key, value)
store[key.to_sym] = value
end
def to_h
store.dup
end
# @return [Boolean] returns +true+ if the context is set as successful (default) # @return [Boolean] returns +true+ if the context is set as successful (default)
def success? def success?
@ -48,27 +64,27 @@ module Service
# context.fail("failure": "something went wrong") # context.fail("failure": "something went wrong")
# @return [Context] # @return [Context]
def fail(context = {}) def fail(context = {})
merge(context) store.merge!(context.symbolize_keys)
@failure = true @failure = true
self self
end end
# Merges the given context into the current one.
# @!visibility private
def merge(other_context = {})
other_context.each { |key, value| self[key.to_sym] = value }
self
end
def inspect_steps def inspect_steps
StepsInspector.new(self) Service::StepsInspector.new(self)
end end
private private
attr_reader :store
def self.build(context = {}) def self.build(context = {})
self === context ? context : new(context) self === context ? context : new(context)
end end
def method_missing(method_name, *args, &block)
return super if args.present?
store[method_name]
end
end end
# Internal module to define available steps as DSL # Internal module to define available steps as DSL
@ -117,7 +133,7 @@ module Service
if method.parameters.any? { _1[0] != :keyreq } if method.parameters.any? { _1[0] != :keyreq }
raise "In #{type} '#{name}': default values in step implementations are not allowed. Maybe they could be defined in a contract?" raise "In #{type} '#{name}': default values in step implementations are not allowed. Maybe they could be defined in a contract?"
end end
args = context.to_h.slice(*method.parameters.select { _1[0] == :keyreq }.map(&:last)) args = context.slice(*method.parameters.select { _1[0] == :keyreq }.map(&:last))
context[result_key] = Context.build(object: object) context[result_key] = Context.build(object: object)
instance.instance_exec(**args, &method) instance.instance_exec(**args, &method)
end end
@ -180,7 +196,7 @@ module Service
attributes = class_name.attribute_names.map(&:to_sym) attributes = class_name.attribute_names.map(&:to_sym)
default_values = {} default_values = {}
default_values = context[default_values_from].slice(*attributes) if default_values_from default_values = context[default_values_from].slice(*attributes) if default_values_from
contract = class_name.new(default_values.merge(context.to_h.slice(*attributes))) contract = class_name.new(default_values.merge(context.slice(*attributes)))
context[contract_name] = contract context[contract_name] = contract
context[result_key] = Context.build context[result_key] = Context.build
if contract.invalid? if contract.invalid?
@ -347,7 +363,6 @@ module Service
# @!visibility private # @!visibility private
def initialize(initial_context = {}) def initialize(initial_context = {})
@initial_context = initial_context.with_indifferent_access
@context = Context.build(initial_context.merge(__steps__: self.class.steps)) @context = Context.build(initial_context.merge(__steps__: self.class.steps))
end end

View File

@ -1,10 +1,10 @@
# frozen_string_literal: true # frozen_string_literal: true
# = StepsInspector # = Service::StepsInspector
# #
# This class takes a {Service::Base::Context} object and inspects it. # This class takes a {Service::Base::Context} object and inspects it.
# It will output a list of steps and what is their known state. # It will output a list of steps and what is their known state.
class StepsInspector class Service::StepsInspector
# @!visibility private # @!visibility private
class Step class Step
attr_reader :step, :result, :nesting_level attr_reader :step, :result, :nesting_level

View File

@ -82,19 +82,18 @@ module Chat
end end
if memberships.blank? if memberships.blank?
context.added_user_ids = [] context[:added_user_ids] = []
return return
end end
context.added_user_ids = context[:added_user_ids] = ::Chat::UserChatChannelMembership
::Chat::UserChatChannelMembership .upsert_all(
.upsert_all( memberships,
memberships, unique_by: %i[user_id chat_channel_id],
unique_by: %i[user_id chat_channel_id], returning: Arel.sql("user_id, (xmax = '0') as inserted"),
returning: Arel.sql("user_id, (xmax = '0') as inserted"), )
) .select { |row| row["inserted"] }
.select { |row| row["inserted"] } .map { |row| row["user_id"] }
.map { |row| row["user_id"] }
::Chat::DirectMessageUser.upsert_all( ::Chat::DirectMessageUser.upsert_all(
context.added_user_ids.map do |id| context.added_user_ids.map do |id|

View File

@ -47,8 +47,10 @@ module Chat
end end
def create_memberships(channel:, contract:) def create_memberships(channel:, contract:)
context.added_user_ids = context[:added_user_ids] = ::Chat::Action::CreateMembershipsForAutoJoin.call(
::Chat::Action::CreateMembershipsForAutoJoin.call(channel: channel, contract: contract) channel: channel,
contract: contract,
)
end end
def recalculate_user_count(channel:, added_user_ids:) def recalculate_user_count(channel:, added_user_ids:)

View File

@ -100,11 +100,8 @@ module Chat
return if memberships_to_remove.empty? return if memberships_to_remove.empty?
context.merge( context[:users_removed_map] = Chat::Action::RemoveMemberships.call(
users_removed_map: memberships: Chat::UserChatChannelMembership.where(id: memberships_to_remove),
Chat::Action::RemoveMemberships.call(
memberships: Chat::UserChatChannelMembership.where(id: memberships_to_remove),
),
) )
end end

View File

@ -81,11 +81,8 @@ module Chat
return if memberships_to_remove.empty? return if memberships_to_remove.empty?
context.merge( context[:users_removed_map] = Chat::Action::RemoveMemberships.call(
users_removed_map: memberships: Chat::UserChatChannelMembership.where(id: memberships_to_remove),
Chat::Action::RemoveMemberships.call(
memberships: Chat::UserChatChannelMembership.where(id: memberships_to_remove),
),
) )
end end

View File

@ -59,15 +59,14 @@ module Chat
def find_or_create_thread(channel:, original_message:, contract:) def find_or_create_thread(channel:, original_message:, contract:)
if original_message.thread_id.present? if original_message.thread_id.present?
return context.thread = ::Chat::Thread.find_by(id: original_message.thread_id) return context[:thread] = ::Chat::Thread.find_by(id: original_message.thread_id)
end end
context.thread = context[:thread] = channel.threads.create(
channel.threads.create( title: contract.title,
title: contract.title, original_message: original_message,
original_message: original_message, original_message_user: original_message.user,
original_message_user: original_message.user, )
)
fail!(context.thread.errors.full_messages.join(", ")) if context.thread.invalid? fail!(context.thread.errors.full_messages.join(", ")) if context.thread.invalid?
end end
@ -76,7 +75,7 @@ module Chat
end end
def fetch_membership(guardian:) def fetch_membership(guardian:)
context.membership = context.thread.membership_for(guardian.user) context[:membership] = context.thread.membership_for(guardian.user)
end end
def publish_new_thread(channel:, original_message:) def publish_new_thread(channel:, original_message:)

View File

@ -65,18 +65,18 @@ module Chat
end end
def enabled_threads?(channel:) def enabled_threads?(channel:)
context.enabled_threads = channel.threading_enabled context[:enabled_threads] = channel.threading_enabled
end end
def can_view_channel(guardian:, channel:) def can_view_channel(guardian:, channel:)
guardian.can_preview_chat_channel?(channel) guardian.can_preview_chat_channel?(channel)
end end
def determine_target_message_id(contract:) def determine_target_message_id(contract:, membership:)
if contract.fetch_from_last_read if contract.fetch_from_last_read
context.target_message_id = context.membership&.last_read_message_id context[:target_message_id] = membership&.last_read_message_id
else else
context.target_message_id = contract.target_message_id context[:target_message_id] = contract.target_message_id
end end
end end
@ -92,7 +92,7 @@ module Chat
return true return true
end end
context.target_message_id = nil context[:target_message_id] = nil
true true
end end
@ -108,9 +108,9 @@ module Chat
target_date: contract.target_date, target_date: contract.target_date,
) )
context.can_load_more_past = messages_data[:can_load_more_past] context[:can_load_more_past] = messages_data[:can_load_more_past]
context.can_load_more_future = messages_data[:can_load_more_future] context[:can_load_more_future] = messages_data[:can_load_more_future]
context.target_message_id = messages_data[:target_message_id] context[:target_message_id] = messages_data[:target_message_id]
messages_data[:target_message] = ( messages_data[:target_message] = (
if messages_data[:target_message]&.thread_reply? && if messages_data[:target_message]&.thread_reply? &&
@ -121,7 +121,7 @@ module Chat
end end
) )
context.messages = [ context[:messages] = [
messages_data[:messages], messages_data[:messages],
messages_data[:past_messages]&.reverse, messages_data[:past_messages]&.reverse,
messages_data[:target_message], messages_data[:target_message],
@ -130,37 +130,36 @@ module Chat
end end
def fetch_tracking(guardian:) def fetch_tracking(guardian:)
context.tracking = {} context[:tracking] = {}
return if !context.thread_ids.present? return if !context.thread_ids.present?
context.tracking = context[:tracking] = ::Chat::TrackingStateReportQuery.call(
::Chat::TrackingStateReportQuery.call( guardian: guardian,
guardian: guardian, thread_ids: context.thread_ids,
thread_ids: context.thread_ids, include_threads: true,
include_threads: true, )
)
end end
def fetch_thread_ids(messages:) def fetch_thread_ids(messages:)
context.thread_ids = messages.map(&:thread_id).compact.uniq context[:thread_ids] = messages.map(&:thread_id).compact.uniq
end end
def fetch_thread_participants(messages:) def fetch_thread_participants(messages:)
return if context.thread_ids.empty? return if context.thread_ids.empty?
context.thread_participants = context[:thread_participants] = ::Chat::ThreadParticipantQuery.call(
::Chat::ThreadParticipantQuery.call(thread_ids: context.thread_ids) thread_ids: context.thread_ids,
)
end end
def fetch_thread_memberships(guardian:) def fetch_thread_memberships(guardian:)
return if context.thread_ids.empty? return if context.thread_ids.empty?
context.thread_memberships = context[:thread_memberships] = ::Chat::UserChatThreadMembership.where(
::Chat::UserChatThreadMembership.where( thread_id: context.thread_ids,
thread_id: context.thread_ids, user_id: guardian.user.id,
user_id: guardian.user.id, )
)
end end
def update_membership_last_viewed_at(guardian:) def update_membership_last_viewed_at(guardian:)

View File

@ -63,13 +63,13 @@ module Chat
def determine_target_message_id(contract:, membership:, guardian:, thread:) def determine_target_message_id(contract:, membership:, guardian:, thread:)
if contract.fetch_from_last_message if contract.fetch_from_last_message
context.target_message_id = thread.last_message_id context[:target_message_id] = thread.last_message_id
elsif contract.fetch_from_first_message elsif contract.fetch_from_first_message
context.target_message_id = thread.original_message_id context[:target_message_id] = thread.original_message_id
elsif contract.fetch_from_last_read || !contract.target_message_id elsif contract.fetch_from_last_read || !contract.target_message_id
context.target_message_id = membership&.last_read_message_id context[:target_message_id] = membership&.last_read_message_id
elsif contract.target_message_id elsif contract.target_message_id
context.target_message_id = contract.target_message_id context[:target_message_id] = contract.target_message_id
end end
end end
@ -99,8 +99,8 @@ module Chat
contract.fetch_from_first_message || contract.fetch_from_last_message, contract.fetch_from_first_message || contract.fetch_from_last_message,
) )
context.can_load_more_past = messages_data[:can_load_more_past] context[:can_load_more_past] = messages_data[:can_load_more_past]
context.can_load_more_future = messages_data[:can_load_more_future] context[:can_load_more_future] = messages_data[:can_load_more_future]
[ [
messages_data[:messages], messages_data[:messages],

View File

@ -51,11 +51,11 @@ module Chat
private private
def set_limit(contract:) def set_limit(contract:)
context.limit = (contract.limit || THREADS_LIMIT).to_i.clamp(1, THREADS_LIMIT) context[:limit] = (contract.limit || THREADS_LIMIT).to_i.clamp(1, THREADS_LIMIT)
end end
def set_offset(contract:) def set_offset(contract:)
context.offset = [contract.offset || 0, 0].max context[:offset] = [contract.offset || 0, 0].max
end end
def fetch_channel(contract:) def fetch_channel(contract:)
@ -118,33 +118,30 @@ module Chat
end end
def fetch_tracking(guardian:, threads:) def fetch_tracking(guardian:, threads:)
context.tracking = context[:tracking] = ::Chat::TrackingStateReportQuery.call(
::Chat::TrackingStateReportQuery.call( guardian: guardian,
guardian: guardian, thread_ids: threads.map(&:id),
thread_ids: threads.map(&:id), include_threads: true,
include_threads: true, ).thread_tracking
).thread_tracking
end end
def fetch_memberships(guardian:, threads:) def fetch_memberships(guardian:, threads:)
context.memberships = context[:memberships] = ::Chat::UserChatThreadMembership.where(
::Chat::UserChatThreadMembership.where( thread_id: threads.map(&:id),
thread_id: threads.map(&:id), user_id: guardian.user.id,
user_id: guardian.user.id, )
)
end end
def fetch_participants(threads:) def fetch_participants(threads:)
context.participants = ::Chat::ThreadParticipantQuery.call(thread_ids: threads.map(&:id)) context[:participants] = ::Chat::ThreadParticipantQuery.call(thread_ids: threads.map(&:id))
end end
def build_load_more_url(contract:) def build_load_more_url(contract:)
load_more_params = { offset: context.offset + context.limit }.to_query load_more_params = { offset: context.offset + context.limit }.to_query
context.load_more_url = context[:load_more_url] = ::URI::HTTP.build(
::URI::HTTP.build( path: "/chat/api/channels/#{contract.channel_id}/threads",
path: "/chat/api/channels/#{contract.channel_id}/threads", query: load_more_params,
query: load_more_params, ).request_uri
).request_uri
end end
end end
end end

View File

@ -35,11 +35,11 @@ module Chat
private private
def set_limit(contract:) def set_limit(contract:)
context.limit = (contract.limit || THREADS_LIMIT).to_i.clamp(1, THREADS_LIMIT) context[:limit] = (contract.limit || THREADS_LIMIT).to_i.clamp(1, THREADS_LIMIT)
end end
def set_offset(contract:) def set_offset(contract:)
context.offset = [contract.offset || 0, 0].max context[:offset] = [contract.offset || 0, 0].max
end end
def fetch_threads(guardian:) def fetch_threads(guardian:)
@ -112,31 +112,31 @@ module Chat
end end
def fetch_tracking(guardian:, threads:) def fetch_tracking(guardian:, threads:)
context.tracking = context[:tracking] = ::Chat::TrackingStateReportQuery.call(
::Chat::TrackingStateReportQuery.call( guardian: guardian,
guardian: guardian, thread_ids: threads.map(&:id),
thread_ids: threads.map(&:id), include_threads: true,
include_threads: true, ).thread_tracking
).thread_tracking
end end
def fetch_memberships(guardian:, threads:) def fetch_memberships(guardian:, threads:)
context.memberships = context[:memberships] = ::Chat::UserChatThreadMembership.where(
::Chat::UserChatThreadMembership.where( thread_id: threads.map(&:id),
thread_id: threads.map(&:id), user_id: guardian.user.id,
user_id: guardian.user.id, )
)
end end
def fetch_participants(threads:) def fetch_participants(threads:)
context.participants = ::Chat::ThreadParticipantQuery.call(thread_ids: threads.map(&:id)) context[:participants] = ::Chat::ThreadParticipantQuery.call(thread_ids: threads.map(&:id))
end end
def build_load_more_url(contract:) def build_load_more_url(contract:)
load_more_params = { limit: context.limit, offset: context.offset + context.limit }.to_query load_more_params = { limit: context.limit, offset: context.offset + context.limit }.to_query
context.load_more_url = context[:load_more_url] = ::URI::HTTP.build(
::URI::HTTP.build(path: "/chat/api/me/threads", query: load_more_params).request_uri path: "/chat/api/me/threads",
query: load_more_params,
).request_uri
end end
end end
end end

View File

@ -54,7 +54,7 @@ module Chat
notification_level: Chat::NotificationLevels.all[:normal], notification_level: Chat::NotificationLevels.all[:normal],
) if !membership ) if !membership
membership.update!(thread_title_prompt_seen: true) membership.update!(thread_title_prompt_seen: true)
context.membership = membership context[:membership] = membership
end end
end end
end end

View File

@ -34,7 +34,7 @@ module Chat
private private
def clean_term(contract:) def clean_term(contract:)
context.term = contract.term&.downcase&.strip&.gsub(/^[@#]+/, "") context[:term] = contract.term&.downcase&.strip&.gsub(/^[@#]+/, "")
end end
def fetch_memberships(guardian:) def fetch_memberships(guardian:)

View File

@ -32,7 +32,7 @@ module Chat
end end
def unfollow(channel:, guardian:) def unfollow(channel:, guardian:)
context.membership = channel.remove(guardian.user) context[:membership] = channel.remove(guardian.user)
end end
end end
end end

View File

@ -66,7 +66,7 @@ module Chat
def update_channel(channel:, contract:) def update_channel(channel:, contract:)
channel.assign_attributes(contract.attributes) channel.assign_attributes(contract.attributes)
context.threading_enabled_changed = channel.threading_enabled_changed? context[:threading_enabled_changed] = channel.threading_enabled_changed?
channel.save! channel.save!
end end

View File

@ -120,12 +120,11 @@ module Chat
prev_message = message.message_before_last_save || message.message_was prev_message = message.message_before_last_save || message.message_was
return if !should_create_revision(message, prev_message, guardian) return if !should_create_revision(message, prev_message, guardian)
context.revision = context[:revision] = message.revisions.create!(
message.revisions.create!( old_message: prev_message,
old_message: prev_message, new_message: message.message,
new_message: message.message, user_id: guardian.user.id,
user_id: guardian.user.id, )
)
end end
def should_create_revision(new_message, prev_message, guardian) def should_create_revision(new_message, prev_message, guardian)
@ -151,7 +150,7 @@ module Chat
end end
def publish(message:, guardian:, contract:) def publish(message:, guardian:, contract:)
edit_timestamp = context.revision&.created_at&.iso8601(6) || Time.zone.now.iso8601(6) edit_timestamp = context[:revision]&.created_at&.iso8601(6) || Time.zone.now.iso8601(6)
::Chat::Publisher.publish_edit!(message.chat_channel, message) ::Chat::Publisher.publish_edit!(message.chat_channel, message)

View File

@ -60,7 +60,7 @@ module Chat
membership.update!(last_read_message_id: thread.last_message_id) membership.update!(last_read_message_id: thread.last_message_id)
end end
membership.update!(notification_level: contract.notification_level) membership.update!(notification_level: contract.notification_level)
context.membership = membership context[:membership] = membership
end end
end end
end end

View File

@ -1,6 +1,6 @@
# frozen_string_literal: true # frozen_string_literal: true
RSpec.describe StepsInspector do RSpec.describe Service::StepsInspector do
class DummyService class DummyService
include Service::Base include Service::Base
@ -239,7 +239,7 @@ RSpec.describe StepsInspector do
end end
context "when a reason is provided" do context "when a reason is provided" do
before { result["result.policy.policy"].reason = "failed" } before { result["result.policy.policy"][:reason] = "failed" }
it "returns the reason" do it "returns the reason" do
expect(error).to eq "failed" expect(error).to eq "failed"

View File

@ -26,7 +26,7 @@ module ServiceMatchers
private private
def error_message_with_inspection(message) def error_message_with_inspection(message)
inspector = StepsInspector.new(result) inspector = Service::StepsInspector.new(result)
"#{message}\n\n#{inspector.inspect}\n\n#{inspector.error}" "#{message}\n\n#{inspector.inspect}\n\n#{inspector.error}"
end end
end end
@ -89,7 +89,7 @@ module ServiceMatchers
end end
def error_message_with_inspection(message) def error_message_with_inspection(message)
inspector = StepsInspector.new(result) inspector = Service::StepsInspector.new(result)
"#{message}\n\n#{inspector.inspect}\n\n#{inspector.error}" "#{message}\n\n#{inspector.inspect}\n\n#{inspector.error}"
end end
@ -158,7 +158,7 @@ module ServiceMatchers
end end
def inspect_steps(result) def inspect_steps(result)
inspector = StepsInspector.new(result) inspector = Service::StepsInspector.new(result)
puts "Steps:" puts "Steps:"
puts inspector.inspect puts inspector.inspect
puts "\nFirst error:" puts "\nFirst error:"