Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 27 additions & 11 deletions apps/documentation/server/a2a-server.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import { cors } from '@elysiajs/cors'
import {
CORE_PORTS,
getLocalhostHost,
Expand Down Expand Up @@ -144,6 +143,25 @@ const AGENT_CARD = {
],
} as const

function resolveOrigin(origin?: string | null): string | null {
if (!origin) return null
return ALLOWED_ORIGINS.includes(origin) ? origin : null
}

function applyCorsHeaders(
set: { headers: Record<string, string | string[] | number> },
request: Request,
) {
const allowedOrigin = resolveOrigin(request.headers.get('origin'))
if (!allowedOrigin) return
set.headers['Access-Control-Allow-Origin'] = allowedOrigin
set.headers['Access-Control-Allow-Credentials'] = 'true'
set.headers['Access-Control-Allow-Headers'] = 'Content-Type'
set.headers['Access-Control-Allow-Methods'] = 'GET,POST,OPTIONS'
set.headers['Access-Control-Max-Age'] = '86400'
set.headers.Vary = 'Origin'
}

/** Validate documentation page path (no traversal allowed) */
function validateDocPath(pagePath: string): string {
// Normalize and check for path traversal
Expand Down Expand Up @@ -204,16 +222,14 @@ async function executeSkill(
}

export const app = new Elysia()
.use(
cors({
origin: (request) => {
const origin = request.headers.get('origin')
if (!origin) return true
return ALLOWED_ORIGINS.includes(origin)
},
credentials: true,
}),
)
.onRequest(({ request, set }) => {
applyCorsHeaders(set, request)
return
})
.options('*', ({ request, set }) => {
applyCorsHeaders(set, request)
return new Response(null, { status: 204 })
})
.derive(({ request, server }) => {
const forwarded = request.headers.get('x-forwarded-for')
const clientIp =
Expand Down
2 changes: 1 addition & 1 deletion apps/documentation/tests/unit/a2a.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ describe('A2A Server Structure', () => {

test('has proper CORS configuration', async () => {
const serverCode = await Bun.file(SERVER_PATH).text()
expect(serverCode).toContain('cors(')
expect(serverCode).toContain('ALLOWED_ORIGINS')
expect(serverCode).toContain('Access-Control-Allow-Origin')
})
})

Expand Down
28 changes: 15 additions & 13 deletions apps/dws/api/rlaif/trainers/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,32 +86,34 @@ def setup(self, reference_model_cid: str | None = None):

self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)

self.model = AutoModelForCausalLM.from_pretrained(
model = cast(PreTrainedModel, AutoModelForCausalLM.from_pretrained(
self.model_name, torch_dtype=torch.bfloat16, trust_remote_code=True
)
))
device = torch.device(self.device)
self.model = cast(PreTrainedModel, torch.nn.Module.to(self.model, device))
self.model.gradient_checkpointing_enable()
self.model.train()
model = cast(PreTrainedModel, torch.nn.Module.to(model, device))
model.gradient_checkpointing_enable()
model.train()
self.model = model

# Load or clone reference model
if reference_model_cid:
ref_path = self._download_model(reference_model_cid)
self.ref_model = AutoModelForCausalLM.from_pretrained(
ref_model = cast(PreTrainedModel, AutoModelForCausalLM.from_pretrained(
ref_path, torch_dtype=torch.bfloat16, trust_remote_code=True
)
))
else:
# Clone current model as reference
self.ref_model = AutoModelForCausalLM.from_pretrained(
ref_model = cast(PreTrainedModel, AutoModelForCausalLM.from_pretrained(
self.model_name, torch_dtype=torch.bfloat16, trust_remote_code=True
)
))

self.ref_model = cast(PreTrainedModel, torch.nn.Module.to(self.ref_model, device))
self.ref_model.eval()
for param in self.ref_model.parameters():
ref_model = cast(PreTrainedModel, torch.nn.Module.to(ref_model, device))
ref_model.eval()
for param in ref_model.parameters():
param.requires_grad = False
self.ref_model = ref_model

self.optimizer = AdamW(self.model.parameters(), lr=self.learning_rate)
self.optimizer = AdamW(model.parameters(), lr=self.learning_rate)

logger.info(f"Model loaded on {self.device}")

Expand Down
1 change: 0 additions & 1 deletion apps/oauth3/api/routes/auth-init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,6 @@ export function createAuthInitRouter(config: AuthConfig) {
return { error: 'invalid_origin' }
}

const appId = body.appId ?? 'jeju-default'
const rpId = origin.hostname

const existingCredentials = await passkeyState.listCredentialsByRpId(
Expand Down
11 changes: 7 additions & 4 deletions apps/oauth3/api/services/kms.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ export async function generateSecureToken(
}

const config = getConfig()
const header = { alg: 'ES256K', typ: 'JWT', kid: config.jwtSigningKeyId }
const headerB64 = base64urlEncode(JSON.stringify(header))
const headerB64 = base64urlEncode(
JSON.stringify({ alg: 'ES256K', typ: 'JWT', kid: config.jwtSigningKeyId }),
)
const payloadB64 = base64urlEncode(JSON.stringify(claims))
const signingInput = `${headerB64}.${payloadB64}`
const messageHash = keccak256(toBytes(signingInput))
Expand Down Expand Up @@ -132,10 +133,12 @@ export async function verifySecureToken(token: string): Promise<string | null> {

const [headerB64, payloadB64, signatureB64] = parts

let header: { alg?: string; kid?: string }
let claims: { sub?: string; iss?: string; exp?: number }
try {
header = JSON.parse(base64urlDecode(headerB64))
const header = JSON.parse(base64urlDecode(headerB64))
if (header?.alg && header.alg !== 'ES256K') {
return null
}
claims = JSON.parse(base64urlDecode(payloadB64))
} catch {
return null
Expand Down
6 changes: 5 additions & 1 deletion apps/oauth3/api/services/sealed-oauth.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import type { AuthProvider, SealedOAuthProvider } from '../../lib/types'
import type {
AuthProvider,
SealedOAuthProvider,
SealedSecret,
} from '../../lib/types'
import { z } from 'zod'
import { sealSecret, unsealSecret } from './kms'

Expand Down
10 changes: 6 additions & 4 deletions apps/wallet/web/components/auth/LinkedAccounts.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ type ProviderInfo = {
description: string
}

const PROVIDER_INFO: Record<AuthProvider, ProviderInfo> = {
type SupportedAuthProvider = Exclude<AuthProvider, 'email' | 'phone'>

const PROVIDER_INFO: Record<SupportedAuthProvider, ProviderInfo> = {
[AuthProvider.WALLET]: {
name: 'Wallet',
icon: Wallet,
Expand Down Expand Up @@ -85,13 +87,13 @@ const PROVIDER_INFO: Record<AuthProvider, ProviderInfo> = {
}

type LinkedProvider = {
provider: AuthProvider
provider: SupportedAuthProvider
providerId: string
handle?: string
linkedAt: number
}

const SUPPORTED_PROVIDERS: AuthProvider[] = [
const SUPPORTED_PROVIDERS: SupportedAuthProvider[] = [
AuthProvider.WALLET,
AuthProvider.GOOGLE,
AuthProvider.APPLE,
Expand All @@ -102,7 +104,7 @@ const SUPPORTED_PROVIDERS: AuthProvider[] = [
AuthProvider.PASSKEY,
]

function toAuthProvider(type: string): AuthProvider | null {
function toAuthProvider(type: string): SupportedAuthProvider | null {
switch (type) {
case 'wallet':
return AuthProvider.WALLET
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
"dev": "bun run jeju dev",
"build": "bun run jeju build",
"test": "bun run jeju test",
"typecheck": "turbo run typecheck --filter='!./vendor/*' && tsc --noEmit && pyright packages/training/python apps/dws/api && bun run \"typecheck:rust\"",
"typecheck": "turbo run typecheck && tsc --noEmit && pyright packages/training/python apps/dws/api && bun run \"typecheck:rust\"",
"typecheck:rust": "bash -c '. $HOME/.cargo/env 2>/dev/null; mkdir -p apps/node/app/dist apps/vpn/app/dist apps/wallet/app/dist; cargo check --manifest-path apps/node/app/src-tauri/Cargo.toml && cargo check --manifest-path apps/vpn/app/src-tauri/Cargo.toml && cargo check --manifest-path apps/wallet/app/src-tauri/Cargo.toml'",
"lint": "biome check --write --unsafe packages apps && ruff check packages apps --exclude packages/workerd --exclude packages/contracts/lib --fix && ruff format packages apps --exclude packages/workerd --exclude packages/contracts/lib && bun run \"lint:rust\"",
"lint:rust": "bash -c '. $HOME/.cargo/env 2>/dev/null; cargo fmt --manifest-path apps/node/app/src-tauri/Cargo.toml && cargo fmt --manifest-path apps/vpn/app/src-tauri/Cargo.toml && cargo fmt --manifest-path apps/wallet/app/src-tauri/Cargo.toml'",
Expand Down
78 changes: 52 additions & 26 deletions packages/auth/src/credentials/verifiable-credentials.ts
Original file line number Diff line number Diff line change
Expand Up @@ -309,19 +309,28 @@ export class VerifiableCredentialIssuer {
}

private getCredentialTypeForProvider(provider: AuthProvider): string {
const typeMap: Record<AuthProvider, string> = {
wallet: 'WalletOwnershipCredential',
farcaster: 'FarcasterAccountCredential',
google: 'GoogleAccountCredential',
apple: 'AppleAccountCredential',
twitter: 'TwitterAccountCredential',
github: 'GitHubAccountCredential',
discord: 'DiscordAccountCredential',
email: 'EmailAccountCredential',
phone: 'PhoneAccountCredential',
}

return typeMap[provider] ?? 'OAuth3IdentityCredential'
switch (provider) {
case 'wallet':
return 'WalletOwnershipCredential'
case 'passkey':
return 'PasskeyAccountCredential'
case 'farcaster':
return 'FarcasterAccountCredential'
case 'google':
return 'GoogleAccountCredential'
case 'apple':
return 'AppleAccountCredential'
case 'twitter':
return 'TwitterAccountCredential'
case 'github':
return 'GitHubAccountCredential'
case 'discord':
return 'DiscordAccountCredential'
case 'email':
return 'EmailAccountCredential'
case 'phone':
return 'PhoneAccountCredential'
}
}

private createJWS(hash: Hex, challenge: string, domain?: string): string {
Expand Down Expand Up @@ -594,6 +603,35 @@ export class VerifiableCredentialVerifier {
}
}

export function getOnChainProviderId(provider: AuthProvider): number {
switch (provider) {
case 'wallet':
return 0
case 'farcaster':
return 1
case 'google':
return 2
case 'apple':
return 3
case 'twitter':
return 4
case 'github':
return 5
case 'discord':
return 6
case 'email':
return 7
case 'phone':
return 8
case 'passkey':
throw new Error('Passkey credentials are not yet supported on-chain')
default: {
const _exhaustive: never = provider
return _exhaustive
}
}
}

export function createCredentialHash(credential: VerifiableCredential): Hex {
const essential = {
type: credential.type,
Expand All @@ -613,20 +651,8 @@ export function credentialToOnChainAttestation(
issuedAt: number
expiresAt: number
} {
const providerMap: Record<AuthProvider, number> = {
wallet: 0,
farcaster: 1,
google: 2,
apple: 3,
twitter: 4,
github: 5,
discord: 6,
email: 7,
phone: 8,
}

return {
provider: providerMap[credential.credentialSubject.provider],
provider: getOnChainProviderId(credential.credentialSubject.provider),
providerId: keccak256(toBytes(credential.credentialSubject.providerId)),
credentialHash: createCredentialHash(credential),
issuedAt: Math.floor(new Date(credential.issuanceDate).getTime() / 1000),
Expand Down
51 changes: 43 additions & 8 deletions packages/auth/src/sdk/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ function generateUUID(): string {
return `${hex.slice(0, 8)}-${hex.slice(8, 12)}-${hex.slice(12, 16)}-${hex.slice(16, 20)}-${hex.slice(20)}`
}

function arrayBufferToBase64url(buffer: ArrayBuffer): string {
function arrayBufferToBase64url(buffer: ArrayBuffer | SharedArrayBuffer): string {
const bytes = new Uint8Array(buffer)
let binary = ''
for (const byte of bytes) {
Expand All @@ -56,6 +56,28 @@ function isPasskeyRequestOptions(
return 'challenge' in options && !('user' in options)
}

function isAuthenticatorAssertionResponse(
response: AuthenticatorResponse,
): response is AuthenticatorAssertionResponse {
return 'authenticatorData' in response && 'signature' in response
}

function safeArrayBufferToBase64url(
value:
| ArrayBuffer
| SharedArrayBuffer
| ArrayBufferView
| null
| undefined,
): string | undefined {
if (!value) return undefined
const buffer =
value instanceof ArrayBuffer || value instanceof SharedArrayBuffer
? value
: value.buffer.slice(value.byteOffset, value.byteOffset + value.byteLength)
return arrayBufferToBase64url(buffer)
}

// OAuth callback data schema
const OAuthCallbackSchema = z.object({
code: z.string().optional(),
Expand Down Expand Up @@ -622,21 +644,34 @@ export class OAuth3Client {
if (!('attestationObject' in response)) {
throw new Error('Invalid passkey registration response')
}
const registrationResponse = response as AuthenticatorAttestationResponse & {
attestationObject: ArrayBuffer | SharedArrayBuffer
clientDataJSON: ArrayBuffer
}
responsePayload = {
clientDataJSON,
attestationObject: arrayBufferToBase64url(response.attestationObject),
attestationObject: safeArrayBufferToBase64url(
registrationResponse.attestationObject,
),
}
} else {
if (!('authenticatorData' in response) || !('signature' in response)) {
if (!isAuthenticatorAssertionResponse(response)) {
throw new Error('Invalid passkey authentication response')
}
const assertionResponse = response as AuthenticatorAssertionResponse & {
authenticatorData: ArrayBuffer | SharedArrayBuffer
signature: ArrayBuffer | SharedArrayBuffer
userHandle?: ArrayBuffer | SharedArrayBuffer | null
}
responsePayload = {
clientDataJSON,
authenticatorData: arrayBufferToBase64url(response.authenticatorData),
signature: arrayBufferToBase64url(response.signature),
userHandle: response.userHandle
? arrayBufferToBase64url(response.userHandle)
: undefined,
authenticatorData: safeArrayBufferToBase64url(
assertionResponse.authenticatorData,
),
signature: safeArrayBufferToBase64url(assertionResponse.signature),
userHandle: safeArrayBufferToBase64url(
assertionResponse.userHandle,
),
}
}

Expand Down
4 changes: 3 additions & 1 deletion packages/auth/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
* threshold MPC signing, and W3C Verifiable Credentials.
*/

import type { JsonRecord, TEEAttestation } from '@jejunetwork/types'
import type { TEEAttestation } from '@jejunetwork/types'
import type { Address, Hex } from 'viem'

export type { JsonRecord } from '@jejunetwork/types'

export const AuthProvider = {
WALLET: 'wallet',
PASSKEY: 'passkey',
Expand Down
Loading
Loading