diff --git a/endpoints/publicProxy/config.go b/endpoints/publicProxy/config.go index cac3da354..233c1a4ea 100644 --- a/endpoints/publicProxy/config.go +++ b/endpoints/publicProxy/config.go @@ -16,11 +16,16 @@ type Config struct { Identity string Address string HostMatch string - Interstitial bool + Interstitial *InterstitialConfig Oauth *OauthConfig Tls *endpoints.TlsConfig } +type InterstitialConfig struct { + Enabled bool + UserAgentPrefixes []string +} + type OauthConfig struct { BindAddress string RedirectUrl string @@ -46,9 +51,8 @@ type OauthProviderConfig struct { func DefaultConfig() *Config { return &Config{ - Identity: "public", - Address: "0.0.0.0:8080", - Interstitial: false, + Identity: "public", + Address: "0.0.0.0:8080", } } diff --git a/endpoints/publicProxy/http.go b/endpoints/publicProxy/http.go index 5e3ba4e81..69de6330e 100644 --- a/endpoints/publicProxy/http.go +++ b/endpoints/publicProxy/http.go @@ -158,15 +158,31 @@ func shareHandler(handler http.Handler, pcfg *Config, key []byte, ctx ziti.Conte if shrToken != "" { if svc, found := endpoints.GetRefreshedService(shrToken, ctx); found { if cfg, found := svc.Config[sdk.ZrokProxyConfig]; found { - if pcfg.Interstitial { - if v, istlFound := cfg["interstitial"]; istlFound { - if istlEnabled, ok := v.(bool); ok && istlEnabled { - skip := r.Header.Get("skip_zrok_interstitial") - _, zrokOkErr := r.Cookie("zrok_interstitial") - if skip == "" && zrokOkErr != nil { - logrus.Debugf("forcing interstitial for '%v'", r.URL) - interstitialUi.WriteInterstitialAnnounce(w) - return + if pcfg.Interstitial != nil && pcfg.Interstitial.Enabled { + sendInterstitial := true + if len(pcfg.Interstitial.UserAgentPrefixes) > 0 { + ua := r.Header.Get("User-Agent") + matched := false + for _, prefix := range pcfg.Interstitial.UserAgentPrefixes { + if strings.HasPrefix(ua, prefix) { + matched = true + break + } + } + if !matched { + sendInterstitial = false + } + } + if sendInterstitial { + if v, istlFound := cfg["interstitial"]; istlFound { + if istlEnabled, ok := v.(bool); ok && istlEnabled { + skip := r.Header.Get("skip_zrok_interstitial") + _, zrokOkErr := r.Cookie("zrok_interstitial") + if skip == "" && zrokOkErr != nil { + logrus.Debugf("forcing interstitial for '%v'", r.URL) + interstitialUi.WriteInterstitialAnnounce(w) + return + } } } }