lib/anthropic/vertex/client.rb (95 lines of code) (raw):

# frozen_string_literal: true module Anthropic module Vertex class Client < Anthropic::Client DEFAULT_VERSION = "vertex-2023-10-16" # @return [String] attr_reader :region # @return [String] attr_reader :project_id # @return [Anthropic::Resources::Messages] attr_reader :messages # @return [Anthropic::Resources::Beta] attr_reader :beta # Creates and returns a new client for interacting with the GCP Vertex API for Anthropic models. # # GCP credentials are resolved according to Application Default Credentials, described at # https://cloud.google.com/docs/authentication/provide-credentials-adc # # @param region [String, nil] Enforce the GCP Region to use. If unset, the region may be set with the CLOUD_ML_REGION environment variable. # # @param project_id [String, nil] Enforce the GCP Project ID to use. If unset, the project_id may be set with the ANTHROPIC_VERTEX_PROJECT_ID environment variable. # # @param base_url [String, nil] Override the default base URL for the API, e.g., `"https://api.example.com/v2/"` # # @param max_retries [Integer] The maximum number of times to retry a request if it fails # # @param timeout [Float] The number of seconds to wait for a response before timing out # # @param initial_retry_delay [Float] The number of seconds to wait before retrying a request # # @param max_retry_delay [Float] The maximum number of seconds to wait before retrying a request # def initialize( region: ENV["CLOUD_ML_REGION"], project_id: ENV["ANTHROPIC_VERTEX_PROJECT_ID"], base_url: nil, max_retries: DEFAULT_MAX_RETRIES, timeout: DEFAULT_TIMEOUT_IN_SECONDS, initial_retry_delay: DEFAULT_INITIAL_RETRY_DELAY, max_retry_delay: DEFAULT_MAX_RETRY_DELAY ) begin require("googleauth") rescue LoadError raise <<~MSG In order to access Anthropic models on Vertex you must require the `googleauth` gem. You can install it by adding the following to your Gemfile: gem "googleauth" and then running `bundle install`. Alternatively, if you are not using Bundler, simply run: gem install googleauth MSG end if region.to_s.empty? raise ArgumentError, "No region was given. The client should be instantiated with the `region` argument or the `CLOUD_ML_REGION` environment variable should be set." end @region = region if project_id.to_s.empty? raise ArgumentError, "No project_id was given and it could not be resolved from credentials. The client should be instantiated with the `project_id` argument or the `ANTHROPIC_VERTEX_PROJECT_ID` environment variable should be set." end @project_id = project_id base_url ||= ENV.fetch("ANTHROPIC_VERTEX_BASE_URL", "https://#{@region}-aiplatform.googleapis.com/v1") super( base_url: base_url, timeout: timeout, max_retries: max_retries, initial_retry_delay: initial_retry_delay, max_retry_delay: max_retry_delay, ) @messages = Anthropic::Resources::Messages.new(client: self) @beta = Anthropic::Resources::Beta.new(client: self) end # @private # # @param req [Hash{Symbol=>Object}] . # # @option req [Symbol] :method # # @option req [String, Array<String>] :path # # @option req [Hash{String=>Array<String>, String, nil}, nil] :query # # @option req [Hash{String=>String, nil}, nil] :headers # # @option req [Object, nil] :body # # @option req [Symbol, nil] :unwrap # # @option req [Class, nil] :page # # @option req [Anthropic::Converter, Class, nil] :model # # @param opts [Hash{Symbol=>Object}] . # # @option opts [String, nil] :idempotency_key # # @option opts [Hash{String=>Array<String>, String, nil}, nil] :extra_query # # @option opts [Hash{String=>String, nil}, nil] :extra_headers # # @option opts [Hash{Symbol=>Object}, nil] :extra_body # # @option opts [Integer, nil] :max_retries # # @option opts [Float, nil] :timeout # # @return [Hash{Symbol=>Object}] # private def build_request(req, opts) fit_req_to_vertex_specs!(req) request_input = super unless request_input[:headers]["Authorization"] authorization = Google::Auth.get_application_default(["https://www.googleapis.com/auth/cloud-platform"]) request_input[:headers] = authorization.apply(request_input[:headers]) end request_input end # @private # # Overrides request components for Vertex-specific request-shape requirements. # # @param request_components [Hash{Symbol=>Object}] . # # @option request_components [Symbol] :method # # @option request_components [String, Array<String>] :path # # @option request_components [Hash{String=>Array<String>, String, nil}, nil] :query # # @option request_components [Hash{String=>String, nil}, nil] :headers # # @option request_components [Object, nil] :body # # @option request_components [Symbol, nil] :unwrap # # @option request_components [Class, nil] :page # # @option request_components [Anthropic::Converter, Class, nil] :model # # @return [Hash{Symbol=>Object}] def fit_req_to_vertex_specs!(request_components) if (body = request_components[:body]).is_a?(Hash) body[:anthropic_version] ||= DEFAULT_VERSION if (anthropic_beta = body.delete(:"anthropic-beta")) request_components[:headers] ||= {} request_components[:headers]["anthropic-beta"] = anthropic_beta.join(",") end end if %w[ v1/messages v1/messages?beta=true ].include?(request_components[:path]) && request_components[:method] == :post unless body.is_a?(Hash) raise Anthropic::Error, "Expected json data to be a hash for post /v1/messages" end model = body.delete(:model) specifier = body[:stream] ? "streamRawPredict" : "rawPredict" request_components[:path] = "projects/#{@project_id}/locations/#{region}/publishers/anthropic/models/#{model}:#{specifier}" end if %w[ v1/messages/count_tokens v1/messages/count_tokens?beta=true ].include?(request_components[:path]) && request_components[:method] == :post request_components[:path] = "projects/#{@project_id}/locations/#{region}/publishers/anthropic/models/count-tokens:rawPredict" end if request_components[:path].start_with?("v1/messages/batches/") raise AnthropicError("The Batch API is not supported in the Vertex client yet") end request_components end end end end